我想更新值为0的2D张量中的索引.因此,数据是2D张量,其第2行第2列索引值将被替换为0.但是,我收到类型错误.任何人都可以帮助我吗?
TypeError: Input ‘ref’ of ‘ScatterUpdate’ Op requires l-value input
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()
sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])
解决方法:
tf.scatter_update只能应用于Variable类型.代码中的数据是变量,而data2不是,因为tf.reshape的返回类型是Tensor.
解:
对于v1.0之后的张量流
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
对于v1.0之前的张量流
data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。