如何解决如何在张量流解码中实现重罚
这是tensor2tensor中的解码部分。 您还可以在 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
中从1155行到1298行中找到它def fast_decode(encoder_output,encoder_decoder_attention_bias,symbols_to_logits_fn,hparams,decode_length,vocab_size,init_cache_fn=_init_transformer_cache,beam_size=1,top_beams=1,alpha=1.0,sos_id=0,eos_id=beam_search.EOS_ID,batch_size=None,force_decode_length=False,scope_prefix="body/",cache=None):
if encoder_output is not None:
batch_size = common_layers.shape_list(encoder_output)[0]
cache = init_cache_fn(
cache=cache,hparams=hparams,batch_size=batch_size,attention_init_length=0,encoder_output=encoder_output,encoder_decoder_attention_bias=encoder_decoder_attention_bias,scope_prefix=scope_prefix)
def inner_loop(i,hit_eos,next_id,decoded_ids,cache,log_prob):
"""One step of greedy decoding."""
logits,cache = symbols_to_logits_fn(next_id,i,cache)
log_probs = common_layers.log_prob_from_logits(logits)
temperature = getattr(hparams,"sampling_temp",0.0)
keep_top = getattr(hparams,"sampling_keep_top_k",-1)
if hparams.sampling_method == "argmax":
temperature = 0.0
next_id = common_layers.sample_with_temperature(
logits,temperature,keep_top)
hit_eos |= tf.equal(next_id,eos_id)
log_prob_indices = tf.stack([tf.range(tf.to_int64(batch_size)),next_id],axis=1)
log_prob += tf.gather_nd(log_probs,log_prob_indices)
next_id = tf.expand_dims(next_id,axis=1)
decoded_ids = tf.concat([decoded_ids,axis=1)
return i + 1,log_prob
def is_not_finished(i,*_):
finished = i >= decode_length
if not force_decode_length:
finished |= tf.reduce_all(hit_eos)
return tf.logical_not(finished)
decoded_ids = tf.zeros([batch_size,0],dtype=tf.int64)
hit_eos = tf.fill([batch_size],False)
next_id = sos_id * tf.ones([batch_size,1],dtype=tf.int64)
initial_log_prob = tf.zeros([batch_size],dtype=tf.float32)
_,_,log_prob = tf.while_loop(
is_not_finished,inner_loop,[
tf.constant(0),initial_log_prob
],shape_invariants=[
tf.TensorShape([]),tf.TensorShape([None]),tf.TensorShape([None,None]),nest.map_structure(beam_search.get_state_shape_invariants,cache),])
scores = log_prob
return {"outputs": decoded_ids,"scores": scores,"cache": cache}
有什么方法可以避免产生重复的单词? 例如,该模型可能会生成“我吃吃吃吃吃苹果”。
我认为获得登录后
logits,cache)
也许有一种方法可以减少logits中先前id的值,但是我不知道该怎么做。
任何想法?
谢谢
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。