首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么xgboost只计算n_estimators回合的梯度?

为什么xgboost只计算n_estimators回合的梯度?
EN

Stack Overflow用户
提问于 2022-01-19 15:04:21
回答 1查看 47关注 0票数 0

在这里,n_estimators指的是xgboost中的树数(周学习者)。我定义了一个自定义的目标函数,其中我必须计算一阶和二阶梯度。我还在自定义目标函数中添加了一个打印函数,这样我就可以计算被调用(调用)的定制目标函数的数量。最后,我发现该函数仅由n_estimators时间调用。

以前,我认为在做叶分割时应该调用自定义的目标函数,甚至一棵树也可以有多个叶分裂。所以我现在很困惑。

EN

回答 1

Stack Overflow用户

发布于 2022-01-22 22:20:34

回答

在XGBoost和LightGBM中,梯度都是在每个助推轮开始时计算一次。

然后构建每一棵树,试图解释基于候选特征划分的梯度。你可以这样想,“每一棵树都是一个决策树,训练到目前为止预测模型的残差”,但是“残差”的值是由选定的目标函数决定的。这就是为什么每次将一个节点添加到模型的一个树时都不需要重新计算梯度的原因。

示例

这里有一个minimal, reproducible example,演示了您询问过的行为。今天,我用Python3.9.7在macOS上运行了这段代码,使用了来自conda的lightgbm==3.3.2xgboost==1.5.0

我提供了lightgbmxgboost示例,因为最初的帖子有这两个标记。

代码语言:javascript
复制
import lightgbm as lgb
import numpy as np
import xgboost as xgb
from sklearn.datasets import make_regression
from typing import Tuple, Union

COUNTER = 0

def _objective_least_squares(
    y_pred: np.ndarray,
    train_data: Union[lgb.Dataset, xgb.DMatrix]
) -> Tuple[np.ndarray, np.ndarray]:
    global COUNTER
    print(f"iteration: {COUNTER}")
    y_true = train_data.get_label()
    grad = y_pred - y_true
    hess = np.ones(len(y_true))
    COUNTER += 1
    return grad, hess

X, y = make_regression(n_samples=1_000, n_features=10, n_informative=8, random_state=708)

#--- xgboost example ---#
COUNTER = 0
xgb_dtrain = xgb.DMatrix(data=X, label=y)
xgb_model = xgb.train(
    params={'tree_method': 'hist', 'verbosity': 0},
    dtrain=xgb_dtrain,
    num_boost_round=10,
    obj=_objective_least_squares
)
print(f"Model has {len(xgb_model.trees_to_dataframe())} total tree nodes")

#--- lightgbm example ---#
COUNTER = 0
lgb_dtrain = lgb.Dataset(data=X, label=y)
lgb_model = lgb.train(
    params={'seed': 1994, 'verbosity': -1},
    train_set=lgb_dtrain,
    num_boost_round=10,
    fobj=_objective_least_squares
)
print(f"Model has {len(lgb_model.trees_to_dataframe())} total tree nodes")

这两个示例都显示了在原始post...the目标函数中报告的行为是每次迭代一次,而不是每个树节点一次。

代码语言:javascript
复制
iteration: 0
iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9

在培训后添加的print()语句确认了在本例中生成的模型有10个以上的树节点,只是为了确认这是一种准确的方法来显示调用目标函数的频率。

代码语言:javascript
复制
# xgboost example
Model has 1100 total tree nodes

# lightgbm example
Model has 610 total tree nodes
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70772943

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档