Mr_and_Mrs_D
Mr_and_Mrs_D

Reputation: 34026

Deleting every 3rd element of a tensor in tensorflow

I am looking for the analogous of np.delete in tensorflow - so I have batches of tensors - each batch has shape (batch_size, variable_length) and I want to get a tensor of shape (batch_size, 2 * variable_length / 3). As seen each batch has a different length which is stored and read from the tfrecord. I am a bit at a loss here on what API I should use for that. Related (for numpy):

where the solution would simply be np.delete(x, slice(2, None, 3)) (after performing a reshape to cater for batch_size)

As requested in the comments I post the code for parsing a single example proto - although I am interested in deleting the nth (3rd) element of a tensor as a standalone question.

@classmethod
def parse_single_example(cls, example_proto):
    instance = cls()
    features_dict = cls._get_features_dict(example_proto)
    instance.path_length = features_dict['path_length']
    ...
    instance.coords = tf.decode_raw(features_dict['coords'], DATA_TYPE) # the tensor
    ...
    return instance.coords, ...

@classmethod
def _get_features_dict(cls, value):
    features_dict = tf.parse_single_example(value,
        features={'coords': tf.FixedLenFeature([], tf.string),
                  ...
                  'path_length': tf.FixedLenFeature([], tf.int64)})
    return features_dict

Upvotes: 1

Views: 2455

Answers (2)

Mr_and_Mrs_D
Mr_and_Mrs_D

Reputation: 34026

Here is a way avoiding the tf.py_func:

import numpy as np
import tensorflow as tf

slices = ([[1, 2, 3, 4, 5, 6]], [2])
d = tf.contrib.data.Dataset.from_tensor_slices(slices)
d = d.map(lambda coords, _pl: tf.boolean_mask(coords, tf.tile(
  np.array([True, True, False]), tf.reshape(tf.cast(_pl, tf.int32), [1]))))

it = d.make_one_shot_iterator()

with tf.Session() as sess:
  print(sess.run(it.get_next()))
  # [1 2 4 5]

Like all things tensorflow was a bit hard to get right - note the cast (tile fails for int64 'multiples' parameter (which was the length type I read from the tf records)), and the rather unintuitive reshape needed. Generalizing this example to accept variable length arrays is left as an exercise.

I would be interested in a gather_nd version of this code.

Upvotes: 1

GPhilo
GPhilo

Reputation: 19123

Disclamer: Since you do not provide a minimum, complete and verifiable example, my code cannot be fully tested. You'll need to try and adapt it to your needs.

This is how you could do it using the tf.data API. Please note that since you're not showing the whole layout of your class, I have to make some assumptions on how and where your data is accessible.

First of all, I'm assuming that your class' constructor knows where the .tfrecord files are stored. Specifically, I'll assume that TFRECORD_FILENAMES is a list containing all the file paths to the files you want to extract the records from.

In your class constructor, you need to instantiate a TFRecordDataset and map() on it functions that modify the data the dataset contains:

class MyClass():
    def __init__(self):
        # more init stuff
        def parse_example(serialized_example):
            features_dict = tf.parse_single_example(value,
              features={'coords': tf.FixedLenFeature([], tf.string),
              ...
              'path_length': tf.FixedLenFeature([], tf.int64)})
            return features_dict

        def skip_every_third_pyfunc(coords):
            # you mention something about a reshape, I guess that goes here as well
            return np.delete(coords, slice(None, None, 3)) 

        self.dataset = (tf.data.TFRecordDataset(TFRECORD_FILENAMES)
                        .map(parse_example)
                        .map( lambda features_dict : { **features_dict, **{'coords': tf.py_func(skip_every_third_pyfunc, features_dict['coords'], features_dict['coords'].dtype)} } )
        self.iterator = self.dataset.make_one_shot_iterator() # adapt this to your needs
        self.features_dict = self.iterator.get_next() # I'm putting this here because I don't know where you'll need it

Note that in skip_every_third_pyfunc you can use numpy functions because we're using tf.py_func to wrap a python function as a tensor operation (all caveats in the link apply).

The ugly lambda in the second .map() call is necessary because you're using a feature dict instead of returning a tuple of tensors. py_func's argument takes numpy arrays as input and returns numpy arrays. To keep the dict format, we use python 3.5+ ** operator. If you're using older versions of python, you can define your own merge_two_dicts function and replace it in the lambda call as per this answer.

Upvotes: 0

Related Questions