Reputation: 1
My TFX model is implemented using a custom tf.estimator.EstimatorV2 class. I'm trying to export the model so that it accepts individual features (extracted from a serialized tf.Example) as inputs, rather than the default single serialized string.
To achieve this, I'm defining a tf_estimator.export.ServingInputReceiver using the following code:
def parsing_transforming_serving_input_fn():
"""Serving input_fn that applies transforms to raw data in tf.Examples."""
raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn(
feature_spec, default_batch_size=None)
serving_input_receiver = raw_input_fn()
transformed_features = tf_transform_output.transform_raw_features(
serving_input_receiver.features, drop_unused_features=True)
def _parse_function(serialized_example):
# Parse the input tf.Example proto using the feature schema
input_features = tf.parse_single_example(serialized_example, feature_spec)
return input_features
# Define a placeholder for the serialized tf.Example
serialized_tf_example = tf.placeholder(
dtype=tf.string,
shape=[None],
name='input_example_tensor',
)
def get_tensor_spec(feature_spec):
tensor_spec_dict = {}
for key, feature in feature_spec.items():
if isinstance(feature, tf.FixedLenFeature):
tensor_spec_dict[key] = tf.TensorSpec(
shape=feature.shape, dtype=feature.dtype
)
elif isinstance(feature, tf.VarLenFeature):
tensor_spec_dict[key] = tf.SparseTensorSpec(
shape=None, dtype=feature.dtype
)
else:
raise ValueError(f'Unsupported feature type: {type(feature)}')
return tensor_spec_dict
dtype_dict = get_tensor_spec(feature_spec)
# Parse the serialized tf.Example
input_features = tf.map_fn(
_parse_function, serialized_tf_example, dtype=dtype_dict
)
# Create placeholders for each individual feature
receiver_tensors = {}
for feature_name, feature_tensor in input_features.items():
if isinstance(feature_tensor, tf.SparseTensor):
receiver_tensors[feature_name] = tf.sparse_placeholder(
dtype=feature_tensor.dtype, name=feature_name
)
else:
receiver_tensors[feature_name] = tf.placeholder(
dtype=feature_tensor.dtype,
shape=feature_tensor.shape,
name=feature_name,
)
return tf_estimator.export.ServingInputReceiver(
transformed_features, receiver_tensors
)
However, while this code doesn't produce any runtime errors, the exported model lacks signature values when loaded. This suggests that the ServingInputReceiver inputs might be incorrectly configured.
Any insights into what I might be doing wrong would be greatly appreciated. Thanks!
Upvotes: 0
Views: 20