我想知道是否有办法从R中的BART包构建的模型中检索数据?
使用其他bart包似乎是可能的,例如dbarts.但是我似乎找不到从BART模型中获取原始数据的方法。例如,如果我创建一些数据并运行一个BART和dbarts模型,如下所示:
library(BART)
library(dbarts)
# create data
df <- data.frame(
x = runif(100),
y = runif(100),
z = runif(100)
)
# create BART
BARTmodel <- wbart(x.train = df[,1:2],
y.train = df[,3])
# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
y.train = df[,3],
keeptrees = TRUE)在keeptrees中使用dbarts选项可以使用以下方法检索数据:
# retrieve data from dbarts
DBARTSmodel$fit$data@x但是,在使用BART时似乎没有任何类似的选项。甚至可以从BART模型中检索数据吗?
发布于 2022-04-29 11:43:40
Value:部分的?wbart建议它不返回输入作为输出的一部分,并且wbart的任何函数参数都不能表明这是可以更改的。
此外,如果您查看str的输出,您会发现它并不存在。
library(BART)
library(dbarts)
# create data
df <- data.frame(
x = runif(100),
y = runif(100),
z = runif(100)
)
# create BART
BARTmodel <- wbart(x.train = df[,1:2],
y.train = df[,3])
# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
y.train = df[,3],
keeptrees = TRUE)
str(BARTmodel)
#> List of 13
#> $ sigma : num [1:1100] 0.258 0.262 0.295 0.278 0.273 ...
#> $ yhat.train.mean: num [1:100] 0.584 0.457 0.505 0.54 0.403 ...
#> $ yhat.train : num [1:1000, 1:100] 0.673 0.62 0.433 0.711 0.634 ...
#> $ yhat.test.mean : num(0)
#> $ yhat.test : num[1:1000, 0 ]
#> $ varcount : int [1:1000, 1:2] 109 114 111 118 115 114 115 110 114 117 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : NULL
#> .. ..$ : chr [1:2] "x" "y"
#> $ varprob : num [1:1000, 1:2] 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : NULL
#> .. ..$ : chr [1:2] "x" "y"
#> $ treedraws :List of 2
#> ..$ cutpoints:List of 2
#> .. ..$ x: num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#> .. ..$ y: num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#> ..$ trees : chr "1000 200 2\n1\n1 0 0 0.01185590432\n3\n1 1 30 -0.01530736435\n2 0 0 0.01064412946\n3 0 0 0.02413784284\n3\n1 0 "| __truncated__
#> $ proc.time : 'proc_time' Named num [1:5] 1.406 0.008 1.415 0 0
#> ..- attr(*, "names")= chr [1:5] "user.self" "sys.self" "elapsed" "user.child" ...
#> $ mu : num 0.501
#> $ varcount.mean : Named num [1:2] 115 110
#> ..- attr(*, "names")= chr [1:2] "x" "y"
#> $ varprob.mean : Named num [1:2] 0.5 0.5
#> ..- attr(*, "names")= chr [1:2] "x" "y"
#> $ rm.const : int [1:2] 1 2
#> - attr(*, "class")= chr "wbart"而str()输出的bart输出虽然很长,但确实包含输入:
str(DBARTSmodel)
#> List of 11
#> $ call : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#> $ first.sigma : num [1:100] 0.289 0.311 0.268 0.253 0.242 ...
#> $ sigma : num [1:1000] 0.288 0.307 0.248 0.257 0.293 ...
#> $ sigest : num 0.295
#> $ yhat.train : num [1:1000, 1:100] 0.715 0.677 0.508 0.51 0.827 ...
#> $ yhat.train.mean: num [1:100] 0.583 0.456 0.504 0.544 0.404 ...
#> $ yhat.test : NULL
#> $ yhat.test.mean : NULL
#> $ varcount : int [1:1000, 1:2] 128 118 120 142 130 145 145 150 138 138 ...
#> ..- attr(*, "dimnames")=List of 2
#> .. ..$ : NULL
#> .. ..$ : chr [1:2] "x" "y"
#> $ y : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#> $ fit :Reference class 'dbartsSampler' [package "dbarts"] with 5 fields
#> ..$ pointer:<externalptr>
#> ..$ control:Formal class 'dbartsControl' [package "dbarts"] with 18 slots
#> .. .. ..@ binary : logi FALSE
#> .. .. ..@ verbose : logi TRUE
#> .. .. ..@ keepTrainingFits: logi TRUE
#> .. .. ..@ useQuantiles : logi FALSE
#> .. .. ..@ keepTrees : logi TRUE
#> .. .. ..@ n.samples : int 1000
#> .. .. ..@ n.burn : int 100
#> .. .. ..@ n.trees : int 200
#> .. .. ..@ n.chains : int 1
#> .. .. ..@ n.threads : int 1
#> .. .. ..@ n.thin : int 1
#> .. .. ..@ printEvery : int 100
#> .. .. ..@ printCutoffs : int 0
#> .. .. ..@ rngKind : chr "default"
#> .. .. ..@ rngNormalKind : chr "default"
#> .. .. ..@ rngSeed : int NA
#> .. .. ..@ updateState : logi TRUE
#> .. .. ..@ call : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#> ..$ model :Formal class 'dbartsModel' [package "dbarts"] with 9 slots
#> .. .. ..@ p.birth_death : num 0.5
#> .. .. ..@ p.swap : num 0.1
#> .. .. ..@ p.change : num 0.4
#> .. .. ..@ p.birth : num 0.5
#> .. .. ..@ node.scale : num 0.5
#> .. .. ..@ tree.prior :Formal class 'dbartsCGMPrior' [package "dbarts"] with 3 slots
#> .. .. .. .. ..@ power : num 2
#> .. .. .. .. ..@ base : num 0.95
#> .. .. .. .. ..@ splitProbabilities: num(0)
#> .. .. ..@ node.prior :Formal class 'dbartsNormalPrior' [package "dbarts"] with 0 slots
#> list()
#> .. .. ..@ node.hyperprior:Formal class 'dbartsFixedHyperprior' [package "dbarts"] with 1 slot
#> .. .. .. .. ..@ k: num 2
#> .. .. ..@ resid.prior :Formal class 'dbartsChiSqPrior' [package "dbarts"] with 2 slots
#> .. .. .. .. ..@ df : num 3
#> .. .. .. .. ..@ quantile: num 0.9
#> ..$ data :Formal class 'dbartsData' [package "dbarts"] with 10 slots
#> .. .. ..@ y : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#> .. .. ..@ x : num [1:100, 1:2] 0.152 0.666 0.967 0.248 0.668 ...
#> .. .. .. ..- attr(*, "dimnames")=List of 2
#> .. .. .. .. ..$ : NULL
#> .. .. .. .. ..$ : chr [1:2] "x" "y"
#> .. .. .. ..- attr(*, "drop")=List of 2
#> .. .. .. .. ..$ x: logi FALSE
#> .. .. .. .. ..$ y: logi FALSE
#> .. .. .. ..- attr(*, "term.labels")= chr [1:2] "x" "y"
#> .. .. ..@ varTypes : int [1:2] 0 0
#> .. .. ..@ x.test : NULL
#> .. .. ..@ weights : NULL
#> .. .. ..@ offset : NULL
#> .. .. ..@ offset.test : NULL
#> .. .. ..@ n.cuts : int [1:2] 100 100
#> .. .. ..@ sigma : num 0.295
#> .. .. ..@ testUsesRegularOffset: logi NA
#> ..$ state :List of 1
#> .. ..$ :Formal class 'dbartsState' [package "dbarts"] with 6 slots
#> .. .. .. ..@ trees : int [1:1055] 0 18 -1 0 49 -1 -1 0 60 -1 ...
#> .. .. .. ..@ treeFits : num [1:100, 1:200] -0.02252 0.00931 0.00931 0.02688 0.00931 ...
#> .. .. .. ..@ savedTrees: int [1:2340360] 0 797997482 1070928224 1 -402902351 1070268808 -1 -1094651769 -1081938039 -1 ...
#> .. .. .. ..@ sigma : num 0.297
#> .. .. .. ..@ k : num 2
#> .. .. .. ..@ rng.state : int [1:18] 0 1078575104 0 1078575104 -1657977906 1075613906 0 1078558720 277209871 -1068236140 ...
#> .. ..- attr(*, "runningTime")= num 0.477
#> .. ..- attr(*, "currentNumSamples")= int 1000
#> .. ..- attr(*, "currentSampleNum")= int 0
#> .. ..- attr(*, "numCuts")= int [1:2] 100 100
#> .. ..- attr(*, "cutPoints")=List of 2
#> .. .. ..$ : num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#> .. .. ..$ : num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#> ..and 40 methods, of which 26 are possibly relevant:
#> .. copy#envRefClass, getLatents, getPointer, getTrees, initialize, plotTree,
#> .. predict, printTrees, run, sampleNodeParametersFromPrior,
#> .. sampleTreesFromPrior, setControl, setCutPoints, setData, setModel,
#> .. setOffset, setPredictor, setResponse, setSigma, setState, setTestOffset,
#> .. setTestPredictor, setTestPredictorAndOffset, setWeights,
#> .. show#envRefClass, storeState
#> - attr(*, "class")= chr "bart"https://stackoverflow.com/questions/71913902
复制相似问题