首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在MATLAB上训练CNN时如何设置验证集(trainNetwork())

在MATLAB上训练CNN时如何设置验证集(trainNetwork())
EN

Stack Overflow用户
提问于 2017-07-24 01:05:14
回答 1查看 1.1K关注 0票数 0

我在试着用MATLAB训练CNN。matlab文档说,加载数据,设置层和选项。最后使用trainNetwork()进行训练。

代码语言:javascript
复制
    layers = [imageInputLayer([28 28 1])
          convolution2dLayer(5,10,...
                                'Stride',1,...
                                'Padding',[0,0])
          reluLayer
          maxPooling2dLayer(2,'Stride',2)
          fullyConnectedLayer(10)
          softmaxLayer
          classificationLayer];


 options = trainingOptions('sgdm',...Environment
                            'CheckpointPath','',...
                            'ExecutionEnvironment','gpu',...                'auto'  | 'cpu' | 'gpu' | 'multi-gpu' | 'parallel'
                            'InitialLearnRate',0.0001,...   Learning Rate
                            'LearnRateSchedule','none',...                  none    |piecewise
                            'LearnRateDropPeriod',10,...
                            'LearnRateDropFactor',0.1,...
                            'L2Regularization',0.0001,...   Regularization
                            'MaxEpochs',15,...              Epochs
                            'MiniBatchSize',128,...         Batch           128     |
                            'Momentum',0.9,...                              0.9     |
                            'Shuffle','once',...                            once    |never
                            'Verbose',1,...                                 1       | 0             — Indicator to display the information on the training progress
                            'VerboseFrequency',100,...                      50      | 0 
                            'OutputFcn',@plotTrainingAccuracy);

convnet = trainNetwork(trainDigitData,layers,options);

下面是我训练CNN的程序,但问题是我找不到设置验证集的选项。我设置的‘纪元’数字越大,它训练的时间就越长。甚至在过度拟合之前它就会停止吗?

不喜欢nnstart工具箱,当训练一个NN时,它会显示交叉熵和验证,训练错误率。

那么,在matlab上训练CNN时,你通常使用什么?使用第三方lib接口,比如caffe?还是自己写程序?

EN

回答 1

Stack Overflow用户

发布于 2018-11-15 06:05:34

您可以将数据拆分为训练数据和测试数据

代码语言:javascript
复制
idx = floor(0.8 * height(data));
trainingData = data(1:idx,:);
testData = data(idx:end,:);

然后,在trainNetwork之后,您可以运行测试部分

代码语言:javascript
复制
resultsStruct = struct([]);

for i = 1:height(testData)

    % Read the image.
    I = imread(testData.imageFilename{i});
    % Run the detector.
    [bboxes, scores, labels] = detect(detector, I);

    % Collect the results.
    resultsStruct(i).Boxes = bboxes;
    resultsStruct(i).Scores = scores;
    resultsStruct(i).Labels = labels;
end

% Convert the results into a table.
results = struct2table(resultsStruct);

如果你想检查实现更快的R-CNN实现,https://www.mathworks.com/help/vision/examples/object-detection-using-faster-r-cnn-deep-learning.html

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45267723

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档