首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >火炬CrossEntropyCriterion误差

火炬CrossEntropyCriterion误差
EN

Stack Overflow用户
提问于 2016-02-08 09:31:54
回答 1查看 1.2K关注 0票数 1

我试图在Torch中训练一个简单的XOR函数测试网络。当我使用MSECriterion时,它可以工作,但当我尝试使用CrossEntropyCriterion时,它会失败,出现以下错误消息:

代码语言:javascript
复制
/home/a/torch/install/bin/luajit: /home/a/torch/install/share/lua/5.1/nn/THNN.lua:699: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /tmp/luarocks_nn-scm-1-6937/nn/lib/THNN/generic/ClassNLLCriterion.c:31
stack traceback:
    [C]: in function 'v'
    /home/a/torch/install/share/lua/5.1/nn/THNN.lua:699: in function 'ClassNLLCriterion_updateOutput'
    ...e/a/torch/install/share/lua/5.1/nn/ClassNLLCriterion.lua:41: in function 'updateOutput'
    ...torch/install/share/lua/5.1/nn/CrossEntropyCriterion.lua:13: in function 'forward'
    .../a/torch/install/share/lua/5.1/nn/StochasticGradient.lua:35: in function 'train'
    a.lua:34: in main chunk
    [C]: in function 'dofile'
    /home/a/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk
    [C]: at 0x00406670

在将其分解为LogSoftMax和ClassNLLCriterion时,我会得到相同的错误消息。守则是:

代码语言:javascript
复制
dataset={};
function dataset:size() return 100 end -- 100 examples
for i=1,dataset:size() do
  local input = torch.randn(2);     -- normally distributed example in 2d
  local output = torch.Tensor(2);
  if input[1]<0 then
      input[1]=-1
  else
      input[1]=1
  end
  if input[2]<0 then
      input[2]=-1
  else
      input[2]=1
  end
  if input[1]*input[2]>0 then     -- calculate label for XOR function
    output[2] = 1;
  else
    output[1] = 1
  end
  dataset[i] = {input, output}
end

require "nn"
mlp = nn.Sequential();  -- make a multi-layer perceptron
inputs = 2; outputs = 2; HUs = 20; -- parameters
mlp:add(nn.Linear(inputs, HUs))
mlp:add(nn.Tanh())
mlp:add(nn.Linear(HUs, outputs))

criterion = nn.CrossEntropyCriterion()
trainer = nn.StochasticGradient(mlp, criterion)
trainer.learningRate = 0.01
trainer:train(dataset)

x = torch.Tensor(2)
x[1] =  1; x[2] =  1; print(mlp:forward(x))
x[1] =  1; x[2] = -1; print(mlp:forward(x))
x[1] = -1; x[2] =  1; print(mlp:forward(x))
x[1] = -1; x[2] = -1; print(mlp:forward(x))
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-02-08 15:23:48

针对回归问题,设计了MSE准则。当它用于分类任务时,目标应该是一个热点向量。交叉熵/负对数似然准则专门用于分类;因此,不需要将目标类显式表示为向量。在torch中,这些条件的目标仅仅是指定类的索引(1到类的数量)。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/35266217

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档