Reputation: 43
How can I pad a tensor (with dimension WxHxC) with the edge values?
For example:
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
becomes:
[1, 1, 2, 3, 3]
[1, 1, 2, 3, 3]
[4, 4, 5, 6, 6]
[7, 7, 8, 9, 9]
[7, 7, 8, 9, 9]
Upvotes: 1
Views: 1841
Reputation: 482
As a complement, if you want to pad image with replicate mode like opencv, the following can do it, dst_image is the image to pad. And pad_h_up, pad_h_down, pad_w_left,pad_w_right, is the four argument:
def pad_replica(image_pad, up,down, left, right):
paddings_up = tf.constant([[1, 0],[0,0],[0,0]])
paddings_down = tf.constant([[0, 1],[0,0],[0,0]])
paddings_left = tf.constant([[0, 0],[1,0],[0,0]])
paddings_right = tf.constant([[0, 0],[0, 1],[0 ,0]])
i = tf.constant(0)
c = lambda i,pad_len,pad_mode, image: tf.less(i, pad_len)
def body(i,pad_len,pad_mode,image):
i = i+1
image = tf.pad(image, pad_mode,"SYMMETRIC")
return [i, pad_len,pad_mode, image]
[_, _, _, image_pad_up] = tf.while_loop(c, body, \
[i, up, paddings_up, image_pad])
i = tf.constant(0)
[_, _, _, image_pad_down] = tf.while_loop(c, body, [i, down,paddings_down, image_pad_up])
i = tf.constant(0)
[_, _, _, image_pad_left] = tf.while_loop(c, body, [i, left, paddings_left, image_pad_down])
i = tf.constant(0)
[_, _, _, image_pad_right] = tf.while_loop(c, body, [i, right,paddings_right, image_pad_left])
i = tf.constant(0)
return image_pad_right
dst_image.set_shape([None, None, None])
dst_image = pad_replica(dst_image,\
tf.cast(pad_h_up, tf.int32),\
tf.cast(pad_h_down,tf.int32),\
tf.cast(pad_w_left, tf.int32),\
tf.cast(pad_w_right,tf.int32)
)
Upvotes: 1
Reputation: 4868
Use tf.pad()
and mode "SYMMETRIC" - it would reflect the values on the edge, but if you do only 1 depth padding, it's equivalent to repeating the edge value. If you need more padding, you have to repeat the operation, but you can go exponentially (1 first, then 2, then 4, etc.). This code (tested):
import tensorflow as tf
a = tf.reshape( tf.constant( range( 1, 10 ) ), ( 3, 3 ) )
b = tf.pad( a, [ [ 1, 1 ], [ 1, 1 ] ], "SYMMETRIC" )
with tf.Session() as sess:
print( sess.run( b ) )
Outputs:
[[1 1 2 3 3]
[1 1 2 3 3]
[4 4 5 6 6]
[7 7 8 9 9]
[7 7 8 9 9]]
as desired.
Upvotes: 4