首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >R中的AODE机器学习

R中的AODE机器学习
EN

Stack Overflow用户
提问于 2020-03-12 03:36:57
回答 1查看 82关注 0票数 1

我想知道真正的AODE是否比朴素的Bayes更好,就像描述中说的那样:

https://cran.r-project.org/web/packages/AnDE/AnDE.pdf

-> "AODE通过对所有小空间进行平均来实现高度精确的分类.“

https://www.quora.com/What-is-the-difference-between-a-Naive-Bayes-classifier-and-AODE

-> "AODE是放松朴素bayes独立性假设的一种奇怪方法。它不再是一种生成模型,但它以一种与logistic回归略有不同(且不那么原则性)的方式放松独立性假设。它用二次(关于特征数)依赖于训练和测试时间来代替用于训练logistic回归分类器的凸优化问题。“

但是当我进行实验时,我发现预测结果似乎是错误的,我用以下代码实现了它:

代码语言:javascript
复制
library(gmodels)
library(AnDE)
AODE_Model = aode(iris)
predict_aode = predict(AODE_Model, iris)
CrossTable(as.numeric(iris$Species), predict_aode) 

有人能跟我解释一下吗?或者有什么好的实用解决方案来实现AODE?事先谢谢你

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-12 13:00:01

如果您查看该函数的vignette:

列车: data.frame :培训数据。它应该是一个数据框架。AODE只工作离散数据。在将数据帧传递给此function.However之前,最好先对其进行离散化,如果不进行手工操作,则对数据进行离散化。它使用了一个称为离散化的R包。它使用了众所周知的MDL离散化技术(有时可能会失败)

默认情况下,arules的离散化函数将其分割为3,这可能对虹膜不够。因此,我首先用arules的离散化来再现结果:

代码语言:javascript
复制
library(arules)
library(gmodels)
library(AnDE)
set.seed(111)
trn = sample(1:nrow(indata),100)
test = setdiff(1:nrow(indata),trn)

indata <- data.frame(lapply(iris[,1:4],discretize,breaks=3),Species=iris$Species)
AODE_Model = aode(indata[trn,])
predict_aode = predict(AODE_Model, indata[test,])
CrossTable(as.numeric(indata$Species)[test], predict_aode)

                                 | predict_aode 
as.numeric(indata$Species)[test] |         1 |         3 | Row Total | 
---------------------------------|-----------|-----------|-----------|
                               1 |        15 |         5 |        20 | 
                                 |     0.500 |     4.500 |           | 
                                 |     0.750 |     0.250 |     0.400 | 
                                 |     0.333 |     1.000 |           | 
                                 |     0.300 |     0.100 |           | 
---------------------------------|-----------|-----------|-----------|
                               2 |        11 |         0 |        11 | 
                                 |     0.122 |     1.100 |           | 
                                 |     1.000 |     0.000 |     0.220 | 
                                 |     0.244 |     0.000 |           | 
                                 |     0.220 |     0.000 |           | 
---------------------------------|-----------|-----------|-----------|
                               3 |        19 |         0 |        19 | 
                                 |     0.211 |     1.900 |           | 
                                 |     1.000 |     0.000 |     0.380 | 
                                 |     0.422 |     0.000 |           | 
                                 |     0.380 |     0.000 |           | 
---------------------------------|-----------|-----------|-----------|
                    Column Total |        45 |         5 |        50 | 
                                 |     0.900 |     0.100 |           | 
---------------------------------|-----------|-----------|-----------|

您可以看到预测中缺少一个类。让我们把它提高到4:

代码语言:javascript
复制
indata <- data.frame(lapply(iris[,1:4],discretize,breaks=4),Species=iris$Species)
AODE_Model = aode(indata[trn,])
predict_aode = predict(AODE_Model, indata[test,])
CrossTable(as.numeric(indata$Species)[test], predict_aode)

                                 | predict_aode 
as.numeric(indata$Species)[test] |         1 |         2 |         3 | Row Total | 
---------------------------------|-----------|-----------|-----------|-----------|
                               1 |        20 |         0 |         0 |        20 | 
                                 |    18.000 |     4.800 |     7.200 |           | 
                                 |     1.000 |     0.000 |     0.000 |     0.400 | 
                                 |     1.000 |     0.000 |     0.000 |           | 
                                 |     0.400 |     0.000 |     0.000 |           | 
---------------------------------|-----------|-----------|-----------|-----------|
                               2 |         0 |        10 |         1 |        11 | 
                                 |     4.400 |    20.519 |     2.213 |           | 
                                 |     0.000 |     0.909 |     0.091 |     0.220 | 
                                 |     0.000 |     0.833 |     0.056 |           | 
                                 |     0.000 |     0.200 |     0.020 |           | 
---------------------------------|-----------|-----------|-----------|-----------|
                               3 |         0 |         2 |        17 |        19 | 
                                 |     7.600 |     1.437 |    15.091 |           | 
                                 |     0.000 |     0.105 |     0.895 |     0.380 | 
                                 |     0.000 |     0.167 |     0.944 |           | 
                                 |     0.000 |     0.040 |     0.340 |           | 
---------------------------------|-----------|-----------|-----------|-----------|
                    Column Total |        20 |        12 |        18 |        50 | 
                                 |     0.400 |     0.240 |     0.360 |           | 
---------------------------------|-----------|-----------|-----------|-----------|

只有3个错误。对我来说,这是一个在不过度适应的情况下进行谨慎处理的问题,这可能是很棘手的。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60647274

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档