首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >CATBoost和GridSearch

CATBoost和GridSearch
EN

Stack Overflow用户
提问于 2019-11-20 09:43:20
回答 1查看 4.2K关注 0票数 3
代码语言:javascript
复制
model.fit(train_data, y=label_data, eval_set=eval_dataset)
eval_dataset = Pool(val_data, val_labels)
model = CatBoostClassifier(depth=8 or 10, iterations=10, task_type="GPU", devices='0-2', eval_metric='Accuracy', boosting_type="Ordered", bagging_temperature=0, use_best_model=True)

当我运行上面的代码(在两个单独的运行/深度设置为8或10)时,我得到以下结果:

深度10: 0.6864865深8: 0.6756757

我想以一种方式设置和运行GridSearch --因此它运行完全相同的组合并产生完全相同的结果--就像我手动运行代码时一样。

GridSearch代码:

代码语言:javascript
复制
model = CatBoostClassifier(iterations=10, task_type="GPU", devices='0-2', eval_metric='Accuracy', boosting_type="Ordered", depth=10, bagging_temperature=0, use_best_model=True)

grid = {'depth': [8,10]}
grid_search_result = GridSearchCV(model, grid, cv=2)
results = grid_search_result.fit(train_data, y=label_data, eval_set=eval_dataset) 

问题:

  1. --我希望GridSearch使用我的"eval_set“来比较/验证所有不同的运行(比如手动运行时)--但是它使用的是一些我不知道的东西,它似乎根本不看"eval_set”?
  2. 不仅产生两个结果,而且取决于"cv“(交叉验证拆分策略)。它运行3,5,7,9或11次?我不想那样。
  3. 我试着通过调试器检查整个“结果”对象--但是我只是找不到验证“准确性”的分数,以便获得最好的或者所有的其他结果。我可以找到很多其他的价值--但它们都比不上我要找的东西。这些数字与"eval_set“数据集生成的数字不匹配?

我通过实现我自己的简单GridSearch来解决我的问题(万一它可以帮助/激励其他人:- ):如果您对代码有任何评论,请告诉我:-)

代码语言:javascript
复制
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import GridSearchCV
import csv
from datetime import datetime

# Initialize data

train_data = pd.read_csv('./train_x.csv')
label_data = pd.read_csv('./labels_train_x.csv')
val_data = pd.read_csv('./val_x.csv')
val_labels = pd.read_csv('./labels_val_x.csv')

eval_dataset = Pool(val_data, val_labels)

ite = [1000,2000]
depth = [6,7,8,9,10]
max_bin = [None,32,46,100,254]
l2_leaf_reg = [None,2,10,20,30]
bagging_temperature = [None,0,0.5,1]
random_strength = [None,1,5,10]
total_runs = len(ite) * len(depth) * len(max_bin) * len(l2_leaf_reg) * len(bagging_temperature) * len(random_strength)

print('Total runs: ' + str(total_runs))

counter = 0

file_name = './Results/Catboost_' + str(datetime.now().strftime("%d_%m_%Y_%H_%M_%S")) + '.csv'

row = ['Validation Accuray','Logloss','Iterations', 'Depth', 'Max_bin', 'L2_leaf_reg', 'Bagging_temperature', 'Random_strength']
with open(file_name, 'a') as csvFile:
    writer = csv.writer(csvFile)
    writer.writerow(row)
csvFile.close()

for a in ite:
    for b in depth:
        for c in max_bin:
            for d in l2_leaf_reg:
                for e in bagging_temperature:
                    for f in random_strength:
                        model = CatBoostClassifier(task_type="GPU", devices='0-2', eval_metric='Accuracy', boosting_type="Ordered", use_best_model=True,
                        iterations=a, depth=b, max_bin=c, l2_leaf_reg=d, bagging_temperature=e, random_strength=f)
                        counter += 1
                        print('Run # ' + str(counter) + '/' + str(total_runs))
                        result = model.fit(train_data, y=label_data, eval_set=eval_dataset, verbose=1)

                        accuracy = float(result.best_score_['validation']['Accuracy'])
                        logLoss = result.best_score_['validation']['Logloss']

                        row = [ accuracy, logLoss,
                                ('Auto' if a == None else a),
                                ('Auto' if b == None else b),
                                ('Auto' if c == None else c),
                                ('Auto' if d == None else d),
                                ('Auto' if e == None else e),
                                ('Auto' if f == None else f)]

                        with open(file_name, 'a') as csvFile:
                            writer = csv.writer(csvFile)
                            writer.writerow(row)
                        csvFile.close()
EN

回答 1

Stack Overflow用户

发布于 2019-11-20 12:03:37

Catboost中的eval集充当了一个抵抗集。

在GridSearchCV中,cv是在train_data上执行的。

一种解决方案是将您的train_data和eval_dataset合并,并在GridSearchCV中传递train和eval索引。尝试在cv param中生成这两组索引。然后,您将只有一个分割和准确的数字,将给您相同的结果。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58951164

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档