如何解决在model.fit中自定义train_step Tensorflow Keras-InvalidArgumentError:操作'while'没有名为'_XlaCompile'的attr
我正在尝试实现一个自定义的多输入和输出模型,该模型使用this论文中提出的学习算法。该模型本身可以很好地运行,而无需使用我用作基准的自定义学习算法。我遇到的问题是,代码在代码行的DebiasModel类的train_step函数中被卡住了:
mc_pred = self.main_classifier([xu,xs],training=True)
它没有返回错误。运行一个小时后,我中断了内核,并返回错误消息:
InvalidArgumentError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception,another exception occurred:
InvalidArgumentError: Operation 'gradients/while_grad/Placeholder_28' has no attr named '_read_only_resource_inputs'.
我不确定问题是什么,我也尝试在tf.GradientTape中使用persistent = True,而不是在单个手表中声明两个gradientTapes。但是,会发生完全相同的错误。
有人知道这个问题是什么吗?以及如何解决?
我正在使用Tensorflow V2.3.0和Keras V2.4.0
源代码
class model_components:
def mitigation_expert():
inputs = Input(shape=(300,),dtype=tf.int32,name="me_input")
x = Embedding(num_tokens,300,weights=[embedding_matrix],input_length=max_length,trainable=False,name="me_embedding")(inputs)
x = LSTM(300,return_sequences=False,name="me_lstm")(x)
model = Model(inputs,x)
return model
def control_expert():
inputs = Input(shape=(22,name="ce_input")
y = Dense(19,activation='relu',name="ce_hidden")(inputs)
model = Model(inputs,y)
return model
def main_classifier():
# Expert components
me = model_components.mitigation_expert()
ce = model_components.control_expert()
# Main classifier
ensemble = concatenate([me.output,ce.output],name="pred_ensemble")
pred_output = Dense(319,activation="relu",name="pred_hidden")(ensemble)
pred_output = Dense(3,activation="softmax",name="pred_output")(pred_output)
model = Model(inputs=[me.input,ce.input],outputs=pred_output,name="main_classifier")
return model
def adversary_classifier():
# Mitigation Expert component
me = model_components.mitigation_expert()
# Adversary classifier
adv_output = Dense(300,name="adv_hidden")(me.output)
adv_output = Dense(1,activation='sigmoid',name="adv_output")(adv_output)
model = Model(inputs=me.input,outputs=adv_output,name="adversary_classifier")
return model
def tf_normalize(x):
return x / (tf.norm(x) + np.finfo(np.float32).tiny)
class DebiasModel(keras.Model):
def __init__(self,main_classifier,adversary_classifier):
super(DebiasModel,self).__init__()
self.main_classifier = main_classifier
self.adversary_classifier = adversary_classifier
def compile(self,mc_optimizer,adv_optimizer,mc_loss,adv_loss,debias_param):
super(DebiasModel,self).compile()
self.mc_optimizer = mc_optimizer
self.adv_optimizer = adv_optimizer
self.mc_loss = mc_loss
self.adv_loss = adv_loss
self.debias_param = debias_param
def train_step(self,data):
# Unpack data from model.fit()
x,y,sample_weight = data
# Unpack input and output features
xu,xs = x
y_mc = y['pred_output']
z_adv = y['adv_output']
# Unpack sample_weights
mainClass_weights = sample_weight["pred_output"]
protectClass_weights = sample_weight["adv_output"]
# Generate prediction and compute loss for Main_Classifier
with tf.GradientTape() as mc_tape,tf.GradientTape() as me_mc_tape:
mc_pred = self.main_classifier([xu,training=True)
mc_loss = self.mc_loss(y_mc,mc_pred,sample_weight=mainClass_weights)
# Compute and Apply Gradients for CE & Main Classifier
mc_trainable_vars = self.main_classifier.trainable_weights[3:]
mc_grads = mc_tape.gradient(mc_loss,mc_trainable_vars)
self.mc_optimizer.apply_gradients(zip(mc_grads,mc_trainable_vars))
# Generate prediction and compute loss for Adversary_Classifier
with tf.GradientTape() as adv_tape,tf.GradientTape() as me_adv_tape:
adv_pred = self.adversary_classifier(xu)
adv_loss = self.adv_loss(z_adv,adv_pred,sample_weight=protectClass_weights)
# Compute and Apply Gradients for CE & Main Classifier
adv_trainable_vars = self.adversary_classifier.trainable_weights[3:]
adv_grads = adv_tape.gradient(adv_loss,adv_trainable_vars)
self.adv_optimizer.apply_gradients(zip(adv_grads,adv_trainable_vars))
# Compute and Apply Gradients to debias ME
me_adv_debias_trainable_vars = self.adversary_classifier.trainable_weights[:3]
adv_debias_grads = me_adv_tape.gradient(adv_loss,me_adv_debias_trainable_vars)
adv_debias_dict = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(me_adv_debias_trainable_vars,adv_debias_grads),0)
me_mc_debias_trainable_vars = self.main_classifier.trainable_weights[:3]
mc_debias_grads = me_mc_tape.gradient(mc_loss,me_mc_debias_trainable_vars)
me_grads = []
for g,v in zip(mc_debias_grads,me_mc_debias_trainable_vars):
unit_adv = tf_normalize(adv_debias_dict.lookup(v))
g -= tf.math.reduce_sum(g * unit_adv) * unit_adv
g -= self.debias_param * adv_debias_dict.lookup(v)
me_grads.append(zip(g,v))
self.mc_optimizer.apply_gradients(me_grads)
return {"pred_loss": mc_loss,"adv_loss": adv_loss}
model = DebiasModel(model_components.main_classifier(),model_components.adversary_classifier())
model.compile(mc_optimizer=tf.keras.optimizers.Adam(),adv_optimizer=tf.keras.optimizers.Adam(),mc_loss=tf.keras.losses.CategoricalCrossentropy(),adv_loss=tf.keras.losses.BinaryCrossentropy(),debias_param=1)
epoch = 5
sample_weights = {
"pred_output": mainClass_weight,"adv_output": protectClass_weight,}
model.fit(x=[xu_train,xs_train],y={"pred_output": y_train,"adv_output": z_train},validation_data=([xu_val,xs_val],{"pred_output": y_val,"adv_output": z_val}),sample_weight=sample_weights,epochs=epoch,batch_size=256,verbose=1)
错误回溯
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in get_attr(self,name)
2485 with c_api_util.tf_buffer() as buf:
-> 2486 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op,name,buf)
2487 data = pywrap_tf_session.TF_GetBuffer(buf)
InvalidArgumentError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception,another exception occurred:
ValueError Traceback (most recent call last)
51 frames
ValueError: Operation 'while' has no attr named '_XlaCompile'.
During handling of the above exception,another exception occurred:
InvalidArgumentError Traceback (most recent call last)
InvalidArgumentError: Operation 'gradients/while_grad/Placeholder_28' has no attr named '_read_only_resource_inputs'.
注意:我还没有添加完整的追溯,但是如果需要,我可以提供它。提前非常感谢!
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。