如何解决在Tensorflow中为自定义静态张量保留未知的批次尺寸
一些注意事项:我正在使用tensorflow 2.3.0,python 3.8.2和numpy 1.18.5(不确定这是否很重要)
我正在编写一个自定义层,该层在内部存储形状为(a,b)的不可训练张量N,其中a,b是已知值(此张量是在初始化期间创建的)。在输入张量上调用时,它会展平输入张量,展平其存储的张量,并将两者串联在一起。不幸的是,我似乎无法弄清楚如何在此串联期间保留未知的批次尺寸。这是最少的代码:
import tensorflow as tf
from tensorflow.keras.layers import Layer,Flatten
class CustomLayer(Layer):
def __init__(self,N): # N is a tensor of shape (a,b),where a,b > 1
super(CustomLayer,self).__init__()
self.N = self.add_weight(name="N",shape=N.shape,trainable=False,initializer=lambda *args,**kwargs: N)
# correct me if I'm wrong in using this initializer approach,but for some reason,when I
# just do self.N = N,this variable would disappear when I saved and loaded the model
def build(self,input_shape):
pass # my reasoning is that all the necessary stuff is handled in init
def call(self,input_tensor):
input_flattened = Flatten()(input_tensor)
N_flattened = Flatten()(self.N)
return tf.concat((input_flattened,N_flattened),axis=-1)
我注意到的第一个问题是Flatten()(self.N)
将返回与原始self.N
具有相同形状(a,b)的张量,结果,返回值的形状将为(a,num_input_tensor_values + b)。我这样做的原因是,第一个维度a被视为批量大小。我修改了call
函数:
def call(self,input_tensor):
input_flattened = Flatten()(input_tensor)
N = tf.expand_dims(self.N,axis=0) # N would now be shape (1,a,b)
N_flattened = Flatten()(N)
return tf.concat((input_flattened,axis=-1)
这将返回一个形状为(1,num_input_vals + a * b)的张量,这很好,但是现在批次尺寸为永久1,这是我在开始使用该层训练模型时意识到的,并且只能工作批量大小为1。这在模型摘要中也很明显-如果我将这一层放在输入之后,然后再添加其他层,则输出张量的第一个维度将像None,1,1...
一样。有没有一种方法可以存储此内部张量并在call
中使用它,同时保留可变的批量大小? (例如,批量大小为4时,会将相同的展平N的副本连接到4个展平输入张量的每个张量的末尾。)
解决方法
输入中要包含样本,因此必须具有平整的N
向量,因为要串联到每个样本。将其想象成将行配对并串联在一起。如果只有一个N
向量,则只能连接一对。
要解决此问题,您应使用tf.tile()
重复N
多次,直到批次中有样品为止。
示例:
def call(self,input_tensor):
input_flattened = Flatten()(input_tensor) # input_flattened shape: (None,..)
N = tf.expand_dims(self.N,axis=0) # N shape: (1,a,b)
N_flattened = Flatten()(N) # N_flattened shape: (1,a*b)
N_tiled = tf.tile(N_flattened,[tf.shape(input_tensor)[0],1]) # repeat along the first dim as many times,as there are samples and leave the second dim alone
return tf.concat((input_flattened,N_tiled),axis=-1)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。