首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Monkey patching -更改不会反映在其他模块中

Monkey patching -更改不会反映在其他模块中
EN

Stack Overflow用户
提问于 2021-03-25 18:17:00
回答 1查看 48关注 0票数 1

我想在不更改源代码的情况下编辑Python类中的方法(包lightgbm)。这个类被其他模块调用,但我所做的更改不会反映在这些模块中。

这是我的代码(我在self.__higher_better_inner_eval下面编辑):

代码语言:javascript
复制
class Booster_fix(lightgbm.basic.Booster):

    def __get_eval_info(self):
        """Get inner evaluation count and names."""
        if self.__need_reload_eval_info:
            self.__need_reload_eval_info = False
            out_num_eval = ctypes.c_int(0)
            # Get num of inner evals
            _safe_call(_LIB.LGBM_BoosterGetEvalCounts(
                self.handle,
                ctypes.byref(out_num_eval)))
            self.__num_inner_eval = out_num_eval.value
            if self.__num_inner_eval > 0:
                # Get name of evals
                tmp_out_len = ctypes.c_int(0)
                reserved_string_buffer_size = 255
                required_string_buffer_size = ctypes.c_size_t(0)
                string_buffers = [
                    ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(self.__num_inner_eval)
                ]
                ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
                _safe_call(_LIB.LGBM_BoosterGetEvalNames(
                    self.handle,
                    ctypes.c_int(self.__num_inner_eval),
                    ctypes.byref(tmp_out_len),
                    ctypes.c_size_t(reserved_string_buffer_size),
                    ctypes.byref(required_string_buffer_size),
                    ptr_string_buffers))
                if self.__num_inner_eval != tmp_out_len.value:
                    raise ValueError("Length of eval names doesn't equal with num_evals")
                if reserved_string_buffer_size < required_string_buffer_size.value:
                    raise BufferError(
                        "Allocated eval name buffer size ({}) was inferior to the needed size ({})."
                        .format(reserved_string_buffer_size, required_string_buffer_size.value)
                    )
                self.__name_inner_eval = \
                    [string_buffers[i].value.decode('utf-8') for i in range_(self.__num_inner_eval)]
                self.__higher_better_inner_eval = \
                    [name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval] # replacement row
                    # [name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval] # original row

lightgbm.basic.Booster = Booster_fix

在其他一些模块中,有如下调用:

代码语言:javascript
复制
from .basic import Booster

我猜我的猴子补丁不管用了。

基本上,我想实现这个修复:https://github.com/microsoft/LightGBM/pull/3649/commits/f73407f05e389a74e6f44a2cb9c637df6afdb33b,不等待发布的更新,也不编辑源代码。

你能帮我拿一下吗?

编辑:@Maurice,要测试代码,您可以运行:

代码语言:javascript
复制
import pandas as pd
from lightgbm import LGBMClassifier # lgbm version 3.1.1
from sklearn.model_selection import train_test_split

df = pd.read_csv("https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv")
df['Sex'] = df['Sex'].astype('category')
X, y = df.drop(columns=['Name', 'Survived']), df['Survived']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
model = LGBMClassifier()
model.fit(X_train, y_train, early_stopping_rounds=10, eval_metric='average_precision', eval_set=(X_test, y_test))
EN

回答 1

Stack Overflow用户

发布于 2021-03-25 18:58:44

您只需为lightgbm.basic.Booster__get_eval_info打补丁

代码语言:javascript
复制
from lightgbm.basic import Booster


def patched__get_eval_info(self):
    print("patched ...")
    # your code here


Booster.__get_eval_info = patched__get_eval_info
b = Booster(model_file='model.txt')
b.__get_eval_info()

# ... other imports that use lightgbm follow here

输出:

代码语言:javascript
复制
patched ...
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66797355

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档