bountrisv
bountrisv

Reputation: 55

How to omit zeros in a 4-D tensor in tensorflow?

Say I have a tensor:

import tensorflow as tf
t = tf.Variable([[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.], [77., 0., 0., 12., 0., 0., 33., 55., 0.]],
                 [[0., 132., 0., 0., 234., 0., 1., 24., 0.], [43., 0., 0., 124., 0., 0., 0., 52., 645]]]])

I want to omit zeros and be left with a tensor of shape: (1, 2, 2, 4), with 4 being the number of non zero elements in my tensor like

t = tf.Variable([[[[235., 1006., 23., 42], [77., 12., 33., 55.]],
                 [[132., 234., 1., 24.], [43., 124., 52., 645]]]])

I've used boolean mask to to do this on a 1-D tensor. How can I omit the zeros in a 4-D tensor. Can it be generalized for higher ranks?

Upvotes: 3

Views: 2201

Answers (1)

Allen Lavoie
Allen Lavoie

Reputation: 5808

Using TensorFlow 1.12:

import tensorflow as tf

def batch_of_vectors_nonzero_entries(batch_of_vectors):
  """Removes non-zero entries from batched vectors.

  Requires that each vector have the same number of non-zero entries.

  Args:
    batch_of_vectors: A Tensor with length-N vectors, having shape [..., N].
  Returns:
    A Tensor with shape [..., M] where M is the number of non-zero entries in
    each vector.
  """
  nonzero_indices = tf.where(tf.not_equal(
      batch_of_vectors, tf.zeros_like(batch_of_vectors)))
  # gather_nd gives us a vector containing the non-zero entries of the
  # original Tensor
  nonzero_values = tf.gather_nd(batch_of_vectors, nonzero_indices)
  # Next, reshape so that all but the last dimension is the same as the input
  # Tensor. Note that this will fail unless each vector has the same number of
  # non-zero values.
  reshaped_nonzero_values = tf.reshape(
      nonzero_values,
      tf.concat([tf.shape(batch_of_vectors)[:-1], [-1]], axis=0))
  return reshaped_nonzero_values

t = tf.Variable(
    [[[[0., 235., 0., 0., 1006., 0., 0., 23., 42.],
       [77., 0., 0., 12., 0., 0., 33., 55., 0.]],
      [[0., 132., 0., 0., 234., 0., 1., 24., 0.],
       [43., 0., 0., 124., 0., 0., 0., 52., 645]]]])
nonzero_t = batch_of_vectors_nonzero_entries(t)

with tf.Session():
    tf.global_variables_initializer().run()
    result_evaled = nonzero_t.eval()
    print(result_evaled.shape, result_evaled)

Prints:

(1, 2, 2, 4) [[[[  2.35000000e+02   1.00600000e+03   2.30000000e+01   4.20000000e+01]
   [  7.70000000e+01   1.20000000e+01   3.30000000e+01   5.50000000e+01]]

  [[  1.32000000e+02   2.34000000e+02   1.00000000e+00   2.40000000e+01]
   [  4.30000000e+01   1.24000000e+02   5.20000000e+01   6.45000000e+02]]]]

It may be useful to look into SparseTensors if the result ever ends up being ragged.

Upvotes: 2

Related Questions