RobR
RobR

Reputation: 2190

Static shape for tf.nn.crelu() undefined

Currently (ver 0.11.0) tf.nn.crelu() does not properly set the static shape:

f = tf.random_normal([50, 5, 7, 10])
f2 = tf.nn.relu(f)
print(f2.get_shape().as_list()) # [50, 5, 7, 10]
f3 = tf.nn.crelu(f)
print(f3.get_shape().as_list()) # [None, None, None, None]

with tf.Session() as sess:
    print(tf.__version__)
    py_f2, py_f3 = sess.run([f2, f3])

This has been submitted as Github issue #5912

Upvotes: 2

Views: 245

Answers (1)

RobR
RobR

Reputation: 2190

Workaround:

f = tf.random_normal([50, 5, 7, 10])
[b, nx, ny, nz] = f.get_shape().as_list()
f3 = tf.nn.crelu(f)
f3.set_shape([b, nx, ny, 2*nz]
print(f3.get_shape().as_list()) # [50, 5, 7, 20]

Upvotes: 2

Related Questions