首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何加载逻辑回归模型?

如何加载逻辑回归模型?
EN

Stack Overflow用户
提问于 2017-12-07 16:53:32
回答 2查看 3.3K关注 0票数 4

我想用Java中的Apache Spark训练逻辑回归模型。作为第一步,我想只训练模型一次,并保存模型参数(截距和系数)。随后,使用保存的模型参数在稍后的时间点进行评分。我可以使用以下代码将模型保存在parquet文件中

代码语言:javascript
复制
LogisticRegressionModel trainedLRModel = logReg.fit(data);
trainedLRModel.write().overwrite().save("mypath");

当我加载模型进行评分时,我得到以下错误。

代码语言:javascript
复制
LogisticRegression lr = new LogisticRegression();
lr.load("//saved_model_path");

Exception in thread "main" java.lang.NoSuchMethodException: org.apache.spark.ml.classification.LogisticRegressionModel.<init>(java.lang.String)
    at java.lang.Class.getConstructor0(Class.java:3082)
    at java.lang.Class.getConstructor(Class.java:1825)
    at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:325)
    at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:215)
    at org.apache.spark.ml.classification.LogisticRegression$.load(LogisticRegression.scala:672)
    at org.apache.spark.ml.classification.LogisticRegression.load(LogisticRegression.scala)

有没有一种方法可以训练和保存模型,然后在以后进行评估(评分)?我在Java中使用Spark ML 2.1.0。

EN

回答 2

Stack Overflow用户

发布于 2018-05-05 17:11:23

我在使用PySpark2.1.1时遇到了同样的问题,当我从LogisticRegression改为LogisticRegressionModel时,一切都运行得很好。

代码语言:javascript
复制
LogisticRegression.load("/model/path") # not works 

LogisticRegressionModel.load("/model/path") # works well
票数 5
EN

Stack Overflow用户

发布于 2017-12-07 21:12:01

TL;DR使用LogisticRegressionModel.load

load( path : String):LogisticRegressionModel从输入路径读取ML实例,输入路径是read.load(路径)的快捷方式。

事实上,从Spark2.0.0开始,使用Spark MLlib的推荐方法包括。LogisticRegression estimator,使用的是全新闪亮的Pipeline API

代码语言:javascript
复制
import org.apache.spark.ml.classification._
val lr = new LogisticRegression()

import org.apache.spark.ml.feature._
val tok = new Tokenizer().setInputCol("body")
val hashTF = new HashingTF().setInputCol(tok.getOutputCol).setOutputCol("features")

import org.apache.spark.ml._
val pipeline = new Pipeline().setStages(Array(tok, hashTF, lr))

// training dataset
val emails = Seq(("hello world", 1)).toDF("body", "label")

val model = pipeline.fit(emails)

model.write.overwrite.save("mypath")
val loadedModel = PipelineModel.load("mypath")
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47691063

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档