Reputation: 882
I am trying to use TF.dataset.map to port over this old code because I get a deprecation warning.
Old code which reads a set of custom protos from a TFRecord file:
record_iterator = tf.python_io.tf_record_iterator(path=filename)
for record in record_iterator:
example = MyProto()
example.ParseFromString(record)
I am trying to use eager mode and map, but I get this error.
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(string)
dataset = tf.data.TFRecordDataset(dataset_paths)
parsed_protos = raw_tf_dataset.map(parse_proto)
This code works:
for raw_record in raw_tf_dataset:
proto_object = MyProto()
proto_object.ParseFromString(raw_record.numpy())
But the map gives me an error:
TypeError: a bytes-like object is required, not 'Tensor'
What is the right way to take use the argument the function results of the map and treat them like a string?
Upvotes: 5
Views: 5092
Reputation:
You need to extract string form the tensor and use in the map
function. Below are the steps to be implemented in the code to achieve this.
tf.py_function(get_path, [x], [tf.float32])
. You can find more about tf.py_function here. In tf.py_function
, first argument is the name of map
function, second argument is the element to be passed to map
function and final argument is the return type.bytes.decode(file_path.numpy())
in map function.So modify your program as below,
parsed_protos = raw_tf_dataset.map(parse_proto)
to
parsed_protos = raw_tf_dataset.map(lambda x: tf.py_function(parse_proto, [x], [function return type]))
Also modify parse_proto
as below,
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(string)
to
def parse_proto(string):
proto_object = MyProto()
proto_object.ParseFromString(bytes.decode(string.numpy()))
In the below simple program, we are using tf.data.Dataset.list_files
to read path of the image. Next in the map
function we are reading the image using load_img
and later doing the tf.image.central_crop
function to crop central part of the image.
Code -
%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np
def load_file_and_process(path):
image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
image = img_to_array(image)
image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
return image
train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))
for f in train_dataset:
for l in f:
image = np.array(array_to_img(l))
plt.imshow(image)
Output -
Hope this answers your question. Happy Learning.
Upvotes: 3