Iter Ator
Iter Ator

Reputation: 9269

How to use the conv1d_transpose in Tensorflow?

The conv1d_transpose is not yet in the stable version of Tensorflow, but an implementation is available on github

I would like to create a 1D deconvolution network. The shape of the input is [-1, 256, 16] and the output should be [-1,1024,8]. The kernel's size is 5 and the stride is 4.

I tried to build a 1D convolutional layer with this function:

    (output_depth, input_depth) = (8, 16)
    kernel_width = 7
    f_shape = [kernel_width, output_depth, input_depth]
    layer_1_filter = tf.Variable(tf.random_normal(f_shape))

    layer_1 = tf_exp.conv1d_transpose(
        x,
        layer_1_filter,
        [-1,1024,8],
        stride=4, padding="VALID"
    )

The shape of layer_1 is TensorShape([Dimension(None), Dimension(None), Dimension(None)]), but it should be [-1,1024,8]

What do I wrong? How is it possible to implement 1D deconvolution in Tensorflow?

Upvotes: 2

Views: 5443

Answers (2)

Ali Yazdizadeh
Ali Yazdizadeh

Reputation: 36

The new tf.contrib.nn.conv1d_transpose is now added to Tensorflow API r1.8.

Upvotes: 2

Maxim
Maxim

Reputation: 53758

The pull request is open as of this moment, so the API and behavior can and probably will change. Some feature that one might expect from conv1d_transpose aren't supported:

  • output_shape requires batch size to be known statically, can't pass -1;
  • on the other hand, output shape is dynamic (this explains None dimension).

Also, the kernel_width=7 expects in_width=255, not 256. Should make kernel_width less than 4 to match in_width=256. The result is this demo code:

x = tf.placeholder(shape=[None, 256, 16], dtype=tf.float32)
filter = tf.Variable(tf.random_normal([3, 8, 16]))    # [kernel_width, output_depth, input_depth]
out = conv1d_transpose(x, filter, output_shape=[100, 1024, 8], stride=4, padding="VALID")

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  result = sess.run(out, feed_dict={x: np.zeros([100, 256, 16])})
  print(result.shape)  # prints (100, 1024, 8)

Upvotes: 3

Related Questions