首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ValueError:未知优化器:优化器

ValueError:未知优化器:优化器
EN

Stack Overflow用户
提问于 2020-07-25 18:14:06
回答 1查看 574关注 0票数 0

我想要进行超参数调优,为此我应用了gridsearchCV,但在拟合它的过程中,获得了ValueError

代码语言:javascript
复制
from keras.wrappers.scikit_learn import KerasClassifier 
from sklearn.model_selection import GridSearchCV

def build_classifier(optimizer):
    ann = tf.keras.models.Sequential()
    ann.add(tf.keras.layers.Dense(units = 6, activation = 'relu'))
    ann.add(tf.keras.layers.Dense(units = 6, activation = 'relu'))
    ann.add(tf.keras.layers.Dense(units = 1, activation = 'sigmoid'))     #softmax in case of more than 2 classes
    ann.compile(optimizer = 'optimizer', loss = 'binary_crossentropy', metrics = ['accuracy']) #categorical_crossentropy in case of categories > 2
    return ann

ann = KerasClassifier(build_fn = build_classifier)

parameters = {'batch_size': [25,32],
              'epochs' : [10,100],
              'optimizer' : ['adam', 'rmsprop']}

grid_search = GridSearchCV(estimator = ann,
                           param_grid = parameters,
                           scoring = 'accuracy',
                           cv = 10)
grid_search = grid_search.fit(X_train, y_train)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-07-25 22:50:55

不是将'optimizer'字符串传递给compile(),而是传递函数参数optimizer

代码语言:javascript
复制
import tensorflow as tf
from sklearn.model_selection import GridSearchCV

def build_classifier(optimizer):
    ann = tf.keras.models.Sequential()
    ann.add(tf.keras.layers.Dense(units = 6, activation = 'relu'))
    ann.add(tf.keras.layers.Dense(units = 6, activation = 'relu'))
    ann.add(tf.keras.layers.Dense(units = 1, activation = 'sigmoid'))    
    ann.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    return ann

ann = tf.keras.wrappers.scikit_learn.KerasClassifier(build_fn = build_classifier)
    
parameters = {'batch_size': [25,32],
              'epochs': [10, 100],
              'optimizer': ['Adam', 'RMSprop']}

grid_search = GridSearchCV(estimator=ann,
                           param_grid=parameters,
                           scoring= 'accuracy',
                           cv=10)

grid_search = grid_search.fit(X, y)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63087107

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档