首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >spark线性回归模型迁移到1.6.1后的训练失败

spark线性回归模型迁移到1.6.1后的训练失败
EN

Stack Overflow用户
提问于 2016-06-03 14:58:01
回答 1查看 274关注 0票数 0

我用火花-毫升来训练线性回归模型。它非常适用于spark版本1.5.2,但现在对于1.6.1,我得到了以下错误:

代码语言:javascript
复制
java.lang.AssertionError: assertion failed: lapack.dppsv returned 228.

它似乎与一些低水平的线性代数库有关,但它在火花版本更新之前工作得很好。

在这两个版本中,我在培训开始前都会收到相同的警告,表示它不能加载BLAS和LAPACK。

代码语言:javascript
复制
[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
[Executor task launch worker-6] com.github.fommil.netlib.BLAS - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
[main]    com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeSystemLAPACK
[main]    com.github.fommil.netlib.LAPACK - Failed to load implementation from: com.github.fommil.netlib.NativeRefLAPACK

下面是一个最低限度的代码:

代码语言:javascript
复制
import java.util.ArrayList;
import java.util.List;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.OneHotEncoder;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;

public class Application {

    public static void main(String args[]) {

        // create context
        JavaSparkContext javaSparkContext = new JavaSparkContext("local[*]", "CalculCote");
        SQLContext sqlContext = new SQLContext(javaSparkContext);

        // describre fields
        List<StructField> fields = new ArrayList<StructField>();
        fields.add(DataTypes.createStructField("brand", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("commercial_name", DataTypes.StringType, true));
        fields.add(DataTypes.createStructField("mileage", DataTypes.IntegerType, true));
        fields.add(DataTypes.createStructField("price", DataTypes.DoubleType, true));

        // load dataframe from file
        DataFrame df = sqlContext.read().format("com.databricks.spark.csv") //
                .option("header", "true") //
                .option("InferSchema", "false") //
                .option("delimiter", ";") //
                .schema(DataTypes.createStructType(fields)) //
                .load("input.csv").persist();

        // show first rows
        df.show();

        // indexers and encoders for non numerical values
        StringIndexer brandIndexer = new StringIndexer() //
                .setInputCol("brand") //
                .setOutputCol("brandIndex");

        OneHotEncoder brandEncoder = new OneHotEncoder() //
                .setInputCol("brandIndex") //
                .setOutputCol("brandVec");

        StringIndexer commNameIndexer = new StringIndexer() //
                .setInputCol("commercial_name") //
                .setOutputCol("commNameIndex");

        OneHotEncoder commNameEncoder = new OneHotEncoder() //
                .setInputCol("commNameIndex") //
                .setOutputCol("commNameVec");

        // model predictors
        VectorAssembler predictors = new VectorAssembler() //
                .setInputCols(new String[] { "brandVec", "commNameVec", "mileage" }) //
                .setOutputCol("features");

        // train model
        LinearRegression lr = new LinearRegression().setLabelCol("price");

        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { //
                brandIndexer, brandEncoder, commNameIndexer, commNameEncoder, predictors, lr });

        PipelineModel pm = pipeline.fit(df);

        DataFrame result = pm.transform(df);

        result.show();
    }
}

和input.csv数据

代码语言:javascript
复制
brand;commercial_name;mileage;price
APRILIA;ATLANTIC 125;18237;1400
BMW;R1200 GS;10900;12400
HONDA;CB 1000;58225;4250
HONDA;CB 1000;1780;7610
HONDA;CROSSRUNNER 800;2067;11490
KAWASAKI;ER-6F 600;51600;2010
KAWASAKI;VERSYS 1000;5900;13900
KAWASAKI;VERSYS 650;3350;6200
KTM;SUPER DUKE 990;36420;4760

the pom.xml

代码语言:javascript
复制
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <groupId>test</groupId>
    <artifactId>sparkmigration</artifactId>
    <packaging>jar</packaging>
    <name>sparkmigration</name>
    <version>0.0.1</version>
    <url>http://maven.apache.org</url>



    <properties>
        <java.version>1.8</java.version>
        <spark.version>1.6.1</spark.version>
<!--        <spark.version>1.5.2</spark.version> -->
        <spark.csv.version>1.3.0</spark.csv.version>
        <slf4j.version>1.7.2</slf4j.version>
        <logback.version>1.0.9</logback.version>
    </properties>


    <dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.11</artifactId>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-log4j12</artifactId>
                </exclusion>
            </exclusions>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.11</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>com.databricks</groupId>
            <artifactId>spark-csv_2.11</artifactId>
            <version>${spark.csv.version}</version>
        </dependency>


        <!-- Logs -->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>log4j-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>jcl-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-core</artifactId>
            <version>${logback.version}</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>   

    </dependencies>


    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.2</version>
                <configuration>
                    <source>${java.version}</source>
                    <target>${java.version}</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-06-15 10:01:13

问题修复(谢谢apache星星之火邮件列表)

从spark 1.6开始,线性回归模型被设置为"auto",在一些代码中(features<=4096,无弹性网络参数集,.),WSL algo代替L。

我强迫求解者使用l-bfgs,它起作用了。

代码语言:javascript
复制
LinearRegression lr = new LinearRegression().setLabelCol("price").setSolver("l-bfgs");
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/37617628

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档