首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >有没有办法从R中的BART包模型中检索数据?

有没有办法从R中的BART包模型中检索数据?
EN

Stack Overflow用户
提问于 2022-04-18 15:37:34
回答 1查看 170关注 0票数 2

我想知道是否有办法从R中的BART包构建的模型中检索数据?

使用其他bart包似乎是可能的,例如dbarts.但是我似乎找不到从BART模型中获取原始数据的方法。例如,如果我创建一些数据并运行一个BARTdbarts模型,如下所示:

代码语言:javascript
复制
library(BART)
library(dbarts)

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100)
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)

keeptrees中使用dbarts选项可以使用以下方法检索数据:

代码语言:javascript
复制
# retrieve data from dbarts
DBARTSmodel$fit$data@x

但是,在使用BART时似乎没有任何类似的选项。甚至可以从BART模型中检索数据吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-04-29 11:43:40

Value:部分的?wbart建议它不返回输入作为输出的一部分,并且wbart的任何函数参数都不能表明这是可以更改的。

此外,如果您查看str的输出,您会发现它并不存在。

代码语言:javascript
复制
library(BART)
library(dbarts)

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100)
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)

str(BARTmodel)
#> List of 13
#>  $ sigma          : num [1:1100] 0.258 0.262 0.295 0.278 0.273 ...
#>  $ yhat.train.mean: num [1:100] 0.584 0.457 0.505 0.54 0.403 ...
#>  $ yhat.train     : num [1:1000, 1:100] 0.673 0.62 0.433 0.711 0.634 ...
#>  $ yhat.test.mean : num(0) 
#>  $ yhat.test      : num[1:1000, 0 ] 
#>  $ varcount       : int [1:1000, 1:2] 109 114 111 118 115 114 115 110 114 117 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ varprob        : num [1:1000, 1:2] 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ treedraws      :List of 2
#>   ..$ cutpoints:List of 2
#>   .. ..$ x: num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#>   .. ..$ y: num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#>   ..$ trees    : chr "1000 200 2\n1\n1 0 0 0.01185590432\n3\n1 1 30 -0.01530736435\n2 0 0 0.01064412946\n3 0 0 0.02413784284\n3\n1 0 "| __truncated__
#>  $ proc.time      : 'proc_time' Named num [1:5] 1.406 0.008 1.415 0 0
#>   ..- attr(*, "names")= chr [1:5] "user.self" "sys.self" "elapsed" "user.child" ...
#>  $ mu             : num 0.501
#>  $ varcount.mean  : Named num [1:2] 115 110
#>   ..- attr(*, "names")= chr [1:2] "x" "y"
#>  $ varprob.mean   : Named num [1:2] 0.5 0.5
#>   ..- attr(*, "names")= chr [1:2] "x" "y"
#>  $ rm.const       : int [1:2] 1 2
#>  - attr(*, "class")= chr "wbart"

str()输出的bart输出虽然很长,但确实包含输入:

