首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在MXNET中从预训练模型(RSNET-152)中提取特征

在MXNET中从预训练模型(RSNET-152)中提取特征
EN

Stack Overflow用户
提问于 2018-12-09 21:12:34
回答 1查看 178关注 0票数 1

嗨,我正在尝试从MxNet中的预训练模型(RsNet-152)中提取半月板层的输出。由于我需要在java应用程序中使用该脚本,因此我选择了scala作为语言。

我遵循了这里提到的步骤https://mxnet.incubator.apache.org/tutorials/python/predict_image.html

并由脚本相应地修改。下面是loadModel函数。

代码语言:javascript
复制
  def loadResnetModel(modelPath: String): Module = {
val (net, argParams, auxParams) = Model.loadCheckpoint(modelPath, modelFileNumber)
val allLayer = net.getInternals()
val secondLastLayer = allLayer.get("flatten0_output")
val mod = new Module(symbolVar = secondLastLayer, contexts = Context.cpu(), labelNames =null)
val dataShape = ListMap("data" -> Shape(1, 3, 224, 224))
mod.bind(dataShapes=dataShape, forTraining = false)
mod.setParams(argParams, auxParams, allowMissing=true)
mod

当我尝试运行脚本时,我得到了以下错误。

代码语言:javascript
复制
 Exception in thread "main" java.lang.IllegalArgumentException: requirement failed: Find name fc1_bias that is not in the arguments
 [java]     at scala.Predef$.require(Predef.scala:224)
 [java]     at org.apache.mxnet.Executor$$anonfun$copyParamsFrom$1.apply(Executor.scala:274)
 [java]     at org.apache.mxnet.Executor$$anonfun$copyParamsFrom$1.apply(Executor.scala:270)
 [java]     at scala.collection.immutable.HashMap$HashMap1.foreach(HashMap.scala:221)
 [java]     at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
 [java]     at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
 [java]     at scala.collection.immutable.HashMap$HashTrieMap.foreach(HashMap.scala:428)
 [java]     at org.apache.mxnet.Executor.copyParamsFrom(Executor.scala:270)
 [java]     at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$setParams$1.apply(DataParallelExecutorGroup.scala:452)
 [java]     at org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$setParams$1.apply(DataParallelExecutorGroup.scala:452)
 [java]     at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
 [java]     at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
 [java]     at org.apache.mxnet.module.DataParallelExecutorGroup.setParams(DataParallelExecutorGroup.scala:452)
 [java]     at org.apache.mxnet.module.Module.setParams(Module.scala:201)

附言:我是mxnet和scala的新手。有什么明显的错误是我看不到的吗?

EN

回答 1

Stack Overflow用户

发布于 2018-12-18 09:18:52

您需要更改函数的最后一行:您需要调用mod.setParams(argParams, auxParams)而不是mod.setParams(argParams, auxParams, allowMissing=true)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53692644

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档