我正在研究一个不平衡类的分类问题(5% 1)。我想预测的是班级,而不是概率。
在二进制分类问题中,scikit的classifier.predict()默认使用0.5吗?如果没有,默认的方法是什么?如果是的话,我该如何改变呢?
在scikit中,有些分类器有class_weight='auto'选项,但并不是所有的分类器都有。对于class_weight='auto',.predict()会使用实际的人口比例作为阈值吗?
在像MultinomialNB这样不支持class_weight的分类器中,有什么方法可以做到这一点呢?而不是使用predict_proba(),然后自己计算类。
发布于 2013-11-15 09:23:53
scikit的
classifier.predict()默认使用0.5吗?
在概率分类器中是的。正如其他人所解释的,从数学的角度来看,这是唯一合理的阈值。
在像MultinomialNB这样不支持
class_weight的分类器中,有什么方法可以做到这一点呢?
您可以设置class_prior,即每类y的先验概率P(y),这有效地移动了决策边界。例如。
# minimal dataset
>>> X = [[1, 0], [1, 0], [0, 1]]
>>> y = [0, 0, 1]
# use empirical prior, learned from y
>>> MultinomialNB().fit(X,y).predict([1,1])
array([0])
# use custom prior to make 1 more likely
>>> MultinomialNB(class_prior=[.1, .9]).fit(X,y).predict([1,1])
array([1])发布于 2018-08-03 19:32:20
可以使用clf.predict_proba()设置阈值
例如:
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state = 2)
clf.fit(X_train,y_train)
# y_pred = clf.predict(X_test) # default threshold is 0.5
y_pred = (clf.predict_proba(X_test)[:,1] >= 0.3).astype(bool) # set threshold as 0.3发布于 2016-02-09 19:32:18
在scikit学习中,二进制分类的阈值为0.5,多类分类的概率最大的类别为哪一类。在许多问题中,通过调整阈值可以得到更好的结果。然而,这必须小心,而不是对坚持测试数据,而是通过交叉验证的培训数据。如果您对测试数据的阈值做了任何调整,您只是对测试数据进行了过度拟合。
大多数调整阈值的方法都是基于接收机工作特性(ROC)和尤登J统计量的,但也可以通过其他方法来实现,比如用遗传算法进行搜索。
这是一篇同行评议期刊上的文章,描述了在医学上这样做的情况:
http://www.ncbi.nlm.nih.gov/pmc/articles/PMC2515362/
据我所知,在Python中没有实现它的包,但是在Python中使用蛮力搜索查找它相对简单(但效率低)。
这是一些R代码来完成它。
## load data
DD73OP <- read.table("/my_probabilites.txt", header=T, quote="\"")
library("pROC")
# No smoothing
roc_OP <- roc(DD73OP$tc, DD73OP$prob)
auc_OP <- auc(roc_OP)
auc_OP
Area under the curve: 0.8909
plot(roc_OP)
# Best threshold
# Method: Youden
#Youden's J statistic (Youden, 1950) is employed. The optimal cut-off is the threshold that maximizes the distance to the identity (diagonal) line. Can be shortened to "y".
#The optimality criterion is:
#max(sensitivities + specificities)
coords(roc_OP, "best", ret=c("threshold", "specificity", "sensitivity"), best.method="youden")
#threshold specificity sensitivity
#0.7276835 0.9092466 0.7559022https://stackoverflow.com/questions/19984957
复制相似问题