首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将tensorflow模型部署在sagemaker异步端点上,并包含一个inference.py脚本

将tensorflow模型部署在sagemaker异步端点上,并包含一个inference.py脚本
EN

Stack Overflow用户
提问于 2022-09-03 20:27:27
回答 1查看 88关注 0票数 0

我正在尝试将tensorflow模型部署到sagemaker上的异步端点。

我以前使用以下代码将相同的模型部署到实时端点:

代码语言:javascript
复制
from sagemaker.tensorflow.serving import Model

tensorflow_serving_model = Model(model_data=model_artifact,
                                 entry_point = 'inference.py',
                                 source_dir = 'code',
                                 role=role,
                                 framework_version='2.3',
                                 sagemaker_session=sagemaker_session)
代码语言:javascript
复制
predictor = tensorflow_serving_model.deploy(initial_instance_count=1, instance_type='ml.m5.xlarge')

使用source_dir参数;我能够在我的模型中包含inference.py和requirements.txt文件.

现在要做的事情:尝试将相同的模型部署到异步端点,遵循文档博客示例.我使用了以下片段:

代码语言:javascript
复制
from sagemaker.image_uris import retrieve

deploy_instance_type = 'ml.m5.xlarge'
tensorflow_inference_image_uri = retrieve('tensorflow',
                                       region,
                                       version='2.8',
                                       py_version='py3',
                                       instance_type = deploy_instance_type,
                                       accelerator_type=None,
                                       image_scope='inference')

container = tensorflow_inference_image_uri
model_name = 'sagemaker-{0}'.format(str(int(time.time())))

# Create model
create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        'Image': container,
        'ModelDataUrl': model_artifact,
        'Environment': {
            'TS_MAX_REQUEST_SIZE': '100000000', #default max request size is 6 Mb for torchserve, need to update it to support the 70 mb input payload
            'TS_MAX_RESPONSE_SIZE': '100000000',
            'TS_DEFAULT_RESPONSE_TIMEOUT': '1000'
        }
    },    
)
代码语言:javascript
复制
# Create endpoint config
endpoint_config_name = f"AsyncEndpointConfig-{strftime('%Y-%m-%d-%H-%M-%S', gmtime())}"
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.m5.xlarge",
            "InitialInstanceCount": 1
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": f"s3://{bucket}/{bucket_prefix}/output",
            #  Optionally specify Amazon SNS topics
            "NotificationConfig": {
              "SuccessTopic": success_topic,
              "ErrorTopic": error_topic,
            }
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 2
        }
    }
)
代码语言:javascript
复制
# Create endpoint
endpoint_name = f"sm-{strftime('%Y-%m-%d-%H-%M-%S', gmtime())}"
create_endpoint_response = sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)

问题:在尝试将模型部署到异步端点时,无法指定包含inference.py和requirements.txt的源目录。

我确信我不能根据docs在.tar模型文件中包含代码/目录,唯一的方法是通过SDK类初始化中的.tar参数。

我的问题:如何在异步端点上使用包含的inference.py的代码/目录?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-09-06 23:10:50

您没有source_dir选项的原因是,您现在试图使用boto3来部署模型,而不是使用最初使用的SageMaker Python。

您可以像以前一样使用SDK将模型部署到异步端点。唯一的区别是你需要一个AsyncInferenceConfig

您可以使用以下内容:

代码语言:javascript
复制
from sagemaker.tensorflow.serving import Model

tensorflow_serving_model = Model(model_data=model_artifact,
                                 entry_point = 'inference.py',
                                 source_dir = 'code',
                                 role=role,
                                 framework_version='2.3',
                                 sagemaker_session=sagemaker_session)
代码语言:javascript
复制
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig

async_config = AsyncInferenceConfig(
    output_path=f"s3://{s3_bucket}/{bucket_prefix}/output",
    max_concurrent_invocations_per_instance=4,
    # Optionally specify Amazon SNS topics
    # notification_config = {
    # "SuccessTopic": "arn:aws:sns:<aws-region>:<account-id>:<topic-name>",
    # "ErrorTopic": "arn:aws:sns:<aws-region>:<account-id>:<topic-name>",
    # }
)
代码语言:javascript
复制
endpoint_name = resource_name.format("Endpoint", datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))

async_predictor = tensorflow_serving_model.deploy(
    async_inference_config=async_config,
    instance_type="ml.m5.xlarge",
    initial_instance_count=1,
    endpoint_name=endpoint_name,
)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73595234

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档