可以为sagemaker.sklearn.estimator的SKLearn指定S3存储桶吗?

我在学习SageMaker的处理作业API时,参考了这个示例笔记本:https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker_processing/scikit_learn_data_processing_and_model_evaluation/scikit_learn_data_processing_and_model_evaluation.ipynb

我尝试修改他们的代码,以避免使用默认的S3存储桶,即:s3://sagemaker-<region>-<account_id>/

对于使用.run方法的数据处理步骤:

from sagemaker.processing import ProcessingInput, ProcessingOutputsklearn_processor.run(    code="preprocessing.py",    inputs=[ProcessingInput(source=input_data, destination="/opt/ml/processing/input")],    outputs=[        ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),        ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),    ],    arguments=["--train-test-split-ratio", "0.2"],)

我通过使用destination参数成功修改了代码,以使用我自己的S3存储桶,像这样:

sklearn_processor.run(     code=output_bucket_uri + "preprocessing.py",     inputs=[ProcessingInput(         source=input_bucket_uri + "census-income.csv",         destination=path+"input/",     )],     outputs=[         ProcessingOutput(             output_name="train_data",             source=path+"train/",             destination=output_bucket_uri + "train/",         ),         ProcessingOutput(             output_name="test_data",             source=path+"test/",             destination=output_bucket_uri + "test/",         ),     ],     arguments=["--train-test-split-ratio", "0.2"], )

但是对于.fit方法:

sklearn.fit({"train": preprocessed_training_data})

我没有找到一个参数可以传递,使输出工件保存到我指定的S3存储桶,而不是默认的s3://sagemaker-<region>-<account_id>/存储桶。


回答:

对于SKLearnProcessor,指定默认存储桶的理想方法是通过创建一个带有该存储桶的sagemaker会话,并将其作为sagemaker_session参数发送。示例:

from sagemaker.session import Session    sklearn_processor = SKLearnProcessor(framework_version='0.20.0',                                     role='<arn-role>',                                     instance_type='ml.m5.xlarge',                                     instance_count=1,                                     sagemaker_session=Session(default_bucket='<s3-bucket-name>'))

我知道这不是你的具体问题,但你在问题详情中添加了一个替代方案。所以我在这里添加它作为一个更清晰的方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注