使用python 3.7和tensorflow 2.1.0,我想知道如何在精确性和召回性都更好的时候保存最好的模型。
当只有精度更高时,下面的代码给出了保存最佳模型的元素。
# some dependencies
from tensorflow.keras.metrics import Precision, Recall
# some coding instructions to define ImageDataGenerator flows.
metrics = {"precision":Precision(name="precision"), "recall":Recall(name="recall")}
checkpoints = {
"precision" : ModelCheckpoint(
"./models/best.h5",
mode="max",
monitor="val_precision",
save_freq="epoch",
save_weights_only=False,
save_best_only=True,
verbose=1
)
}
callbacks = list(checkpoints.values())
model = Sequential()
# neural network architecture ...
model.compile(
loss="binary_crossentropy",
optimizer='adam',
metrics=[metric for metric in metrics.values()]
)
model.fit_generator(train_generator,
validation_data = val_generator,
steps_per_epoch = train_generator.n//train_generator.batch_size,
validation_steps = val_generator.n//val_generator.batch_size,
class_weight={0:0.75, 1:1.5},
callbacks=callbacks,
epochs=300)我不认为为召回添加回调将使我能够为两个指标保存最好的模型.在我看来,如果召回比上个时代计算的召回要好的话,即使精度越来越差,模型也会被保存下来。精度也一样。我认为这种代码只会给我们一种“或”逻辑,而不是“和”逻辑。
也许我错了..。有人能帮我吗?任何帮助都会很感激的。非常欢迎解释!
发布于 2022-10-29 22:15:59
我想出了一个解决办法,希望能帮到一些人。
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, model, path, metrics):
self.model = model
self.path = path
self.metrics = metrics
self.history = []
def on_epoch_end(self, epoch, logs):
if self.history:
last_logs = self.history[-1]
last_metrics_values = {metric : last_logs[metric] for metric in self.metrics}
current_metrics_values = {metric : logs[metric] for metric in self.metrics}
checking = [last_metrics_values[metric] < current_metrics_values[metric] for metric in self.metrics]
decision = reduce(operator.and_, checking)
if decision:
print("\nModel performs better at current epoch for all monitored metrics : ", self.metrics)
print("\nSaving current model in ", self.path)
self.model.save(self.path)
else:
print("\nSaving model at first epoch for initialization purpose")
self.model.save(self.path)
self.history.append(logs)您可以更改代码,使其具有更复杂的逻辑,以决定何时保存模型,以及对您来说最好的模型是什么。决策采用checking和decision进行计算。您必须具有以下依赖关系:
from functools import reduce
import operator
import tensorflow as tfhttps://stackoverflow.com/questions/74248901
复制相似问题