首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >本地加载模型以进行推理

本地加载模型以进行推理
EN

Stack Overflow用户
提问于 2020-09-29 19:25:42
回答 1查看 168关注 0票数 1

我已经训练了一个NTM模型,这是一个神经主题模型,直接在AWS救世主平台上。一旦培训完成,您就可以下载mxnet模型文件。解压缩后,文件包含:

  • 参数
  • symbol.json
  • meta.json

我遵循mxnet上的docs加载模型,并有以下代码:

代码语言:javascript
复制
sym, arg_params, aux_params = mx.model.load_checkpoint('model_algo-1', 0)
module_model = mx.mod.Module(symbol=sym, label_names=None, context=mx.cpu())

module_model.bind(
    for_training=False,
    data_shapes=[('data', (1, VOCAB_SIZE))]
)

module_model.set_params(arg_params=arg_params, aux_params=aux_params, allow_missing=True) # must set allow missing true here or receive an error for a missing n_epoch var

现在,我尝试使用该模型进行推理:

代码语言:javascript
复制
module_model.predict(x) # where x is a numpy array of size (1, VOCAB_SIZE)

代码运行,但结果只是一个值,我希望在这个值中发布一个主题:

代码语言:javascript
复制
[11.060672]
<NDArray 1 @cpu(0)>

编辑:

我尝试使用符号API加载它,但仍然没有成功:

代码语言:javascript
复制
import warnings
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    deserialized_net = gluon.nn.SymbolBlock.imports('model_algo-1-symbol.json', ['data'], 'model_algo-1-0000.params', ctx=mx.cpu())

错误:

代码语言:javascript
复制
AssertionError: Parameter 'n_epoch' is missing in file: model_algo-1-0000.params, which contains parameters: 'logsigma_bias', 'enc_0_bias', 'projection_bias', ..., 'enc_1_weight', 'enc_0_weight', 'mean_bias', 'logsigma_weight'. Please make sure source and target networks have the same prefix.

任何帮助都会很好!

EN

回答 1

Stack Overflow用户

发布于 2020-10-29 19:38:16

SageMaker不支持此用例。该模型可以托管在SageMaker上进行在线推理,也可以用于批量预测和转换作业。

请参阅更多细节:

  1. https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html
  2. https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64126327

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档