首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >是否可以在PySpark中使用带有OneVsRest的LinearSVC模型?

是否可以在PySpark中使用带有OneVsRest的LinearSVC模型?
EN

Stack Overflow用户
提问于 2019-10-10 16:13:32
回答 2查看 607关注 0票数 0

我正在尝试在PySpark的OneVsRest中使用LinearSVC模型,但似乎还不支持。

我的错误消息

代码语言:javascript
复制
LinearSVC only supports binary classification. 1 classes detected in LinearSVC_43a50b0b70d60a8cbdb1__labelCol

为了在PySpark中实现它,我需要做哪些更改?

有人知道Pyspark中的OneVsRest什么时候会支持LinearSVC吗?

EN

回答 2

Stack Overflow用户

发布于 2019-10-10 19:29:05

错误消息告诉您,您的数据集当前只包含一个类,但LinearSVM是一个二进制分类算法,它只需要两个类。我不确定您的其余代码是否会导致任何问题,因为您还没有发布任何内容。以防万一你或其他人需要它,看看下面。

就像alrady说的那样,LinearSVM是一个二进制分类算法,它永远不会根据定义支持多类分类,但你总是可以将一个多类分类问题简化为一个二进制分类问题。One-vs-Rest就是这样一种减少的方法。它为每个类训练一个分类器,从工程角度来看,将其分离到spark did这样的专用类是有意义的。OneVsRest为您的每个类训练一个分类器,并根据此分类器列表对给定样本进行评分。分数最高的分类器确定样本的预测标签。

请看下面的代码,了解如何在LinearSVC中使用OneVsRest:

代码语言:javascript
复制
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import OneVsRest, LinearSVC
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

df = spark.read.csv('/tmp/iris.data', schema='sepalLength DOUBLE, sepalWidth DOUBLE, petalLength DOUBLE, petalWidth DOUBLE, class STRING')


vecAssembler = VectorAssembler(inputCols=["sepalLength", "sepalWidth", "petalLength", 'petalWidth'], outputCol="features")
df = vecAssembler.transform(df)

stringIndexer = StringIndexer(inputCol="class", outputCol="label")
si_model = stringIndexer.fit(df)
df = si_model.transform(df)

svm = LinearSVC()
ovr = OneVsRest(classifier=svm)
ovrModel = ovr.fit(df)

evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

predictions = ovrModel.transform(df)

print("Accuracy: {}".format(evaluator.evaluate(predictions)))

输出:

代码语言:javascript
复制
Accuracy: 0.9533333333333334
票数 0
EN

Stack Overflow用户

发布于 2020-03-09 20:16:05

这是PySpark中的一个有趣的bug。如果您有多个类,则必须从零开始标识它们。

我刚刚经历了这个bug。我用与他们在LinearSVC guide中建议的相同的方式构建了一个数据帧。

代码语言:javascript
复制
df = sc.parallelize([
    Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
    Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()

最初,它是一个普通的RDD,然后我在Row中转换了每个RDD记录。我遇到了一个三类问题,其中类被命名为123。我实例化了一个OneVsRest对象(就像@cronoik建议的那样),并且我经历了与您相同的错误。

因此,我完全按照他们的用户指南(见上)初始化了df数据帧,并决定通过添加和删除类来开始使用它。所以我简单地用label=2.0替换了第二个模式中的label=0.0,错误就出现了。即使他们的数据帧,即使只有两个类。

因此,我将类的名称从123更改为012,错误就消失了。

希望这能有所帮助!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58318296

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档