Reputation: 7962
I create a dataset by reading the TFRecords, I map the values and I want to filter the dataset for specific values, but since the result is a dict with tensors, I am not able to get the actual value of a tensor or to check it with tf.cond()
/ tf.equal
. How can I do that?
def mapping_func(serialized_example):
feature = { 'label': tf.FixedLenFeature([1], tf.string) }
features = tf.parse_single_example(serialized_example, features=feature)
return features
def filter_func(features):
# this doesn't work
#result = features['label'] == 'some_label_value'
# neither this
result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
return result
def main():
file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(file_names)
dataset = dataset.map(mapping_func)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.filter(filter_func)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
sample = iterator.get_next()
Upvotes: 19
Views: 18169
Reputation: 1356
Reading, filtering a dataset is very easy and there is no need to unstack anything.
to read the dataset:
print(my_dataset, '\n\n')
##let us print the first 3 records
for record in my_dataset.take(3):
##below could be large in case of image
print(record)
##let us print a specific key
print(record['key2'])
To filter is equally simple:
my_filtereddataset = my_dataset.filter(_filtcond1)
where you define _filtcond1 however you want. Let us say there is a 'true' 'false' boolean flag in your dataset, then:
@tf.function
def _filtcond1(x):
return x['key_bool'] == 1
or even a lambda function:
my_filtereddataset = my_dataset.filter(lambda x: x['key_int']>13)
If you are reading a dataset which you havent created or you are unaware of the keys (as seems to be the OPs case), you can use this to get an idea of the keys and structure first:
import json
from google.protobuf.json_format import MessageToJson
for raw_record in noidea_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
##print(example) ##if image it will be toooolong
m = json.loads(MessageToJson(example))
print(m['features']['feature'].keys())
Now you can proceed with the filtering
Upvotes: 1
Reputation: 43
I think you don't need to make label a 1-dimensional array in the first place.
with:
feature = {'label': tf.FixedLenFeature((), tf.string)}
you won't need to unstack the label in your filter_func
Upvotes: 1
Reputation: 7962
I am answering my own question. I found the issue!
What I needed to do is tf.unstack()
the label like this:
label = tf.unstack(features['label'])
label = label[0]
before I give it to tf.equal()
:
result = tf.reshape(tf.equal(label, 'some_label_value'), [])
I suppose the problem was that the label is defined as an array with one element of type string tf.FixedLenFeature([1], tf.string)
, so in order to get the first and single element I had to unpack it (which creates a list) and then get the element with index 0, correct me if I'm wrong.
Upvotes: 6
Reputation: 79
You should try to use the apply function from tf.data.TFRecordDataset tensorflow documentation
Otherwise... read this article about TFRecords to get a better knowledge about TFRecords TFRecords for humans
But the most likely situation is that you can not access neither modify a TFRecord...there is a request on github about this topic TFRecords request
My advice is to make the things as easy as you can...you have to know that you are you working with graph and sessions...
In any case...if everything fail try the part of the code that does not work in a tensorflow session as simple as you can do it...probably all these operations should be done when tf.session is running...
Upvotes: -6