首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow.js中的自定义激活(Swish)

tensorflow.js中的自定义激活(Swish)
EN

Stack Overflow用户
提问于 2019-08-10 04:28:51
回答 1查看 627关注 0票数 3

我正在尝试在tensorflow.js中实现一个自定义激活(Swish)函数。我该怎么做呢?

我没有找到任何自定义激活函数的信息,但添加了一个custom layer。所以我实现了一个自定义的层,我在没有分配任何激活的层之后手动添加。它似乎可以工作,但在代码中(尤其是模型定义)相当繁琐。如何直接扩展激活类以获得可以在任何层定义中调用的激活函数?激活函数是在TF.js代码中的https://github.com/tensorflow/tfjs-layers/blob/master/src/activations.ts中定义的。它应该是这样工作的:const dense_layer = tf.layers.dense({units: 8, activation: my_Swish});

我把我的代码放在一起,在那里我扩展了一个层,使其像一个激活函数。这是类Swish。我还尝试通过Swich2类扩展tf.layers.activations,但调用它会产生错误。

代码语言:javascript
复制
errors.ts:48 Uncaught Error: activation: Improper config format: {"_callHook":null,"_addedWeightNames":[],"_stateful":false,"id":1,"activityRegularizer":null,"inputSpec":null,"supportsMasking":true,"_trainableWeights":[],"_nonTrainableWeights":[],"_losses":[],"_updates":[],"_built":false,"inboundNodes":[],"outboundNodes":[],"name":"activation_Activation1","trainable_":true,"initialWeights":null,"_refCount":null,"fastWeightInitDuringBuild":false,"activation":{}}.
'className' and 'config' must set.
    at new e (errors.ts:48)
    at Lp (generic_utils.ts:227)
    at rm (activations.ts:218)
    at im (activations.ts:239)
    at new e (core.ts:218)
    at Object.dense (exports_layers.ts:524)
    at tf.html:77

我已经在代码中添加了这两个测试用例。只需取消第64行和第66行的注释,即可查看相应的错误。我的tf_test.html (可以在浏览器本地运行):

代码语言:javascript
复制
<html>
  <head>
    <!-- Load TensorFlow.js, Plotly for plots -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
      <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
    <!-- Place your code in the script tag below. You can also use an external .js file -->    
  </head>

  <body>
  <div id="myDiv"><!-- Plotly chart will be drawn inside this DIV --></div>
  <script>
    // Notice there is no 'import' statement. 'tf' is available on the index-page
    // because of the script tag above.

    // custom layer
    // https://github.com/tensorflow/tfjs-examples/blob/master/custom-layer/custom_layer.js
    // https://gist.github.com/caisq/33ed021e0c7b9d0e728cb1dce399527d
    // https://github.com/tensorflow/tfjs-layers/blob/master/src/activations.ts
    class Swish extends tf.layers.Layer {
      //static className = 'Swish';
      constructor(config) {
        super(config);
        this.alpha = config.alpha;
      }

      call(input) {
        return tf.tidy(() => {
          const x = input[0]; //tf.getExactlyOneTensor(input);
          return tf.sigmoid(x.mul(this.alpha)).mul(x);
        });
      }

      computeOutputShape(inputShape){
        return inputShape;
      }

      static get className() {
      return 'Swish';
      }

      getConfig() {
        const config = {alpha: this.alpha};
        const baseConfig = super.getConfig();
        Object.assign(config, baseConfig);
        return config;
      }
    }

    tf.serialization.registerClass(Swish);

    class Swish2 extends tf.layers.activation {
      static get className() {return 'Swish2'};
      apply(x, alpha = 1) {
        return tf.sigmoid(x.mul(this.alpha)).mul(x);
      }
    }
    tf.serialization.registerClass(Swish2);

    const sw = new Swish({alpha: 1});
    const sw2 = new Swish2(1);
    // create Model
    const m_in = tf.input({shape: [1]});
    // this doesn't work!!! activations have to be defined manually as layers and then added, see below
    //const test = tf.layers.dense({units: 8, activation: sw}).apply(m_in);
    // sw2 can't be called this way either, yields 'Improper config format' error
    //const test2 = tf.layers.dense({units: 8, activation: sw2}).apply(m_in);
    const dense1 = tf.layers.dense({units: 8}).apply(m_in);
    const swa = sw.apply(dense1);
    const m_out = tf.layers.dense({units: 1}).apply(swa);
    const model = tf.model({inputs: m_in, outputs: m_out});

  </script>
  </body>
</html>
EN

回答 1

Stack Overflow用户

发布于 2019-08-12 16:50:58

根据文档,tf.layers.dense的激活属性是一个字符串:'elu'|'hardSigmoid'|'linear'|'relu'|'relu6'| 'selu'|'sigmoid'|'softmax'|'softplus'|'softsign'|'tanh'。通过使用tf.layers.activation.layer的实例,将引发以下错误:

激活:配置格式错误:{"_callHook":null,"_addedWeightNames":[],"_stateful":false,"id":0,"activityRegularizer":null,"inputSpec":null,"supportsMasking":true,"_trainableWeights":[],"_nonTrainableWeights":[],"_losses":[],"_updates":[],"_built":false,"inboundNodes":[],"outboundNodes":[],"name":"activation_Activation1","trainable_":true,"initialWeights":null,"_refCount":null,"fastWeightInitDuringBuild":false,“激活”:{},"alpha":1}。

该错误指示没有设置className,也没有设置配置,因为激活类的getActivation返回一个空对象。查看代码here,我们可以看到,如果激活不是一个字符串,则返回一个没有类名和配置的空层。

设置自定义激活层的正确方法是将其应用为doc中指定的层。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57436740

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档