首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Deeplearning4j预测二手车价格

Deeplearning4j预测二手车价格
EN

Stack Overflow用户
提问于 2019-05-21 10:15:12
回答 1查看 2.3K关注 0票数 1

我想预测二手车的价格,我有销售汽车的历史数据。我将数值缩放为0-1,并使其他特性成为热点。

代码语言:javascript
复制
public RestResponse<JSONObject> buildModelDl4j( HttpServletRequest request, HttpServletResponse response, @RequestBody Map<String, String> json ) throws IOException
{
    RestResponse<JSONObject> restResponse = ControllerBase.getRestResponse( request, response, null ) ;

    String path = "\\HOME_EXCHANGE\\uploads\\" + json.get( "filePath" ) ;

    int numLinesToSkip = 1 ;
    char delimiter = ',' ;

    RecordReader recordReader = new CSVRecordReader( numLinesToSkip, delimiter ) ;

    try
    {
        recordReader.initialize( new FileSplit( new File( path ) ) ) ;
    }
    catch( InterruptedException e )
    {
        e.printStackTrace( ) ;
    }

    DataSetIterator iter = new RecordReaderDataSetIterator( recordReader, batchSize, indexToCalc, indexToCalc, true ) ;
    json.put( "numAttr", String.valueOf( numAttr ) ) ;

    //        ds.shuffle( ) ;   //TODO should I shuffle the data ?

    MultiLayerNetwork net = buildNetwork( json ) ;

    net.init( ) ;

    net.setListeners( new ScoreIterationListener( 30 ) ) ;

    DataSet testData = null ;

    for( int i = 0; i < nEpochs; i++ )
    {
        iter.reset( ) ;

        while( iter.hasNext( ) )
        {
            DataSet ds = iter.next( ) ;
            SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
            DataSet trainingData = testAndTrain.getTrain( ) ;
            testData = testAndTrain.getTest( ) ;
            net.fit( trainingData ) ;
        }

        iter.reset( ) ;

        int cnt = 0 ;
        while( iter.hasNext( ) && cnt++ < 3 )
        {
            DataSet ds = iter.next( ) ;
            SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
            testData = testAndTrain.getTest( ) ;
            String testResults = testResults( net, testData, indexToCalc ) ;
            System.err.println( "Test results:  [" + i + "]  \n" + testResults ) ;
        }

    }

    RegressionEvaluation eval = new RegressionEvaluation( ) ;
    INDArray output = net.output( testData.getFeatures( ) ) ;
    eval.eval( testData.getLabels( ), output ) ;
    System.out.println( eval.stats( ) ) ;

    String testResults = testResults( net, testData, indexToCalc ) ;

    result.put( "testResults", testResults ) ;

    System.err.println( "Test results last: \n" + testResults ) ;

    restResponse.setData( result ) ;

    return restResponse ;
}

我用前端传递的参数建立模型,从csv文件中读取数据,然后对模型进行训练。我做的对吗?我应该如何使用测试和训练数据?在示例中有两种方法,它们使用

代码语言:javascript
复制
DataSet ds = iter.next( ) ;
SplitTestAndTrain testAndTrain = ds.splitTestAndTrain( splitRate / 100.0 ) ;
DataSet trainingData = testAndTrain.getTrain( ) ;
testData = testAndTrain.getTest( ) ;
net.fit( trainingData ) ;

代码语言:javascript
复制
for( int i = 0; i < nEpochs; i++ )
{
  net.fit( iter ) ;
  iter.reset( ) ;
}

哪一种是正确的方法?

EN

回答 1

Stack Overflow用户

发布于 2019-06-02 10:12:19

我用前端传递的参数建立模型,从csv文件中读取数据,然后对模型进行训练。我做的对吗?我应该如何使用测试和训练数据?

一个更好的方法是使用DataSetIteratorSplitter,如下所示:

代码语言:javascript
复制
DataSetIteratorSplitter dataSetIteratorSplitter = new DataSetIteratorSplitter(dataSetIterator,totalNumBatches,ratio);
multiLayerNetwork.fit(dataSetIteratorSplitter.getTrainIterator(),epochCount);

totalNumBatches将是总数数据集除以小批量大小。例如,如果您有10000个数据集,并且假设我们在一个批处理中分配了8个样本,那么总共有1250个批。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56236060

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档