正如文档/教程中提到的,我们可以调用 Estimator.fit()
来启动训练任务。
该方法的必需参数是 inputs
,它是对训练文件的 S3/文件引用。例如:
estimator.fit({'train':'s3://my-bucket/training_data'})
training-script.py
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
我期望 os.environ['SM_CHANNEL_TRAIN']
返回的是 S3 路径。但实际上,它返回的是 /opt/ml/input/data/train
。
有人知道这是为什么吗?
更新
我也尝试调用了 estimator.fit('s3://my-bucket/training_data')
。不知为何,训练实例没有获取到 SM_CHANNEL_TRAIN 环境变量。实际上,我在环境变量中根本没有看到 S3 URI。
回答:
在 SageMaker 中运行训练任务时,提供的包含训练数据的 S3 URL 会被从指定的 URL 复制到 docker 容器(即训练任务)中。因此,环境变量 SM_CHANNEL_TRAIN 指向的是从提供的 S3 URL 复制过来的训练数据的本地路径。