我正在使用Iris数据集使用Back Propagation训练我的NN。代码附在附件中。
p = [
5.1,3.5,1.4,0.2; %iris data set
4.9,3.0,1.4,0.2;
4.7,3.2,1.3,0.2;
4.6,3.1,1.5,0.2;
5.0,3.6,1.4,0.2;
5.4,3.9,1.7,0.4;
4.6,3.4,1.4,0.3;
5.0,3.4,1.5,0.2;
4.4,2.9,1.4,0.2;
4.9,3.1,1.5,0.1;
5.4,3.7,1.5,0.2;
4.8,3.4,1.6,0.2;
4.8,3.0,1.4,0.1;
4.3,3.0,1.1,0.1;
5.8,4.0,1.2,0.2;
5.7,4.4,1.5,0.4;
5.4,3.9,1.3,0.4;
5.1,3.5,1.4,0.3;
5.7,3.8,1.7,0.3;
5.1,3.8,1.5,0.3;
5.4,3.4,1.7,0.2;
5.1,3.7,1.5,0.4;
4.6,3.6,1.0,0.2;
5.1,3.3,1.7,0.5;
4.8,3.4,1.9,0.2;
5.0,3.0,1.6,0.2;
5.0,3.4,1.6,0.4;
5.2,3.5,1.5,0.2;
5.2,3.4,1.4,0.2;
4.7,3.2,1.6,0.2;
4.8,3.1,1.6,0.2;
5.4,3.4,1.5,0.4;
5.2,4.1,1.5,0.1;
5.5,4.2,1.4,0.2;
4.9,3.1,1.5,0.1;
5.0,3.2,1.2,0.2;
5.5,3.5,1.3,0.2;
4.9,3.1,1.5,0.1;
4.4,3.0,1.3,0.2;
5.1,3.4,1.5,0.2;
5.0,3.5,1.3,0.3;
4.5,2.3,1.3,0.3;
4.4,3.2,1.3,0.2;
5.0,3.5,1.6,0.6;
5.1,3.8,1.9,0.4;
4.8,3.0,1.4,0.3;
5.1,3.8,1.6,0.2;
4.6,3.2,1.4,0.2;
5.3,3.7,1.5,0.2;
5.0,3.3,1.4,0.2;
7.0,3.2,4.7,1.4;
6.4,3.2,4.5,1.5;
6.9,3.1,4.9,1.5;
5.5,2.3,4.0,1.3;
6.5,2.8,4.6,1.5;
5.7,2.8,4.5,1.3;
6.3,3.3,4.7,1.6;
4.9,2.4,3.3,1.0;
6.6,2.9,4.6,1.3;
5.2,2.7,3.9,1.4;
5.0,2.0,3.5,1.0;
5.9,3.0,4.2,1.5;
6.0,2.2,4.0,1.0;
6.1,2.9,4.7,1.4;
5.6,2.9,3.6,1.3;
6.7,3.1,4.4,1.4;
5.6,3.0,4.5,1.5;
5.8,2.7,4.1,1.0;
6.2,2.2,4.5,1.5;
5.6,2.5,3.9,1.1;
5.9,3.2,4.8,1.8;
6.1,2.8,4.0,1.3;
6.3,2.5,4.9,1.5;
6.1,2.8,4.7,1.2;
6.4,2.9,4.3,1.3;
6.6,3.0,4.4,1.4;
6.8,2.8,4.8,1.4;
6.7,3.0,5.0,1.7;
6.0,2.9,4.5,1.5;
5.7,2.6,3.5,1.0;
5.5,2.4,3.8,1.1;
5.5,2.4,3.7,1.0;
5.8,2.7,3.9,1.2;
6.0,2.7,5.1,1.6;
5.4,3.0,4.5,1.5;
6.0,3.4,4.5,1.6;
6.7,3.1,4.7,1.5;
6.3,2.3,4.4,1.3;
5.6,3.0,4.1,1.3;
5.5,2.5,4.0,1.3;
5.5,2.6,4.4,1.2;
6.1,3.0,4.6,1.4;
5.8,2.6,4.0,1.2;
5.0,2.3,3.3,1.0;
5.6,2.7,4.2,1.3;
5.7,3.0,4.2,1.2;
5.7,2.9,4.2,1.3;
6.2,2.9,4.3,1.3;
5.1,2.5,3.0,1.1;
5.7,2.8,4.1,1.3;
6.3,3.3,6.0,2.5;
5.8,2.7,5.1,1.9;
7.1,3.0,5.9,2.1;
6.3,2.9,5.6,1.8;
6.5,3.0,5.8,2.2;
7.6,3.0,6.6,2.1;
4.9,2.5,4.5,1.7;
7.3,2.9,6.3,1.8;
6.7,2.5,5.8,1.8;
7.2,3.6,6.1,2.5;
6.5,3.2,5.1,2.0;
6.4,2.7,5.3,1.9;
6.8,3.0,5.5,2.1;
5.7,2.5,5.0,2.0;
5.8,2.8,5.1,2.4;
6.4,3.2,5.3,2.3;
6.5,3.0,5.5,1.8;
7.7,3.8,6.7,2.2;
7.7,2.6,6.9,2.3;
6.0,2.2,5.0,1.5;
6.9,3.2,5.7,2.3;
5.6,2.8,4.9,2.0;
7.7,2.8,6.7,2.0;
6.3,2.7,4.9,1.8;
6.7,3.3,5.7,2.1;
7.2,3.2,6.0,1.8;
6.2,2.8,4.8,1.8;
6.1,3.0,4.9,1.8;
6.4,2.8,5.6,2.1;
7.2,3.0,5.8,1.6;
7.4,2.8,6.1,1.9;
7.9,3.8,6.4,2.0;
6.4,2.8,5.6,2.2;
6.3,2.8,5.1,1.5;
6.1,2.6,5.6,1.4;
7.7,3.0,6.1,2.3;
6.3,3.4,5.6,2.4;
6.4,3.1,5.5,1.8;
6.0,3.0,4.8,1.8;
6.9,3.1,5.4,2.1;
6.7,3.1,5.6,2.4;
6.9,3.1,5.1,2.3;
5.8,2.7,5.1,1.9;
6.8,3.2,5.9,2.3;
6.7,3.3,5.7,2.5;
6.7,3.0,5.2,2.3;
6.3,2.5,5.0,1.9;
6.5,3.0,5.2,2.0;
6.2,3.4,5.4,2.3;
5.9,3.0,5.1,1.8;
]';
t = [
0; %assign 0 to output neuron for Iris-setosa
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0;
0.5; %assign 0.5 to output neuron for Iris-versicolor
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
0.5;
1; %assign 1 to output neuron for Iris-virginica
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
1;
]';
net = feedforwardnet(3,'traingd'); %3 hidden layers and training algorithm
net = configure(net,p,t);
net.layers{2}.transferFcn = 'logsig'; %sigmoid function in output layer
net.layers{1}.transferFcn = 'logsig'; %sigmiod fucntion in hidden layer
net.performFcn = 'mse';
net = init(net);
net.trainParam.epochs = 10000;
net.trainParam.lr = 0.7; %learning rate
net.trainParam.goal = 0.01; %mse
net = train(net,p,t);
view(net);问题是我没有得到第一个类的期望输出(它的输出应该接近于零)。当我将第一个类的向量输入到训练好的网络中时,输出接近0.5 (但它应该接近于零)。
这是第一个类的第一个向量的输出:
输出=净(5.1,3.5,1.4,0.2‘)
输出= 0.5003
这个输出应该接近于零(因为我已经将0赋给了第一个类),但结果是0.5。第一类的所有输入都是这种情况。对于第二类和第三类,输出很好,即第二类接近0.5,第三类接近1.0。
你能运行这段代码并告诉我我哪里做错了吗?
(我认为这可能是偏置输入的问题,因为类1的所有输出都被偏移了0.5。)
致以问候。
发布于 2015-06-18 22:50:58
你有没有尝试过用较低的均方误差进行训练(在你的代码中net.trainParam.goal = 0.01;)?尝试使用net.trainParam.goal = 0.0001或更低的值。
https://stackoverflow.com/questions/29344010
复制相似问题