我想在不更改源代码的情况下编辑Python类中的方法(包lightgbm)。这个类被其他模块调用,但我所做的更改不会反映在这些模块中。
这是我的代码(我在self.__higher_better_inner_eval下面编辑):
class Booster_fix(lightgbm.basic.Booster):
def __get_eval_info(self):
"""Get inner evaluation count and names."""
if self.__need_reload_eval_info:
self.__need_reload_eval_info = False
out_num_eval = ctypes.c_int(0)
# Get num of inner evals
_safe_call(_LIB.LGBM_BoosterGetEvalCounts(
self.handle,
ctypes.byref(out_num_eval)))
self.__num_inner_eval = out_num_eval.value
if self.__num_inner_eval > 0:
# Get name of evals
tmp_out_len = ctypes.c_int(0)
reserved_string_buffer_size = 255
required_string_buffer_size = ctypes.c_size_t(0)
string_buffers = [
ctypes.create_string_buffer(reserved_string_buffer_size) for i in range_(self.__num_inner_eval)
]
ptr_string_buffers = (ctypes.c_char_p * self.__num_inner_eval)(*map(ctypes.addressof, string_buffers))
_safe_call(_LIB.LGBM_BoosterGetEvalNames(
self.handle,
ctypes.c_int(self.__num_inner_eval),
ctypes.byref(tmp_out_len),
ctypes.c_size_t(reserved_string_buffer_size),
ctypes.byref(required_string_buffer_size),
ptr_string_buffers))
if self.__num_inner_eval != tmp_out_len.value:
raise ValueError("Length of eval names doesn't equal with num_evals")
if reserved_string_buffer_size < required_string_buffer_size.value:
raise BufferError(
"Allocated eval name buffer size ({}) was inferior to the needed size ({})."
.format(reserved_string_buffer_size, required_string_buffer_size.value)
)
self.__name_inner_eval = \
[string_buffers[i].value.decode('utf-8') for i in range_(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg@', 'map@', 'average_precision')) for name in self.__name_inner_eval] # replacement row
# [name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval] # original row
lightgbm.basic.Booster = Booster_fix在其他一些模块中,有如下调用:
from .basic import Booster我猜我的猴子补丁不管用了。
基本上,我想实现这个修复:https://github.com/microsoft/LightGBM/pull/3649/commits/f73407f05e389a74e6f44a2cb9c637df6afdb33b,不等待发布的更新,也不编辑源代码。
你能帮我拿一下吗?
编辑:@Maurice,要测试代码,您可以运行:
import pandas as pd
from lightgbm import LGBMClassifier # lgbm version 3.1.1
from sklearn.model_selection import train_test_split
df = pd.read_csv("https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv")
df['Sex'] = df['Sex'].astype('category')
X, y = df.drop(columns=['Name', 'Survived']), df['Survived']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
model = LGBMClassifier()
model.fit(X_train, y_train, early_stopping_rounds=10, eval_metric='average_precision', eval_set=(X_test, y_test))发布于 2021-03-25 18:58:44
您只需为lightgbm.basic.Booster的__get_eval_info打补丁
from lightgbm.basic import Booster
def patched__get_eval_info(self):
print("patched ...")
# your code here
Booster.__get_eval_info = patched__get_eval_info
b = Booster(model_file='model.txt')
b.__get_eval_info()
# ... other imports that use lightgbm follow here输出:
patched ...https://stackoverflow.com/questions/66797355
复制相似问题