首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Sagemaker :在推理作业中访问custom_attributes

Sagemaker :在推理作业中访问custom_attributes
EN

Stack Overflow用户
提问于 2022-04-28 00:02:52
回答 1查看 184关注 0票数 0

我正在使用Sagemaker作为我的推理工作和遵循本指南。我正在从下面的python可调用的气流中触发我的救世主推断工作:

代码语言:javascript
复制
def transform(sage_role, inference_file_local_path, **kwargs):
    """
    Python callable to execute Sagemaker SDK train job. It takes infer_batch_output, infer_batch_input, model_artifact,
    instance_type and infer_file_name as run time parameter.
    :param inference_file_local_path: Local entry_point path for Inference file.
    :param sage_role: Sagemaker execution role.
    """
    model = TensorFlowModel(entry_point=infer_file_name,
                            source_dir=inference_file_local_path,
                            model_data=model_artifact,
                            role=sage_role,
                            framework_version="2.5.1")

    tensorflow_serving_transformer = model.transformer(
        instance_count=1,
        instance_type=instance_type,
        accept="text/csv",
        strategy="SingleRecord",
        max_payload=10,
        max_concurrent_transforms=10,
        output_path=batch_output)

    return tensorflow_serving_transformer.transform(data=batch_input, content_type='text/csv')

我的简单inference.py看起来是这样的:

代码语言:javascript
复制
def input_handler(data, context):
    """ Pre-process request input before it is sent to TensorFlow Serving REST API
    Args:
        data (obj): the request data, in format of dict or string
        context (Context): an object containing request and configuration details
    Returns:
        (dict): a JSON-serializable dict that contains request body and headers
    """
    if context.request_content_type == 'application/x-npy':
        # very simple numpy handler
        payload = np.load(data.read().decode('utf-8'))
        x_user_feature = np.asarray(payload.item().get('test').get('feature_a_list'))
        x_channel_feature = np.asarray(payload.item().get('test').get('feature_b_list'))
        examples = []
        for index, elem in enumerate(x_user_feature):
            examples.append({'feature_a_list': elem, 'feature_b_list': x_channel_feature[index]})
        return json.dumps({'instances': examples})

    if context.request_content_type == 'text/csv':
        payload = pd.read_csv(data)
        print("Model name is ..............")
        model_name = context.model_name
        print(model_name)
        examples = []
        row_ch = []
        if config_exists(model_bucket, "{}{}".format(config_path, model_name)):
            config_keys = get_s3_json_file(model_bucket, "{}{}".format(config_path, model_name))
            feature_b_list = config_keys["feature_b_list"].split(",")
            row_ch = [float(ch_feature_str) for ch_feature_str in feature_b_list]
            if "column_names" in config_keys.keys():
                cols = config_keys["column_names"].split(",")
                payload.columns = cols
        for index, row in payload.iterrows():
            row_user = row['feature_a_list'].replace('[', '').replace(']', '').split()
            row_user = [float(x) for x in row_user]
            if not row_ch:
                row_ch = row['feature_b_list'].replace('[', '').replace(']', '').split()
                row_ch = [float(x) for x in row_ch]
            example = {'feature_a_list': row_user, 'feature_b_list': row_ch}
            examples.append(example)

    raise ValueError('{{"error": "unsupported content type {}"}}'.format(
        context.request_content_type or "unknown"))


def output_handler(data, context):
    """Post-process TensorFlow Serving output before it is returned to the client.
    Args:
        data (obj): the TensorFlow serving response
        context (Context): an object containing request and configuration details
    Returns:
        (bytes, string): data to return to client, response content type
    """
    if data.status_code != 200:
        raise ValueError(data.content.decode('utf-8'))

    response_content_type = context.accept_header
    prediction = data.content
    return prediction, response_content_type

但是,它工作得很好,但我希望将自定义参数传递给inference.py,以便能够根据需求相应地修改输入数据。我考虑在每个需求中使用一个配置文件,并根据模型名从s3下载它,但是由于我正在使用model_data并在运行时传递model.tar.gzcontext.model_name总是None

是否有一种方法可以将运行时参数传递给inference.py,以便用于自定义?在我看到的文档中,sagemaker提供了custom_attributes,但是我没有看到任何关于如何在inference.py中使用和访问它的例子。

代码语言:javascript
复制
custom_attributes (string): content of ‘X-Amzn-SageMaker-Custom-Attributes’ header from the original request. For example, ‘tfs-model-name=half*plus*three,tfs-method=predict’
EN

回答 1

Stack Overflow用户

发布于 2022-04-29 03:57:47

目前,在使用实时端点时,CustomAttributesInvokeEndpoint API调用中受到支持。

例如,您可以将JSON行作为转换作业的输入,它包含输入有效负载和可以在inference.py文件中使用的一些自定义参数。

例如,

代码语言:javascript
复制
{
   "input":"1,2,3,4",
   "custom_args":"my_custom_arg"
}
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72036515

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档