首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >dl4j lstm不成功

dl4j lstm不成功
EN

Stack Overflow用户
提问于 2020-06-23 06:12:30
回答 2查看 200关注 0票数 0

我正在尝试将练习复制到这个链接页面的一半位置:https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

本练习使用正弦函数在-1到1之间创建1000个数据点,并使用递归网络近似该函数。

下面是我使用的代码。我将回去研究为什么这个方法不起作用,因为当我能够很容易地使用前馈网络来近似这个函数时,它对我来说没有太大的意义。

代码语言:javascript
复制
      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

你能解释一下我需要使用1 in 10隐藏和1出lstm网络来近似正弦函数的代码吗?

Im不使用任何归一化,因为函数已经是-1:1,并且Im使用Y输入作为特征,使用以下Y输入作为标签来训练网络。

您注意到,我正在构建一个允许更容易地构建网络的类,并且我已经尝试在这个问题上进行了许多更改,但我厌倦了猜测。

以下是我的结果的一些示例。蓝色表示数据,红色表示结果

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-06-24 04:52:48

这是其中的一次,你从想为什么这不起作用,到地狱我的原始结果是如何像他们一样好的。

我的失败在于没有清楚地理解文档,也没有理解BPTT。

对于前馈网络,每次迭代存储为一行,每次输入存储为一列。网络inputs.size dataset.size就是一个例子

然而,对于循环输入,情况正好相反,每一行都是一个输入,每一列都是激活lstm事件链状态所必需的迭代。我的输入至少需要是0、networkinputs.size、dataset.size,但也可以是dataset.size、networkinputs.size、statelength.size

在我之前的例子中,我是用dataset.size,networkinputs.size,1这种格式的数据来训练网络的,所以从我的低分辨率来看,lstm网络应该根本不会工作,但至少会产生一些东西。

在将数据集转换为列表时,可能还会出现一些问题,因为我还更改了向网络提供数据的方式,但我认为大部分问题都是数据结构问题。

下面是我的新结果

票数 1
EN

Stack Overflow用户

发布于 2020-06-23 07:47:30

如果看不到完整的代码,就很难知道发生了什么。首先,我没有看到指定的RnnOutputLayer。你可以看看this,它向你展示了如何在DL4J中构建一个RNN。如果你的RNN设置是正确的,这可能是一个调优问题。您可以找到有关调优here的更多信息。Adam可能是比RMSProp更好的更新器选择。而且tanh可能是激活输出层的一个很好的选择,因为它的范围是(-1,1)。其他要检查/调整的东西-学习率,纪元数,数据的设置(比如你是否想要预测得太远?)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62524314

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档