首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >“config=wandb.config”的含义是什么?

“config=wandb.config”的含义是什么?
EN

Stack Overflow用户
提问于 2022-08-19 06:50:49
回答 1查看 170关注 0票数 0

我试图为我的Logistic回归模型做一个扫描设置。我阅读了wandb的教程,无法理解如何进行配置,特别是教程中config=wandb.config的含义。如果有人能给我一个很好的解释,我会非常感激的。以下是我所做的:

代码语言:javascript
复制
sweep_config = {
    'method': 'grid'
}

metric = {
    'name': 'f1-score',
    'goal': 'maximize'
}

sweep_config['metric'] = metric

parameters = {
    'penalty': {
        'values': ['l2']
    },
    'C': {
        'values': [0.01, 0.1, 1.0, 10.0, 100.0]
    }
}

sweep_config['parameters'] = parameters

然后创建yaml文件:

代码语言:javascript
复制
stream = open('config.yaml', 'w')
yaml.dump(sweep_config, stream) 

那么是时候训练了:

代码语言:javascript
复制
with wandb.init(project=WANDB_PROJECT_NAME):
    config = wandb.config
    
    features = pd.read_csv('data/x_features.csv')
    vectorizer = TfidfVectorizer(ngram_range=(1,2))

    X_features = features = vectorizer.fit_transform(features['lemmatized_reason'])

    y_labels = pd.read_csv('data/y_labels.csv')

    split_data = train_test_split(X_features, y_labels, train_size = 0.85, test_size = 0.15, stratify=y_labels)
    features_train, labels_train = split_data[0], split_data[2]
    features_test, labels_test = split_data[1], split_data[3]
    
    config = wandb.config
    log_reg = LogisticRegression(
        penalty=config.penalty,
        C = config.C
    )
    
    log_reg.fit(features_train, labels_train)
    
    labels_pred = log_reg.predict(features_test)
    labels_proba = log_reg.predict_proba(features_test)
    labels=list(map(str,y_labels['label'].unique()))
    
    # Visualize single plot
    cm = wandb.sklearn.plot_confusion_matrix(labels_test, labels_pred, labels)
    
    score_f1 = f1_score(labels_test, labels_pred, average='weighted')
    
    sm = wandb.sklearn.plot_summary_metrics(
    log_reg, features_train, labels_train, features_test, labels_test)
    
    roc = wandb.sklearn.plot_roc(labels_test, labels_proba)
    
    wandb.log({
        "f1-weighted-log-regression-tfidf-skf": score_f1, 
        "roc-log-regression-tfidf-skf": roc, 
        "conf-mat-logistic-regression-tfidf-skf": cm,
        "summary-metrics-logistic-regression-tfidf-skf": sm
        })

最后是sweep_id和with语句之外的代理:

代码语言:javascript
复制
sweep_id = wandb.sweep(sweep_config, project="multiple-classifiers")
wandb.agent(sweep_id)

在这个配置方面,我遗漏了一些重要的东西,我只是无法理解。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-08-19 11:17:27

我的工作是举重和偏见。使用wandb清理,想法是wandb需要能够改变扫描中的超参数。

将超参数传递给LogisticRegression的下面部分也可以重写。

代码语言:javascript
复制
config = wandb.config
log_reg = LogisticRegression(
    penalty=config.penalty,
    C = config.C
)

就像这样:

代码语言:javascript
复制
log_reg = LogisticRegression(
    penalty=wandb.config.penalty,
    C = wandb.config.C
)

但是,我认为您缺少定义列车函数或列车脚本,这也需要传递给wandb。没有它,上面的例子就行不通了。

下面是一个最低限度的例子,应该有帮助。希望清扫文件也能提供帮助。

代码语言:javascript
复制
import numpy as np 
import random
import wandb

#  Step 1: Define sweep config
sweep_configuration = {
    'method': 'random',
    'name': 'sweep',
    'metric': {'goal': 'maximize', 'name': 'val_acc'},
    'parameters': 
    {
        'batch_size': {'values': [16, 32, 64]},
        'epochs': {'values': [5, 10, 15]},
        'lr': {'max': 0.1, 'min': 0.0001}
     }
}

#  Step 2: Initialize sweep by passing in config
sweep_id = wandb.sweep(sweep_configuration)

def train_one_epoch(epoch, lr, bs): 
  acc = 0.25 + ((epoch/30) +  (random.random()/10))
  loss = 0.2 + (1 - ((epoch-1)/10 +  random.random()/5))
  return acc, loss

def evaluate_one_epoch(epoch): 
  acc = 0.1 + ((epoch/20) +  (random.random()/10))
  loss = 0.25 + (1 - ((epoch-1)/10 +  random.random()/6))
  return acc, loss

def train():
    run = wandb.init()

    #  Step 3: Use hyperparameter values from `wandb.config`
    lr  =  wandb.config.lr
    bs = wandb.config.batch_size
    epochs = wandb.config.epochs

    for epoch in np.arange(1, epochs):
      train_acc, train_loss = train_one_epoch(epoch, lr, bs)
      val_acc, val_loss = evaluate_one_epoch(epoch)

      wandb.log({
        'epoch': epoch, 
        'train_acc': train_acc,
        'train_loss': train_loss, 
        'val_acc': val_acc, 
        'val_loss': val_loss
      })

#  Step 4: Launch sweep by making a call to `wandb.agent`
wandb.agent(sweep_id, function=train, count=4)

最后,您能分享上面找到代码的链接吗?也许我们需要更新一些例子:)

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73412851

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档