Reputation: 23
In my CNN I need a layer that performs an operation like a Conv2D that substracts instead of multiplying. I already have a code that is working, where inputs[0]
is a full image and inputs[1]
a Tensor with shape e.g. (None, 5, 3, 512)
. I have implemented a custom layer in Keras where this is a part of call()
:
...
lines = []
for x in range(0, x_max, x_step):
line_parts = []
for y in range(0, y_max, y_step):
line_parts.append(inputs[0][:,x:x+x_step, y:y+y_step] - inputs[1])
line = K.concatenate(line_parts, 2)
lines.append(line)
img = K.concatenate(lines, 1)
...
However, with smaller x_step
or y_step
it gets too large. How should this kind of loop be implemented without implementing this in C++ or CUDA in a low level part of TensorFlow?
I tried to slice input[0]
and then use tf.map_fn
but cannot find an operation that is able to cut out all my desired smaller tensors at once without loop. Furthermore, I try to use tf.while_loop
but I'm having problems to create an empty tf.Variable
with shape [None, ...]
and I also don't see a solution to use tf.concat
to build the final Tensor from an empty one.
Thx in advance!
Upvotes: 2
Views: 127
Reputation: 59731
I think what you need can be done like this:
import tensorflow as tf
def subtract_patches(imgs, patches):
# Get dimensions
img_shape = tf.shape(imgs)
img_h = img_shape[1]
img_w = img_shape[2]
img_c = img_shape[3]
patch_shape = tf.shape(patches)
patch_h = patch_shape[1]
patch_w = patch_shape[2]
# Reshape image into patches
imgs = tf.reshape(imgs, [-1, img_h // patch_h, patch_h, img_w // patch_w, patch_w, img_c])
# Do subtraction
out = imgs - tf.expand_dims(tf.expand_dims(patches, 1), 3)
# Reshape result back
out = tf.reshape(out, img_shape)
return out
# Test
with tf.Graph().as_default(), tf.Session() as sess:
imgs = tf.reshape(tf.range(2 * 6 * 8 * 2, dtype=tf.float32), (2, 6, 8, 2))
patches = 0.1 * tf.reshape(tf.range(2 * 3 * 4 * 2, dtype=tf.float32), (2, 3, 4, 2))
out = subtract_patches(imgs, patches)
print(sess.run(out))
Output:
[[[[ 0. 0.9]
[ 1.8 2.7]
[ 3.6 4.5]
[ 5.4 6.3]
[ 8. 8.9]
[ 9.8 10.7]
[ 11.6 12.5]
[ 13.4 14.3]]
[[ 15.2 16.1]
[ 17. 17.9]
[ 18.8 19.7]
[ 20.6 21.5]
[ 23.2 24.1]
[ 25. 25.9]
[ 26.8 27.7]
[ 28.6 29.5]]
[[ 30.4 31.3]
[ 32.2 33.1]
[ 34. 34.9]
[ 35.8 36.7]
[ 38.4 39.3]
[ 40.2 41.1]
[ 42. 42.9]
[ 43.8 44.7]]
[[ 48. 48.9]
[ 49.8 50.7]
[ 51.6 52.5]
[ 53.4 54.3]
[ 56. 56.9]
[ 57.8 58.7]
[ 59.6 60.5]
[ 61.4 62.3]]
[[ 63.2 64.1]
[ 65. 65.9]
[ 66.8 67.7]
[ 68.6 69.5]
[ 71.2 72.1]
[ 73. 73.9]
[ 74.8 75.7]
[ 76.6 77.5]]
[[ 78.4 79.3]
[ 80.2 81.1]
[ 82. 82.9]
[ 83.8 84.7]
[ 86.4 87.3]
[ 88.2 89.1]
[ 90. 90.9]
[ 91.8 92.7]]]
[[[ 93.6 94.5]
[ 95.4 96.3]
[ 97.2 98.1]
[ 99. 99.9]
[101.6 102.5]
[103.4 104.3]
[105.2 106.1]
[107. 107.9]]
[[108.8 109.7]
[110.6 111.5]
[112.4 113.3]
[114.2 115.1]
[116.8 117.7]
[118.6 119.5]
[120.4 121.3]
[122.2 123.1]]
[[124. 124.9]
[125.8 126.7]
[127.6 128.5]
[129.4 130.3]
[132. 132.9]
[133.8 134.7]
[135.6 136.5]
[137.4 138.3]]
[[141.6 142.5]
[143.4 144.3]
[145.2 146.1]
[147. 147.9]
[149.6 150.5]
[151.4 152.3]
[153.2 154.1]
[155. 155.9]]
[[156.8 157.7]
[158.6 159.5]
[160.4 161.3]
[162.2 163.1]
[164.8 165.7]
[166.6 167.5]
[168.4 169.3]
[170.2 171.1]]
[[172. 172.9]
[173.8 174.7]
[175.6 176.5]
[177.4 178.3]
[180. 180.9]
[181.8 182.7]
[183.6 184.5]
[185.4 186.3]]]]
Upvotes: 1