首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在RandomForestClassifier.predict()中使用单个字符串?

在RandomForestClassifier.predict()中使用单个字符串?
EN

Stack Overflow用户
提问于 2018-07-22 05:39:48
回答 1查看 871关注 0票数 0

我是个滑雪人..。我正在尝试从一个带有文本的RandomForestClassifier()中预测给定字符串的标签。

很明显,我不知道如何对单个字符串使用predict()。我之所以使用reshape(),是因为我在一段时间前收到了这个错误:“如果数据只有一个特征,则使用array.reshape(-1,1)重塑数据;如果数据只包含一个样本,则使用array.reshape(1,-1)重塑数据。”

如何预测单个文本字符串的标签?

脚本:

代码语言:javascript
复制
#!/usr/bin/env python
''' Read a txt file consisting of '<label>: <long string of text>'
    to use as a model for predicting the label for a string
'''

from argparse import ArgumentParser
import json
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder


def main(args):
    '''
    args: Arguments obtained by _Get_Args()
    '''

    print('Loading data...')
    # Load data from args.txtfile and split the lines into
    # two lists (labels, texts).
    data = open(args.txtfile).readlines()
    labels, texts = ([], [])
    for line in data:
        label, text = line.split(': ', 1)
        labels.append(label)
        texts.append(text)

    # Print a list of unique labels
    print(json.dumps(list(set(labels)), indent=4))

    # Instantiate a CountVectorizer class and git the texts
    # and labels into it.
    cv = CountVectorizer(
            stop_words='english',
            strip_accents='unicode',
            lowercase=True,
            )
    matrix = cv.fit_transform(texts)
    encoder = LabelEncoder()
    labels = encoder.fit_transform(labels)
    rf = RandomForestClassifier()
    rf.fit(matrix, labels)

    # Try to predict the label for args.string.
    prediction = Predict_Label(args.string, cv, rf)
    print(prediction)


def Predict_Label(string, cv, rf):
    '''
    string: str() - A string of text
    cv: The CountVectorizer class
    rf: The RandomForestClassifier class
    '''

    matrix = cv.fit_transform([string])
    matrix = matrix.reshape(1, -1)
    try:
        prediction = rf.predict(matrix)
    except Exception as E:
        print(str(E))
    else:
        return prediction


def _Get_Args():
    parser = ArgumentParser(description='Learn labels from text')
    parser.add_argument('-t', '--txtfile', required=True)
    parser.add_argument('-s', '--string', required=True)
    return parser.parse_args()


if __name__ == '__main__':
    args = _Get_Args()
    main(args)

实际的学习数据文本文件有43663行长,但有一个示例是small_list.txt格式的,每行都采用以下格式:<label>: <long text string>

错误记录在异常输出中:

代码语言:javascript
复制
$ ./learn.py -t small_list.txt -s 'This is a string that might have something to do with phishing or fraud'
Loading data...
[
    "Vulnerabilities__Unknown",
    "Vulnerabilities__MSSQL Browsing Service",
    "Fraud__Phishing",
    "Fraud__Copyright/Trademark Infringement",
    "Attacks and Reconnaissance__Web Attacks",
    "Vulnerabilities__Vulnerable SMB",
    "Internal Report__SBL Notify",
    "Objectionable Content__Russian Federation Objectionable Material",
    "Malicious Code/Traffic__Malicious URL",
    "Spam__Marketing Spam",
    "Attacks and Reconnaissance__Scanning",
    "Malicious Code/Traffic__Unknown",
    "Attacks and Reconnaissance__SSH Brute Force",
    "Spam__URL in Spam",
    "Vulnerabilities__Vulnerable Open Memcached",
    "Malicious Code/Traffic__Sinkhole",
    "Attacks and Reconnaissance__SMTP Brute Force",
    "Illegal content__Child Pornography"
]
Number of features of the model must match the input. Model n_features is 2070 and input n_features is 3 
None
EN

回答 1

Stack Overflow用户

发布于 2019-05-01 00:34:16

您需要获取第一个CountVectorizer (cv)的词汇表,并在预测之前使用来转换新的单个文本。

代码语言:javascript
复制
...

cv = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        )

matrix = cv.fit_transform(texts)
encoder = LabelEncoder()
labels = encoder.fit_transform(labels)
rf = RandomForestClassifier()
rf.fit(matrix, labels)

# Try to predict the label for args.string.
cv_new = CountVectorizer(
        stop_words='english',
        strip_accents='unicode',
        lowercase=True,
        vocabulary=cv.vocabulary_
        )
prediction = Predict_Label(args.string, cv_new, rf)
print(prediction)

...
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51460337

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档