代码语言:javascript
复制
str(DBARTSmodel)
#> List of 11
#>  $ call           : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#>  $ first.sigma    : num [1:100] 0.289 0.311 0.268 0.253 0.242 ...
#>  $ sigma          : num [1:1000] 0.288 0.307 0.248 0.257 0.293 ...
#>  $ sigest         : num 0.295
#>  $ yhat.train     : num [1:1000, 1:100] 0.715 0.677 0.508 0.51 0.827 ...
#>  $ yhat.train.mean: num [1:100] 0.583 0.456 0.504 0.544 0.404 ...
#>  $ yhat.test      : NULL
#>  $ yhat.test.mean : NULL
#>  $ varcount       : int [1:1000, 1:2] 128 118 120 142 130 145 145 150 138 138 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ y              : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#>  $ fit            :Reference class 'dbartsSampler' [package "dbarts"] with 5 fields
#>   ..$ pointer:<externalptr> 
#>   ..$ control:Formal class 'dbartsControl' [package "dbarts"] with 18 slots
#>   .. .. ..@ binary          : logi FALSE
#>   .. .. ..@ verbose         : logi TRUE
#>   .. .. ..@ keepTrainingFits: logi TRUE
#>   .. .. ..@ useQuantiles    : logi FALSE
#>   .. .. ..@ keepTrees       : logi TRUE
#>   .. .. ..@ n.samples       : int 1000
#>   .. .. ..@ n.burn          : int 100
#>   .. .. ..@ n.trees         : int 200
#>   .. .. ..@ n.chains        : int 1
#>   .. .. ..@ n.threads       : int 1
#>   .. .. ..@ n.thin          : int 1
#>   .. .. ..@ printEvery      : int 100
#>   .. .. ..@ printCutoffs    : int 0
#>   .. .. ..@ rngKind         : chr "default"
#>   .. .. ..@ rngNormalKind   : chr "default"
#>   .. .. ..@ rngSeed         : int NA
#>   .. .. ..@ updateState     : logi TRUE
#>   .. .. ..@ call            : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#>   ..$ model  :Formal class 'dbartsModel' [package "dbarts"] with 9 slots
#>   .. .. ..@ p.birth_death  : num 0.5
#>   .. .. ..@ p.swap         : num 0.1
#>   .. .. ..@ p.change       : num 0.4
#>   .. .. ..@ p.birth        : num 0.5
#>   .. .. ..@ node.scale     : num 0.5
#>   .. .. ..@ tree.prior     :Formal class 'dbartsCGMPrior' [package "dbarts"] with 3 slots
#>   .. .. .. .. ..@ power             : num 2
#>   .. .. .. .. ..@ base              : num 0.95
#>   .. .. .. .. ..@ splitProbabilities: num(0) 
#>   .. .. ..@ node.prior     :Formal class 'dbartsNormalPrior' [package "dbarts"] with 0 slots
#>  list()
#>   .. .. ..@ node.hyperprior:Formal class 'dbartsFixedHyperprior' [package "dbarts"] with 1 slot
#>   .. .. .. .. ..@ k: num 2
#>   .. .. ..@ resid.prior    :Formal class 'dbartsChiSqPrior' [package "dbarts"] with 2 slots
#>   .. .. .. .. ..@ df      : num 3
#>   .. .. .. .. ..@ quantile: num 0.9
#>   ..$ data   :Formal class 'dbartsData' [package "dbarts"] with 10 slots
#>   .. .. ..@ y                    : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#>   .. .. ..@ x                    : num [1:100, 1:2] 0.152 0.666 0.967 0.248 0.668 ...
#>   .. .. .. ..- attr(*, "dimnames")=List of 2
#>   .. .. .. .. ..$ : NULL
#>   .. .. .. .. ..$ : chr [1:2] "x" "y"
#>   .. .. .. ..- attr(*, "drop")=List of 2
#>   .. .. .. .. ..$ x: logi FALSE
#>   .. .. .. .. ..$ y: logi FALSE
#>   .. .. .. ..- attr(*, "term.labels")= chr [1:2] "x" "y"
#>   .. .. ..@ varTypes             : int [1:2] 0 0
#>   .. .. ..@ x.test               : NULL
#>   .. .. ..@ weights              : NULL
#>   .. .. ..@ offset               : NULL
#>   .. .. ..@ offset.test          : NULL
#>   .. .. ..@ n.cuts               : int [1:2] 100 100
#>   .. .. ..@ sigma                : num 0.295
#>   .. .. ..@ testUsesRegularOffset: logi NA
#>   ..$ state  :List of 1
#>   .. ..$ :Formal class 'dbartsState' [package "dbarts"] with 6 slots
#>   .. .. .. ..@ trees     : int [1:1055] 0 18 -1 0 49 -1 -1 0 60 -1 ...
#>   .. .. .. ..@ treeFits  : num [1:100, 1:200] -0.02252 0.00931 0.00931 0.02688 0.00931 ...
#>   .. .. .. ..@ savedTrees: int [1:2340360] 0 797997482 1070928224 1 -402902351 1070268808 -1 -1094651769 -1081938039 -1 ...
#>   .. .. .. ..@ sigma     : num 0.297
#>   .. .. .. ..@ k         : num 2
#>   .. .. .. ..@ rng.state : int [1:18] 0 1078575104 0 1078575104 -1657977906 1075613906 0 1078558720 277209871 -1068236140 ...
#>   .. ..- attr(*, "runningTime")= num 0.477
#>   .. ..- attr(*, "currentNumSamples")= int 1000
#>   .. ..- attr(*, "currentSampleNum")= int 0
#>   .. ..- attr(*, "numCuts")= int [1:2] 100 100
#>   .. ..- attr(*, "cutPoints")=List of 2
#>   .. .. ..$ : num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#>   .. .. ..$ : num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#>   ..and 40 methods, of which 26 are  possibly relevant:
#>   ..  copy#envRefClass, getLatents, getPointer, getTrees, initialize, plotTree,
#>   ..  predict, printTrees, run, sampleNodeParametersFromPrior,
#>   ..  sampleTreesFromPrior, setControl, setCutPoints, setData, setModel,
#>   ..  setOffset, setPredictor, setResponse, setSigma, setState, setTestOffset,
#>   ..  setTestPredictor, setTestPredictorAndOffset, setWeights,
#>   ..  show#envRefClass, storeState
#>  - attr(*, "class")= chr "bart"
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71913902

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档