首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >R_squared vs epochs

R_squared vs epochs
EN

Stack Overflow用户
提问于 2020-01-25 17:39:30
回答 1查看 412关注 0票数 0

早上好,我是python初学者,我正在尝试构建我的第一个神经网络。有没有一种方法可以绘制R2的演变与时代的关系?我用以下方式评估R2:r2_score(y_test_pred, y_test)。我用这种方式构建了一个完全连接的神经网络:

代码语言:javascript
复制
optimizer = tf.keras.optimizers.Adam(lr=0.001)
model = Sequential()

# ,kernel_regularizer=l2(c), bias_regularizer=l2(c)
model.add(Dense(100, input_shape = (X_train.shape[1],), activation = 'relu',kernel_initializer='glorot_uniform'))
model.add(Dropout(0.2))
model.add(Dense(100, activation = 'relu',kernel_initializer='glorot_uniform'))
model.add(Dropout(0.2))
model.add(Dense(100, activation = 'relu',kernel_initializer='glorot_uniform'))

model.add(Dense(1,activation = 'linear',kernel_initializer='glorot_uniform'))

model.compile(loss = 'mse', optimizer = optimizer, metrics = ['mse'])

history = model.fit(X_train, y_train, epochs = 100,
                    validation_split = 0.1, shuffle=False, batch_size=250
                    )

history_dict = history.history
`

数据集由18个特征和1个标签组成,这是一个回归任务。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-01-25 17:43:12

您只需将其添加到您的compile行中。

代码语言:javascript
复制
model.compile(loss = 'mse', optimizer = optimizer, metrics = ['mse', r2_score])

如果你想这样做,你需要创建一个keras可以理解的指标,

代码语言:javascript
复制
import tf.keras.backend as K

def r2_score(y_true, y_pred):
    SS_res =  K.sum(K.square(y_true - y_pred)) 
    SS_tot = K.sum(K.square(y_true - K.mean(y_true))) 
    return ( 1 - SS_res/(SS_tot + K.epsilon()) )

代码取自kaggle

对不起,我忘了添加Tensorboard部分。

如果你想在训练过程中看到损失/度量的演变,你可以使用Tensorboard,就像在the doc中一样

代码语言:javascript
复制
logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)


history = model.fit(X_train, y_train, epochs = 100,
                    validation_split = 0.1, shuffle=False, batch_size=250, calllbacks=[tensorboard_callback])

然后在终端中使用以下行访问Tensorboard

tensorboard --logdir logs

然后你可以通过访问localhost:6006在浏览器上访问tensorboard

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59908069

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档