首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在tensorflow.js上加载/再培训/保存tensorflow.js?

如何在tensorflow.js上加载/再培训/保存tensorflow.js?
EN

Stack Overflow用户
提问于 2019-02-27 03:01:33
回答 1查看 2.4K关注 0票数 3

ML /流量初学者。

这些已经训练过的模型中的任何一个都可以加载到tfjs上并在那里进行再培训,然后导出到下载中,还是Tensorflow python是唯一的出路?

我看到这个过程在Tensorflow的教程中有很好的描述和文档记录,但不幸的是,我找不到任何文档/教程来用tfjs重新训练浏览器上的对象检测模型(图像分类是的,对象检测不是)。

我知道如何使用npm加载coco模型,然后可能触发将其保存到下载,但如何:

  • 配置文件(需要修改它,因为我只想拥有一个类,而不是90个)
  • 带注释的图像( .jpg、.xml和.csv)
  • labels.pbtxt
  • .record文件

有没有办法通过重新培训的过程,如ssd_inception_v2_coco和我没有击中正确的谷歌关键字,还是它只是不可能在当前的框架状态?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-02-27 18:36:36

通过使用coco模型作为特征提取器,您可以使用转移学习。转移学习的一个例子可以看到这里.

下面是一个模型,它使用特征提取器作为一个新的序列模型的输入来提取特征。

代码语言:javascript
复制
const loadModel = async () => {
  const loadedModel = await tf.loadModel(MODEL_URL)
  console.log(loadedModel)
  // take whatever layer except last output
  loadedModel.layers.forEach(layer => console.log(layer.name))
  const layer = loadedModel.getLayer(LAYER_NAME)
  return tf.model({ inputs: loadedModel.inputs, outputs: layer.output });
}
loadModel().then(featureExtractor => {
  model = tf.sequential({
    layers: [
      // Flattens the input to a vector so we can use it in a dense layer. While
      // technically a layer, this only performs a reshape (and has no training
      // parameters).
      // slice so as not to take the batch size
      tf.layers.flatten(
        { inputShape: featureExtractor.outputs[0].shape.slice(1) }),
      // add all the layers of the model to train
      tf.layers.dense({
        units: UNITS,
        activation: 'relu',
        kernelInitializer: 'varianceScaling',
        useBias: true
      }),
      // Last Layer. The number of units of the last layer should correspond
      // to the number of classes to predict.
      tf.layers.dense({
        units: NUM_CLASSES,
        kernelInitializer: 'varianceScaling',
        useBias: false,
        activation: 'softmax'
      })
    ]
  });
})

要检测出90类coco中的单个对象,只需对coco的预测使用条件测试即可。

代码语言:javascript
复制
const image = document.getElementById(id)

cocoSsd.load()
  .then(model => model.detect(image))
  .then(prediction => {
if (prediction.class === OBJECT_DETECTED) {
  // display it the bbox to the user}
})

如果类不存在于coco中,那么需要构建一个检测器。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54897356

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档