Reputation: 77
I implemented a custom layer for Minibatch Standard Deviation:
class MinibatchStd(Layer):
def __init__(self, group_size=4, epsilon=1e-8):
super(MinibatchStd, self).__init__()
self.epsilon = epsilon
self.group_size = group_size
def call(self, input_tensor):
n, h, w, c = input_tensor.shape
self.group_size = tf.keras.backend.minimum(self.group_size, tf.cast(input_tensor[0], dtype=tf.int32))
x = tf.reshape(input_tensor, [self.group_size, -1, h, w, c])
group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
group_std = tf.sqrt(group_var + self.epsilon)
avg_std = tf.reduce_mean(group_std, axis=[1,2,3], keepdims=True)
x = tf.tile(avg_std, [self.group_size, h, w, 1])
return tf.concat([input_tensor, x], axis=-1)
After executing it, I get the following error:
ValueError: in user code:
<ipython-input-30-9b80a1ea4799>:20 call *
x = tf.reshape(input_tensor, [self.group_size, -1, h, w, c])
C:\ProgramData\Anaconda3\envs\gputest\lib\site-packages\tensorflow\python\ops\array_ops.py:193 reshape **
result = gen_array_ops.reshape(tensor, shape, name)
C:\ProgramData\Anaconda3\envs\gputest\lib\site-packages\tensorflow\python\ops\gen_array_ops.py:8087 reshape
"Reshape", tensor=tensor, shape=shape, name=name)
C:\ProgramData\Anaconda3\envs\gputest\lib\site-packages\tensorflow\python\framework\op_def_library.py:488 _apply_op_helper
(input_name, err))
ValueError: Tried to convert 'shape' to a tensor and failed. Error: Shapes must be equal rank, but are 3 and 0
From merging shape 0 with other shapes. for '{{node minibatch_std_4/Reshape/packed}} = Pack[N=5, T=DT_INT32, axis=0](minibatch_std_4/Minimum, minibatch_std_4/Reshape/packed/1, minibatch_std_4/Reshape/packed/2, minibatch_std_4/Reshape/packed/3, minibatch_std_4/Reshape/packed/4)' with input shapes: [4,4,256], [], [], [], [].
It only appears when I add the line:
self.group_size = tf.keras.backend.minimum(self.group_size, tf.cast(input_tensor[0], dtype=tf.int32))
I also tried to use tf.math.minimum
but also failed.
I use Keras = 2.4.3
and TF = 2.2.0
Upvotes: 1
Views: 109
Reputation: 3764
There are two ways to get tensor shapes for some tensor (say x
): x.shape
and tf.shape(x)
. These two are fundamentally different: The former simply returns a python list of the shape, and the latter adds an op in the dynamic computation graph, including placeholders for None
dimensions.
In short, instead of
n, h, w, c = input_tensor.shape
use
shape = tf.shape(input_tensor)
n = shape[0]
h = shape[1]
w = shape[2]
c = shape[3]
Upvotes: 1