Reputation: 34026
I am looking for the analogous of np.delete
in tensorflow - so I have batches of tensors - each batch has shape (batch_size, variable_length)
and I want to get a tensor of shape (batch_size, 2 * variable_length / 3)
. As seen each batch has a different length which is stored and read from the tfrecord. I am a bit at a loss here on what API I should use for that. Related (for numpy):
where the solution would simply be np.delete(x, slice(2, None, 3))
(after performing a reshape to cater for batch_size)
As requested in the comments I post the code for parsing a single example proto - although I am interested in deleting the nth (3rd) element of a tensor as a standalone question.
@classmethod
def parse_single_example(cls, example_proto):
instance = cls()
features_dict = cls._get_features_dict(example_proto)
instance.path_length = features_dict['path_length']
...
instance.coords = tf.decode_raw(features_dict['coords'], DATA_TYPE) # the tensor
...
return instance.coords, ...
@classmethod
def _get_features_dict(cls, value):
features_dict = tf.parse_single_example(value,
features={'coords': tf.FixedLenFeature([], tf.string),
...
'path_length': tf.FixedLenFeature([], tf.int64)})
return features_dict
Upvotes: 1
Views: 2455
Reputation: 34026
Here is a way avoiding the tf.py_func
:
import numpy as np
import tensorflow as tf
slices = ([[1, 2, 3, 4, 5, 6]], [2])
d = tf.contrib.data.Dataset.from_tensor_slices(slices)
d = d.map(lambda coords, _pl: tf.boolean_mask(coords, tf.tile(
np.array([True, True, False]), tf.reshape(tf.cast(_pl, tf.int32), [1]))))
it = d.make_one_shot_iterator()
with tf.Session() as sess:
print(sess.run(it.get_next()))
# [1 2 4 5]
Like all things tensorflow was a bit hard to get right - note the cast (tile fails for int64
'multiples' parameter (which was the length type I read from the tf records)), and the rather unintuitive reshape needed. Generalizing this example to accept variable length arrays is left as an exercise.
I would be interested in a gather_nd version of this code.
Upvotes: 1
Reputation: 19123
Disclamer: Since you do not provide a minimum, complete and verifiable example, my code cannot be fully tested. You'll need to try and adapt it to your needs.
This is how you could do it using the tf.data
API.
Please note that since you're not showing the whole layout of your class, I have to make some assumptions on how and where your data is accessible.
First of all, I'm assuming that your class' constructor knows where the .tfrecord
files are stored. Specifically, I'll assume that TFRECORD_FILENAMES
is a list
containing all the file paths to the files you want to extract the records from.
In your class constructor, you need to instantiate a TFRecordDataset
and map()
on it functions that modify the data the dataset contains:
class MyClass():
def __init__(self):
# more init stuff
def parse_example(serialized_example):
features_dict = tf.parse_single_example(value,
features={'coords': tf.FixedLenFeature([], tf.string),
...
'path_length': tf.FixedLenFeature([], tf.int64)})
return features_dict
def skip_every_third_pyfunc(coords):
# you mention something about a reshape, I guess that goes here as well
return np.delete(coords, slice(None, None, 3))
self.dataset = (tf.data.TFRecordDataset(TFRECORD_FILENAMES)
.map(parse_example)
.map( lambda features_dict : { **features_dict, **{'coords': tf.py_func(skip_every_third_pyfunc, features_dict['coords'], features_dict['coords'].dtype)} } )
self.iterator = self.dataset.make_one_shot_iterator() # adapt this to your needs
self.features_dict = self.iterator.get_next() # I'm putting this here because I don't know where you'll need it
Note that in skip_every_third_pyfunc
you can use numpy
functions because we're using tf.py_func
to wrap a python function as a tensor operation (all caveats in the link apply).
The ugly lambda in the second .map()
call is necessary because you're using a feature dict instead of returning a tuple of tensors. py_func
's argument takes numpy arrays as input and returns numpy arrays. To keep the dict format, we use python 3.5+ ** operator. If you're using older versions of python, you can define your own merge_two_dicts
function and replace it in the lambda call as per this answer.
Upvotes: 0