Oliver Urbann
Oliver Urbann

Reputation: 23

Writing a Conv2D like operation in TensorFlow

In my CNN I need a layer that performs an operation like a Conv2D that substracts instead of multiplying. I already have a code that is working, where inputs[0] is a full image and inputs[1] a Tensor with shape e.g. (None, 5, 3, 512). I have implemented a custom layer in Keras where this is a part of call():

    ...
    lines = []
    for x in range(0, x_max, x_step):
        line_parts = []
        for y in range(0, y_max, y_step):
            line_parts.append(inputs[0][:,x:x+x_step, y:y+y_step] - inputs[1])
        line = K.concatenate(line_parts, 2)
        lines.append(line)
    img = K.concatenate(lines, 1)
    ...

However, with smaller x_step or y_step it gets too large. How should this kind of loop be implemented without implementing this in C++ or CUDA in a low level part of TensorFlow?

I tried to slice input[0] and then use tf.map_fn but cannot find an operation that is able to cut out all my desired smaller tensors at once without loop. Furthermore, I try to use tf.while_loop but I'm having problems to create an empty tf.Variable with shape [None, ...] and I also don't see a solution to use tf.concat to build the final Tensor from an empty one.

Thx in advance!

Upvotes: 2

Views: 127

Answers (1)

javidcf
javidcf

Reputation: 59731

I think what you need can be done like this:

import tensorflow as tf

def subtract_patches(imgs, patches):
    # Get dimensions
    img_shape = tf.shape(imgs)
    img_h = img_shape[1]
    img_w = img_shape[2]
    img_c = img_shape[3]
    patch_shape = tf.shape(patches)
    patch_h = patch_shape[1]
    patch_w = patch_shape[2]
    # Reshape image into patches
    imgs = tf.reshape(imgs, [-1, img_h // patch_h, patch_h, img_w // patch_w, patch_w, img_c])
    # Do subtraction
    out = imgs - tf.expand_dims(tf.expand_dims(patches, 1), 3)
    # Reshape result back
    out = tf.reshape(out, img_shape)
    return out

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    imgs = tf.reshape(tf.range(2 * 6 * 8 * 2, dtype=tf.float32), (2, 6, 8, 2))
    patches = 0.1 * tf.reshape(tf.range(2 * 3 * 4 * 2, dtype=tf.float32), (2, 3, 4, 2))
    out = subtract_patches(imgs, patches)
    print(sess.run(out))

Output:

[[[[  0.    0.9]
   [  1.8   2.7]
   [  3.6   4.5]
   [  5.4   6.3]
   [  8.    8.9]
   [  9.8  10.7]
   [ 11.6  12.5]
   [ 13.4  14.3]]

  [[ 15.2  16.1]
   [ 17.   17.9]
   [ 18.8  19.7]
   [ 20.6  21.5]
   [ 23.2  24.1]
   [ 25.   25.9]
   [ 26.8  27.7]
   [ 28.6  29.5]]

  [[ 30.4  31.3]
   [ 32.2  33.1]
   [ 34.   34.9]
   [ 35.8  36.7]
   [ 38.4  39.3]
   [ 40.2  41.1]
   [ 42.   42.9]
   [ 43.8  44.7]]

  [[ 48.   48.9]
   [ 49.8  50.7]
   [ 51.6  52.5]
   [ 53.4  54.3]
   [ 56.   56.9]
   [ 57.8  58.7]
   [ 59.6  60.5]
   [ 61.4  62.3]]

  [[ 63.2  64.1]
   [ 65.   65.9]
   [ 66.8  67.7]
   [ 68.6  69.5]
   [ 71.2  72.1]
   [ 73.   73.9]
   [ 74.8  75.7]
   [ 76.6  77.5]]

  [[ 78.4  79.3]
   [ 80.2  81.1]
   [ 82.   82.9]
   [ 83.8  84.7]
   [ 86.4  87.3]
   [ 88.2  89.1]
   [ 90.   90.9]
   [ 91.8  92.7]]]


 [[[ 93.6  94.5]
   [ 95.4  96.3]
   [ 97.2  98.1]
   [ 99.   99.9]
   [101.6 102.5]
   [103.4 104.3]
   [105.2 106.1]
   [107.  107.9]]

  [[108.8 109.7]
   [110.6 111.5]
   [112.4 113.3]
   [114.2 115.1]
   [116.8 117.7]
   [118.6 119.5]
   [120.4 121.3]
   [122.2 123.1]]

  [[124.  124.9]
   [125.8 126.7]
   [127.6 128.5]
   [129.4 130.3]
   [132.  132.9]
   [133.8 134.7]
   [135.6 136.5]
   [137.4 138.3]]

  [[141.6 142.5]
   [143.4 144.3]
   [145.2 146.1]
   [147.  147.9]
   [149.6 150.5]
   [151.4 152.3]
   [153.2 154.1]
   [155.  155.9]]

  [[156.8 157.7]
   [158.6 159.5]
   [160.4 161.3]
   [162.2 163.1]
   [164.8 165.7]
   [166.6 167.5]
   [168.4 169.3]
   [170.2 171.1]]

  [[172.  172.9]
   [173.8 174.7]
   [175.6 176.5]
   [177.4 178.3]
   [180.  180.9]
   [181.8 182.7]
   [183.6 184.5]
   [185.4 186.3]]]]

Upvotes: 1

Related Questions