如何解决如何将自定义回调的自定义指标值附加到我需要在张量板上使用的“日志”?
我需要实现一个自定义回调,以便在每个时期后计算AUC,我需要将其用作基于LSTM的神经网络中的指标。这是自定义回调:
from tensorflow.keras.callbacks import Callback
class RocCallback(Callback):
def __init__(self,training_data,validation_data):
self.x = training_data[0]
self.y = training_data[1]
self.x_val = validation_data[0]
self.y_val = validation_data[1]
def on_train_begin(self,logs={}):
self.roc_train_list = []
self.roc_val_list = []
self.roc_train=0
self.roc_val=0
logs["roc_train"] = []
logs["roc_val"] = []
return
def on_epoch_end(self,epoch,logs):
y_pred_train = self.model.predict(self.x)
roc_train = roc_auc_score(self.y,y_pred_train)
y_pred_val = self.model.predict(self.x_val)
roc_val = roc_auc_score(self.y_val,y_pred_val)
#print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_train,4)),str(round(roc_val,4))),end=100*' '+'\n')
# self.history['roc_auc_train'].append(round(roc_train,4))
# self.history['roc_auc_val'].append(round(roc_val,4))
self.roc_train = round(roc_train,4)
self.roc_val = round(roc_val,4)
self.roc_train_list.append(self.roc_train)
self.roc_val_list.append(self.roc_val)
print("\rroc_train: %f — roc_val: %f" %(self.roc_train,self.roc_val))
logs["roc_train"]= self.roc_train
logs["roc_val"] = self.roc_val
return logs
有两件事不能正常工作:
-
print("\rroc_train: %f — roc_val: %f" %(self.roc_train,self.roc_val))
在纪元进度条之前打印 ,但它需要在纪元进度条之后打印 例如:
Epoch 2/20
roc_train: 0.550000 — roc_val: 0.547800
2561/2561 [==============================] - 89s 35ms/step - loss: 0.5326 - val_loss: 0.4513
Epoch 3/20
roc_train: 0.559800 — roc_val: 0.558000
2561/2561 [==============================] - 88s 34ms/step - loss: 0.5049 - val_loss: 0.4406
- 张量板中的日志仅以epoch_loss作为度量值,而没有“ roc_train”或“ roc_val”值。 我尝试过
logs["roc_train"].append(self.roc_train)
logs["roc_val"].append(self.roc_val)
但它会引发关键错误。
解决方法
作为一种快速的替代方法,您是否尝试过使用内置的https://www.tensorflow.org/api_docs/python/tf/keras/metrics/AUC
,指标,
tf.keras.metrics.AUC(
num_thresholds=200,curve='ROC',summation_method='interpolation',name=None,dtype=None,thresholds=None,multi_label=False,label_weights=None
)
它可能会暂时解决您的问题。
您的代码确实没有错;在model.fit()的回调列表中,您能否将回调放置在列表的第一位置;在我的情况下,碰巧我想一次保存到.csv,而CustomMetric()回调是最后一个,因此.csv仅保存有loss和val_loss而不保存我的自定义指标。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。