首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用SageMaker SDK为Boto3培训作业指定源目录和入口点?用例是通过Lambda调用开始培训。

如何使用SageMaker SDK为Boto3培训作业指定源目录和入口点?用例是通过Lambda调用开始培训。
EN

Stack Overflow用户
提问于 2021-02-23 01:40:59
回答 1查看 1.6K关注 0票数 0

我一直在使用SageMaker Python在SageMaker笔记本实例上运行培训作业,并在本地使用IAM凭据。他们工作得很好,但我想通过AWS + Gateway开始一项训练工作。

Lambda不支持SageMaker SDK (高级SDK),因此我不得不在我的Lambda处理程序中使用来自boto3的SageMaker客户端。

代码语言:javascript
复制
sagemaker = boto3.client('sagemaker')

据推测,这个boto3服务级SDK将给我100%的控制权,但我找不到参数或配置名称来指定源目录和入口点。我正在运行一个自定义培训作业,它需要在飞行过程中生成一些数据(使用Keras生成器)。

下面是我的SageMaker SDK调用的一个示例

代码语言:javascript
复制
tf_estimator = TensorFlow(base_job_name='tensorflow-nn-training',
                          role=sagemaker.get_execution_role(),
                          source_dir=training_src_path,
                          code_location=training_code_path,
                          output_path=training_output_path,
                          dependencies=['requirements.txt'],
                          entry_point='main.py',
                          script_mode=True,
                          instance_count=1,
                          instance_type='ml.g4dn.2xlarge',
                          framework_version='2.3',
                          py_version='py37',
                          hyperparameters={
                              'model-name': 'my-model-name',
                              'epochs': 1000,
                              'batch-size': 64,
                              'learning-rate': 0.01,
                              'training-split': 0.80,
                              'patience': 50,
                          })

通过调用fit()注入输入路径

代码语言:javascript
复制
input_channels = {
    'train': training_input_path,
}
tf_estimator.fit(inputs=input_channels)
  • source_dir是查找我的src.zip.gz的一个S3 URI,它包含用于执行培训的模型和脚本。
  • entry_point是培训开始的地方。TensorFlow容器只运行python main.py
  • code_location是一个S3前缀,如果我使用本地模型和脚本在本地运行此培训,可以将培训源代码上传到这里。
  • output_path是一个S3 URI,培训作业将将模型构件上传到该URI。

但是,我查看了工作的文档,找不到允许我设置源目录和入口点的任何字段。

举个例子,

代码语言:javascript
复制
sagemaker = boto3.client('sagemaker')
sagemaker.create_training_job(
    TrainingJobName='tf-training-job-from-lambda',
    Hyperparameters={} # Same dictionary as above,
    AlgorithmSpecification={
        'TrainingImage': '763104351884.dkr.ecr.us-west-1.amazonaws.com/tensorflow-training:2.3.1-gpu-py37-cu110-ubuntu18.04',
        'TrainingInputMode': 'File',
        'EnableSageMakerMetricsTimeSeries': True
    },
    RoleArn='My execution role goes here',
    InputDataConfig=[
        {
            'ChannelName': 'train',
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'S3Prefix',
                    'S3Uri': training_input_path,
                    'S3DataDistributionType': 'FullyReplicated'
                }
            },
            'CompressionType': 'None',
            'RecordWrapperType': 'None',
            'InputMode': 'File',
        }  
    ],
    OutputDataConfig={
        'S3OutputPath': training_output_path,
    }
    ResourceConfig={
        'InstanceType': 'ml.g4dn.2xlarge',
        'InstanceCount': 1,
        'VolumeSizeInGB': 16
    }
    StoppingCondition={
        'MaxRuntimeInSeconds': 600 # 10 minutes for testing
    }
)

从上面的配置中,SDK接受培训输入和输出位置,但是哪个配置字段允许用户指定源代码目录和入口点?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-20 11:05:59

您可以将source_dir传递给超级参数,如下所示:

代码语言:javascript
复制
    response = sm_boto3.create_training_job(
        TrainingJobName=f"{your job name}"),
        HyperParameters={
            'model-name': 'my-model-name',
            'epochs': 1000,
            'batch-size': 64,
            'learning-rate': 0.01,
            'training-split': 0.80,
            'patience': 50,
            "sagemaker_program": "script.py", # this is where you specify your train script
            "sagemaker_submit_directory": "s3://" + bucket + "/" + project + "/" + source, # your s3 URI like s3://sm/tensorflow/source/sourcedir.tar.gz
        },
        AlgorithmSpecification={
            "TrainingImage": training_image,
            ...
        }, 

注意:否则,请确保它是xxx.tar.gz。否则,萨吉克会抛出错误。

请参阅end2end.ipynb

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66325857

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档