首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用BinaryRelevance修复BinaryRelevance,甚至使用csr_matrix?

如何使用BinaryRelevance修复BinaryRelevance,甚至使用csr_matrix?
EN

Stack Overflow用户
提问于 2021-07-28 17:48:01
回答 1查看 182关注 0票数 1

我试图用来自kaggle的有毒评论数据来预测有毒的评论

代码语言:javascript
复制
import skmultilearn, sys
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from scipy.sparse import csr_matrix, issparse
from sklearn.naive_bayes import MultinomialNB
from skmultilearn.problem_transform import BinaryRelevance

data_frame = pd.read_csv('data/train.csv')
corpus = data_frame['comment_text']
tfidf = TfidfVectorizer()
Xfeatures = csr_matrix(tfidf.fit_transform(corpus))
y = csr_matrix(data_frame[['toxic','severe_toxic','obscene','threat','insult','identity_hate']])
binary_rel_clf = BinaryRelevance(MultinomialNB())
binary_rel_clf.fit(Xfeatures,y)
predict_text = ['fuck die shit moron suck']
X_predict = tfidf.transform(predict_text)
br_prediction = binary_rel_clf.predict(X_predict)
br_prediction = br_prediction.toarray().astype(bool)
predictions = [y.columns.values[prediction].tolist() for prediction in br_prediction]
print(predictions)

但是,我得到了以下错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "...\multi_label_toxic.py", line 15, in <module>
    binary_rel_clf.fit(Xfeatures,y)
  File "...\problem_transform\br.py", line 161, in fit
    classifier.fit(self._ensure_input_format(
  File "...\base\base.py", line 86, in _ensure_input_format
    return X.toarray()
  File "...\scipy\sparse\compressed.py", line 1031, in toarray
    out = self._process_toarray_args(order, out)
  File "...\scipy\sparse\base.py", line 1202, in _process_toarray_args
    return np.zeros(self.shape, dtype=self.dtype, order=order)
numpy.core._exceptions._ArrayMemoryError: Unable to allocate 226. GiB for an array with shape (159571, 189775) and data type float64

即使试图传递param "require_dense=False“,我也得到了另一个错误:

代码语言:javascript
复制
Traceback (most recent call last):
  File "...\multi_label_toxic.py", line 15, in <module>
    binary_rel_clf.fit(Xfeatures,y)
  File "...\skmultilearn\problem_transform\br.py", line 161, in fit
    classifier.fit(self._ensure_input_format(
  File "...\sklearn\naive_bayes.py", line 612, in fit
    X, y = self._check_X_y(X, y)
  File "...\sklearn\naive_bayes.py", line 477, in _check_X_y
    return self._validate_data(X, y, accept_sparse='csr')
  File "...\sklearn\base.py", line 433, in _validate_data
    X, y = check_X_y(X, y, **check_params)
  File "...\sklearn\utils\validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "...\sklearn\utils\validation.py", line 826, in check_X_y
    y = column_or_1d(y, warn=True)
  File "...\sklearn\utils\validation.py", line 63, in inner_f
    return f(*args, **kwargs)
  File "...\sklearn\utils\validation.py", line 864, in column_or_1d
    raise ValueError(
ValueError: y should be a 1d array, got an array of shape () instead.

我怎样才能解决这个问题,用整个模型来训练呢?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-31 18:29:33

您似乎不正确地指定了required_dense参数。您需要required_dense=False,True来指定稀疏格式的X值,而不是y值。在第二行(预测=.)在将其转换为矩阵之前,需要使用y,这样才能访问列名。下面的代码应该可以工作。

代码语言:javascript
复制
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from scipy.sparse import csr_matrix, issparse
from sklearn.naive_bayes import MultinomialNB
from skmultilearn.problem_transform import BinaryRelevance
import numpy as np

data_frame = pd.read_csv('data/train.csv')
corpus = data_frame['comment_text']
tfidf = TfidfVectorizer()
Xfeatures = csr_matrix(tfidf.fit_transform(corpus))
cats = data_frame[['toxic','severe_toxic','obscene','threat','insult','identity_hate']]
y = csr_matrix(cats)
binary_rel_clf = BinaryRelevance(MultinomialNB(), require_dense = [False, True])
binary_rel_clf.fit(Xfeatures, y) # y[:,0].toarray().reshape(-1)
predict_text = ['fuck die shit moron suck']
X_predict = tfidf.transform(predict_text)
br_prediction = binary_rel_clf.predict(X_predict)
br_prediction = br_prediction.toarray().astype(bool)
predictions = [cats.columns[prediction].tolist() for prediction in br_prediction]
print(predictions)

输出:

代码语言:javascript
复制
[['toxic', 'obscene', 'insult']]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68565172

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档