首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >拟合多标签文本分类模型时出现的错误

拟合多标签文本分类模型时出现的错误
EN

Stack Overflow用户
提问于 2019-08-10 22:41:35
回答 1查看 97关注 0票数 0

我现在正在尝试为一个多标签文本分类问题建立一个分类模型。

我有一个包含已清理文本列表的训练集X_train,例如

代码语言:javascript
复制
["I am constructing Markov chains with  to  states and inferring     
transition probabilities empirically by simply counting how many 
times I saw each transition in my raw data",
"I know the chips only of the  players of my table and mine obviously I 
also know the total number of chips the max and min amount chips the 
players have and the average stackIs it possible to make an 
approximation of my probability of winningI have,
...]

X_train中的每个文本对应的训练多标签集合y,如

代码语言:javascript
复制
[['hypothesis-testing', 'statistical-significance', 'markov-process'],
['probability', 'normal-distribution', 'games'],
...]

现在,我想建立一个模型,该模型可以预测与X_train格式相同的文本集X_test中的标记。

我已经使用MultiLabelBinarizer转换标签,并使用TfidfVectorizer转换火车集中的清理文本。

代码语言:javascript
复制
multilabel_binarizer = MultiLabelBinarizer()
multilabel_binarizer.fit(y)
Y = multilabel_binarizer.transform(y)

vectorizer = TfidfVectorizer(stop_words = stopWordList)
vectorizer.fit(X_train)
x_train = vectorizer.transform(X_train)

但是,当我尝试拟合模型时,总是会遇到错误。我已经尝试过OneVsRestClassifierLogisticRegression

当我拟合一个OneVsRestClassifier模型时,我遇到了这样的错误

代码语言:javascript
复制
Traceback (most recent call last):
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 696, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 268, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 241, in poll
    if func():
  File "/usr/local/spark/python/pyspark/accumulators.py", line 245, in accum_updates
    num_updates = read_int(self.rfile)
  File "/usr/local/spark/python/pyspark/serializers.py", line 714, in read_int
    raise EOFError
EOFError

当我拟合一个LogisticRegression模型时,我遇到了这样的错误

代码语言:javascript
复制
/opt/conda/envs/data3/lib/python3.6/site-packages/sklearn/linear_model/sag.py:326: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
  "the coef_ did not converge", ConvergenceWarning)

有谁知道问题出在哪里以及如何解决这个问题吗?非常感谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-08-11 22:35:24

OneVsRestClassifier适合每个类一个分类器。您需要告诉它您想要哪种类型的分类器(例如,Losgistic回归)。

以下代码适用于我:

代码语言:javascript
复制
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

classifier = OneVsRestClassifier(LogisticRegression())
classifier.fit(x_train, Y)

X_test= ["I play with Markov chains"]
x_test = vectorizer.transform(X_test)

classifier.predict(x_test)

输出: array([0,1,1,0,0,0,1])

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57443096

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档