首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >BinaryClassificationMetrics中Spark2.4.4度量属性的错误

BinaryClassificationMetrics中Spark2.4.4度量属性的错误
EN

Stack Overflow用户
提问于 2020-01-10 09:58:19
回答 1查看 152关注 0票数 0

我试图复制这个Spark / Scala示例,但是当我试图从处理的.csv文件中提取一些度量时,我得到了一个错误。

我的代码片段:

代码语言:javascript
复制
val splitSeed = 5043
val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)

trainingData.show(20);

// Fit the model
val model = lr.fit(trainingData)

// Print the coefficients and intercept for logistic regression
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")

// run the  model on test features to get predictions**
val predictions = model.transform(testData)
//As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.**
testData.show()

// run the  model on test features to get predictions**
val predictions = model.transform(testData)
//As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.**
predictions.show()

// use MLlib to evaluate, convert DF to RDD**
val myRdd = predictions.select("rawPrediction", "label").rdd

val predictionAndLabels = myRdd.map(x => (x(0).asInstanceOf[DenseVector](1), x(1).asInstanceOf[Double]))
// Instantiate metrics object
val metrics = new BinaryClassificationMetrics(predictionAndLabels)
println("area under the precision-recall curve: " + metrics.areaUnderPR)
println("area under the receiver operating characteristic (ROC) curve : " + metrics.areaUnderROC)
// A Precision-Recall curve plots (precision, recall) points for different threshold values, while a
// receiver operating characteristic, or ROC, curve plots (recall, false positive rate) points.
// The closer  the area Under ROC is to 1, the better the model is making predictions.**

当我试图知道属性areaUnderPR 时,我得到了以下错误:

20/01/10 10:41:02警告TaskSetManager:在56.0阶段丢失任务0.0 (TID 246,10.10.252.172,执行者1):java.net.URLClassLoader.findClass(URLClassLoader.java:382):prediction.TestCancerOriginal$$anonfun$1 at java.lang.ClassLoader.loadClass(ClassLoader.java:424) at java.lang.ClassLoader.loadClass(ClassLoader.java:357) at java.lang.Class.forName0(原生方法)在java.lang.Class.forName(Class.java:348) at org.apache.spark.serializer.JavaDeserializationStream$$anon$1.resolveClass(JavaSerializer.scala:67)在java.io.ObjectInputStream.readNonProxyDesc(ObjectInputStream.java:1868) at java.io.ObjectInputStream.readClassDesc(ObjectInputStream.java:1751) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2042) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2287) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2211) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStreamjava.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2287) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2211) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2069) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2287) at java.io.ObjectInputStreamjava.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2069),java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1573),java.io.ObjectInputStream.readObject(ObjectInputStream.java:431),org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75),org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114),org.apache.spark.scheduler.ShuffleMapTask.runTask( org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:55) at org.apache.spark.scheduler.Task.run(Task.scala:123) at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414) at java )。util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748)

My predictions.show结果:

