首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tidymodels工作流的文件大小

tidymodels工作流的文件大小
EN

Stack Overflow用户
提问于 2022-08-29 13:20:08
回答 1查看 89关注 0票数 6

我试图在流程中采用tidymodel,但在保存工作流方面遇到了挑战。工作流对象的文件大小比用于构建模型的数据大很多倍,因此,在尝试将工作流应用于新数据时,我最终耗尽了内存。我不知道这是正确的结果还是我遗漏了什么。

为了对新数据进行预测,难道我们不需要菜谱步骤、模型系数,或者可能需要一些来自培训集的汇总数据(例如,用于缩放目的的培训数据的sd和平均值)吗?那么,为什么工作流对象如此大呢?

下面是一个使用iris数据集的简单示例。我试图跟随朱莉娅的例子,但工作流程最终仍然比数据本身大24倍。我知道泰迪模特进化得很快,所以也许现在有更好的方法了?如有任何建议,敬请谅解!

代码语言:javascript
复制
library(tidyverse)
library(tidymodels)
library(lobstr)
library(butcher)

set.seed(8675309)

#Create an indicator for whether the species is Setosa
df <- iris %>% 
    mutate(is_setosa = factor(Species == "setosa"))

#Split into train/test
df_split <- initial_split(df, prop = 0.80)
df_train <- training(df_split)
df_test <- testing(df_split)

#Create the workflow object
my_workflow <- workflow() %>% 
    #use a logistic regression model using glm
    add_model({
        logistic_reg() %>% 
            set_engine("glm")
    }) %>% 
    #Add the recipe
    add_recipe({
        recipe(is_setosa ~ Sepal.Length + Sepal.Width + 
                   Petal.Length + Petal.Width,
               data = df_train) %>% 
            #Add a few arbitrary transformations
            step_log(Sepal.Length) %>% 
            step_mutate(across(matches("Width"),
                               .fns = ~ as.numeric(.x > quantile(.x, 0.9)),
                               .names = "is_{.col}_top_decile")) %>% 
            step_zv(all_predictors()) %>% 
            step_normalize()
    })


#Do a final fit using the workflow.
#The model doesn't converge, but that's not the point.
my_fit <- my_workflow %>% 
    last_fit(df_split)

#How big is our data? 8.3kb
size_data <- df %>% 
    obj_size()

#What's the smallest we can make the workflow? 197kb
size_fit <- my_fit %>% 
    extract_workflow() %>% 
    butcher() %>% 
    obj_size()

#What's the ratio of size between the original data and the fit object?
as.numeric(size_fit / size_data)
#The fit object is 24x bigger than our data.  
#Is that the expected result?


#In order to make predictions on future data, 
#we'd save/load the butchered workflow?
my_fit %>% 
    extract_workflow() %>% 
    butcher() %>% 
    write_rds("my_fit.rds")
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-08-31 17:41:30

这似乎是使用glm()作为模型引擎的预期行为。有关更多细节,请参见这个GitHub问题;非常感谢Emil & Julia对此进行了调查。

我将模型引擎从glm转换为LiblineaR,使被屠宰的工作流的文件大小减少了4倍。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73529453

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档