我想为滑雪板模型做一个简单的包装。其思想是包装器自动处理各种因素("object"类型的列),用目标的平均值替换它们,同时保持sklearn模型的语法。
如果该因素太稀缺,则由目标的总体平均值取代。这似乎很简单,但当一个因素在测试组而不是在火车组时,就会出现问题。我想出了下面的解决方案,这让我觉得很尴尬。
class ModelEmbedder :
def __init__(self, model, rare_threshold) :
self.model = model
self.means = {}
self.rare_threshold = rare_threshold
self.train = None
self.origin_train = None
self.average = 0
def fit(self,train,target):
self.origin_train = train.copy().fillna(-1)
self.train = train.copy()
self.train = self.train.fillna(-1)
self.train['target'] = target
self.average = target.mean()
for feat in train.columns:
if feat != 'target' :
if self.train[feat].dtype=='object' :
self.train.loc[self.train[feat].value_counts()[self.train[feat]].values < self.rare_threshold, feat] = "RARE"
self.origin_train.loc[self.origin_train[feat].value_counts()[self.origin_train[feat]].values < self.rare_threshold, feat] = "RARE"
self.means[feat] = self.train.groupby([feat])['target'].mean()
self.means[feat]["RARE"] = self.average
self.train[feat] = self.train[feat].replace(self.means[feat], inplace=False)
del self.train['target']
self.model.fit(self.train,target)
def _pre_treat_test(self,test) :
test = test.copy()
test = test.fillna(-1)
for feat in self.origin_train.columns:
if self.origin_train[feat].dtype=='object' :
test.loc[self.origin_train[feat].value_counts()[self.origin_train[feat]].values < self.rare_threshold, feat] = "RARE"
criterion = ~test[feat].isin(set(self.origin_train[feat]))
test.loc[criterion,feat] = self.average
test[feat] = test[feat].replace(self.means[feat], inplace=False)
return test
def predict_proba(self,test) :
test = self._pre_treat_test(test)
return self.model.predict_proba(test)
def get_params(self, deep = True):
return self.model.get_params(deep)然后,每个模型都可以包装:
rf = ensemble.ExtraTreesClassifier(n_jobs=7,
n_estimators = n_estimators,
random_state = 11)
rf_embedded = model_embedder.ModelEmbedder(rf,10)并发送到交叉验证循环或任何管道。
发布于 2016-03-12 02:27:29
几点意见:
:之前添加空格,这听起来很奇怪。我会删除它,因为我从来没有见过任何其他Python代码这样做另外,我不知道你在这里是如何使用下划线的。一般来说,下划线表示Python中的私有成员/方法和它看起来有相当多的地方,您可以使您的成员变量私有。
此外,这一点可能会更加明确:
self.train.loc[self.train[feat].value_counts()[self.train[feat]].values < self.rare_threshold, feat] = "RARE"您可以考虑一些中间步骤来帮助提高可读性。即使是像这样简单的事情:
val = self.train[feat].value_counts()[self.train[feat]].values < self.rare_threshold
self.train.loc[val, feat] = "RARE"更清楚。尝试读取许多嵌套字典查找并不简单。
https://codereview.stackexchange.com/questions/121833
复制相似问题