代码语言:javascript
复制
+------+---------+----+-----+----+------+----+------+----+---+----+------------+--------------------+-----+--------------------+--------------------+----------+
|    id|thickness|size|shape|madh|epsize|bnuc|bchrom|nNuc|mit|clas|clasLogistic|            features|label|       rawPrediction|         probability|prediction|
+------+---------+----+-----+----+------+----+------+----+---+----+------------+--------------------+-----+--------------------+--------------------+----------+
| 63375|      9.0| 1.0|  2.0| 6.0|   4.0|10.0|   7.0| 7.0|2.0|   4|           1|[9.0,1.0,2.0,6.0,...|  1.0|[0.36391634252951...|[0.58998813846052...|       0.0|
|128059|      1.0| 1.0|  1.0| 1.0|   2.0| 5.0|   5.0| 1.0|1.0|   2|           0|[1.0,1.0,1.0,1.0,...|  0.0|[0.81179252636135...|[0.69249134920886...|       0.0|
|145447|      8.0| 4.0|  4.0| 1.0|   2.0| 9.0|   3.0| 3.0|1.0|   4|           1|[8.0,4.0,4.0,1.0,...|  1.0|[0.06964047482828...|[0.51740308582457...|       0.0|
|183913|      1.0| 2.0|  2.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[1.0,2.0,2.0,1.0,...|  0.0|[0.96139876234944...|[0.72340177322811...|       0.0|
|342245|      1.0| 1.0|  3.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[1.0,1.0,3.0,1.0,...|  0.0|[0.95750903648839...|[0.72262279564412...|       0.0|
|434518|      3.0| 1.0|  1.0| 1.0|   2.0| 1.0|   2.0| 1.0|1.0|   2|           0|[3.0,1.0,1.0,1.0,...|  0.0|[1.10995557408198...|[0.75212082898242...|       0.0|
|493452|      1.0| 1.0|  3.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[1.0,1.0,3.0,1.0,...|  0.0|[0.95750903648839...|[0.72262279564412...|       0.0|
|508234|      7.0| 4.0|  5.0|10.0|   2.0|10.0|   3.0| 8.0|2.0|   4|           1|[7.0,4.0,5.0,10.0...|  1.0|[-0.0809133769755...|[0.47978268474014...|       1.0|
|521441|      5.0| 1.0|  1.0| 2.0|   2.0| 1.0|   2.0| 1.0|1.0|   2|           0|[5.0,1.0,1.0,2.0,...|  0.0|[1.10995557408198...|[0.75212082898242...|       0.0|
|527337|      4.0| 1.0|  1.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[4.0,1.0,1.0,1.0,...|  0.0|[1.11079628977456...|[0.75227753466134...|       0.0|
|534555|      1.0| 1.0|  1.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[1.0,1.0,1.0,1.0,...|  0.0|[1.11079628977456...|[0.75227753466134...|       0.0|
|535331|      3.0| 1.0|  1.0| 1.0|   3.0| 1.0|   2.0| 1.0|1.0|   2|           0|[3.0,1.0,1.0,1.0,...|  0.0|[1.10995557408198...|[0.75212082898242...|       0.0|
|558538|      4.0| 1.0|  3.0| 3.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[4.0,1.0,3.0,3.0,...|  0.0|[0.95750903648839...|[0.72262279564412...|       0.0|
|560680|      1.0| 1.0|  1.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[1.0,1.0,1.0,1.0,...|  0.0|[1.11079628977456...|[0.75227753466134...|       0.0|
|601265|     10.0| 4.0|  4.0| 6.0|   2.0|10.0|   2.0| 3.0|1.0|   4|           1|[10.0,4.0,4.0,6.0...|  1.0|[-0.0034290346398...|[0.49914274218002...|       1.0|
|603148|      4.0| 1.0|  1.0| 1.0|   2.0| 1.0|   1.0| 1.0|1.0|   2|           0|[4.0,1.0,1.0,1.0,...|  0.0|[1.11079628977456...|[0.75227753466134...|       0.0|
|606722|      5.0| 5.0|  7.0| 8.0|   6.0|10.0|   7.0| 4.0|1.0|   4|           1|[5.0,5.0,7.0,8.0,...|  1.0|[-0.3103173938140...|[0.42303726852941...|       1.0|
|616240|      5.0| 3.0|  4.0| 3.0|   4.0| 5.0|   4.0| 7.0|1.0|   2|           0|[5.0,3.0,4.0,3.0,...|  0.0|[0.43719456056061...|[0.60759034803682...|       0.0|
|640712|      1.0| 1.0|  1.0| 1.0|   2.0| 1.0|   2.0| 1.0|1.0|   2|           0|[1.0,1.0,1.0,1.0,...|  0.0|[1.10995557408198...|[0.75212082898242...|       0.0|
|654546|      1.0| 1.0|  1.0| 1.0|   2.0| 1.0|   1.0| 1.0|8.0|   2|           0|[1.0,1.0,1.0,1.0,...|  0.0|[1.11079628977456...|[0.75227753466134...|       0.0|
+------+---------+----+-----+----+------+----+------+----+---+----+------------+--------------------+-----+--------------------+--------------------+----------+
only showing top 20 rows
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-12 10:58:02

我在这里看到的一个错误是,您将rawPrediction列传递给BinaryClassificationMetrics对象,而不是prediction列。rawPrediction包含每个类具有某种“概率”的数组,而BinaryClassificationMetrics需要一个双值,正如它的签名所指定的:

代码语言:javascript
复制
new BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)])

您可以看到细节这里

我对这个修改做了一个快速的测试,它看起来很有效,下面是代码片段:

代码语言:javascript
复制
import org.apache.spark.sql.{Encoders, SparkSession}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.functions._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics



case class Obs(id: Int, thickness: Double, size: Double, shape: Double, madh: Double,
               epsize: Double, bnuc: Double, bchrom: Double, nNuc: Double, mit: Double, clas: Double)
val obsSchema = Encoders.product[Obs].schema

val spark = SparkSession.builder
  .appName("StackoverflowQuestions")
  .master("local[*]")
  .getOrCreate()
// Implicits necessary to transform DataFrame to Dataset using .as[] method
import spark.implicits._


val df = spark.read
              .schema(obsSchema)
              .csv("breast-cancer-wisconsin.data")
              .drop("id")
              .withColumn("clas", when(col("clas").equalTo(4.0), 1.0).otherwise(0.0))
              .na.drop() // Make sure to drop nulls, or the feature assemble will fail

//define the feature columns to put in the feature vector**
val featureCols = Array("thickness", "size", "shape", "madh", "epsize", "bnuc", "bchrom", "nNuc", "mit")
//set the input and output column names**
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
//return a dataframe with all of the  feature columns in  a vector column**
val df2 = assembler.transform(df)
//  Create a label column with the StringIndexer**
val labelIndexer = new StringIndexer().setInputCol("clas").setOutputCol("label")
val df3 = labelIndexer.fit(df2).transform(df2)

val splitSeed = 5043
val Array(trainingData, testData) = df3.randomSplit(Array(0.7, 0.3), splitSeed)

val lr = new LogisticRegression()
  .setMaxIter(10)
  .setRegParam(0.3)
  .setElasticNetParam(0.8)

trainingData.show(20);

// Fit the model
val model = lr.fit(trainingData)

// Print the coefficients and intercept for logistic regression
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")

// run the  model on test features to get predictions**
val predictions = model.transform(testData)
//As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction.**
predictions.show(truncate=false)

// use MLlib to evaluate, convert DF to RDD**
val predictionAndLabels = predictions.select("prediction", "label").as[(Double, Double)].rdd

// Instantiate metrics object
val metrics = new BinaryClassificationMetrics(predictionAndLabels)
println("area under the precision-recall curve: " + metrics.areaUnderPR)
println("area under the receiver operating characteristic (ROC) curve : " + metrics.areaUnderROC)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59679368

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档