Reputation: 16831
Let's say that I have the following, very straightforward pipeline:
import os
from tfx import v1 as tfx
_dataset_folder = './tfrecords/train/*'
_pipeline_data_folder = './pipeline_data'
_serving_model_dir = os.path.join(_pipeline_data_folder, 'serving_model')
example_gen = tfx.components.ImportExampleGen(input_base=_dataset_folder)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema'])
_transform_module_file = 'preprocessing_fn.py'
transform = tfx.components.Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file=os.path.abspath(_transform_module_file),
custom_config={'statistics_gen': statistics_gen.outputs['statistics'],
'schema_gen': schema_gen.outputs['schema']})
_trainer_module_file = 'run_fn.py'
trainer = tfx.components.Trainer(
module_file=os.path.abspath(_trainer_module_file),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema_gen.outputs['schema'],
train_args=tfx.proto.TrainArgs(num_steps=10),
eval_args=tfx.proto.EvalArgs(num_steps=6))
pusher = tfx.components.Pusher(
model=trainer.outputs['model'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=_serving_model_dir)))
components = [
example_gen,
statistics_gen,
schema_gen,
example_validator,
transform,
trainer,
pusher,
]
pipeline = tfx.dsl.Pipeline(
pipeline_name='straightforward_pipeline',
pipeline_root=_pipeline_data_folder,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
f'{_pipeline_data_folder}/metadata.db'),
components=components)
tfx.orchestration.LocalDagRunner().run(pipeline)
The only part a bit out of ordinary in the snippet code above is the fact that I'm passing statistics_gen
and schema_gen
to the transform step of the pipeline in the custom_config
argument. What I'm hoping to achieve here is iterating over the list of features in the Dataset, in order to transform them.
This is what I need for that:
My question is, how can I do this in my preprocessing_fn.py
function?
BTW, I know how to do this if I have access to the CSV version of the dataset:
import tensorflow_data_validation as tfdv
dataset_stats = tfdv.generate_statistics_from_csv(examples_file)
feature_1_stats = tfdv.get_feature_stats(dataset_stats.datasets[0],
tfdv.FeaturePath(['feature_1']))
But there is a problem. It is extracting all the info from the dataset while in my code, I believe, they are already extracted by the pipeline steps statistics_gen
and schema_gen
. And I don't want to redo the whole process. I just need to learn how to use the mentioned two steps to get the info I need.
Upvotes: 3
Views: 131
Reputation: 326
I love this line of thinking!
The reason this currently doesn't work is custom_config
needs to be a "pure dictionary", it cannot handle anything dynamic that's populated at runtime. The outputs of statistics_gen are "LIVE" during pipeline execution, but as you likely see, custom_config
just shows the string version of the channel information.
One potential approach is to read the schema and statistics that are stored on disk from preprocessing_fn
and do dynamic transformations there. I know that function gets parsed by tensorflow transform, so I don't know if this approach will break that parsing required for the production of the computation graph.
The second approach, which might be equivalent, due to the potential parsing problem, is to update the component spec, component, and executor such that Transform
takes additional inputs and processes them. This doesn't avoid the problem above, but makes it cleaner from an implementation perspective.
The way I ran into this type of problem is I wanted to make Trainer
and Tuner
more dynamic by using the number of examples in the input data on the fly. Turns out this hits the same "custom config" must be a pure dictionary problem as you're seeing, but doesn't have the parsing complexity associated with the Transform component.
Upvotes: 0