Mehran
Mehran

Reputation: 16831

How to get the list of features along side their schema and stats using TFX

Let's say that I have the following, very straightforward pipeline:

import os
from tfx import v1 as tfx


_dataset_folder = './tfrecords/train/*'
_pipeline_data_folder = './pipeline_data'
_serving_model_dir = os.path.join(_pipeline_data_folder, 'serving_model')

example_gen = tfx.components.ImportExampleGen(input_base=_dataset_folder)
statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=True)
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])

_transform_module_file = 'preprocessing_fn.py'
transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_transform_module_file),
    custom_config={'statistics_gen': statistics_gen.outputs['statistics'],
                   'schema_gen': schema_gen.outputs['schema']})

_trainer_module_file = 'run_fn.py'
trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_trainer_module_file),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=10),
    eval_args=tfx.proto.EvalArgs(num_steps=6))


pusher = tfx.components.Pusher(
  model=trainer.outputs['model'],
  push_destination=tfx.proto.PushDestination(
    filesystem=tfx.proto.PushDestination.Filesystem(
        base_directory=_serving_model_dir)))

components = [
    example_gen,
    statistics_gen,
    schema_gen,
    example_validator,
    transform,
    trainer,
    pusher,
]

pipeline = tfx.dsl.Pipeline(
    pipeline_name='straightforward_pipeline',
    pipeline_root=_pipeline_data_folder,
    metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
        f'{_pipeline_data_folder}/metadata.db'),
    components=components)

tfx.orchestration.LocalDagRunner().run(pipeline)

The only part a bit out of ordinary in the snippet code above is the fact that I'm passing statistics_gen and schema_gen to the transform step of the pipeline in the custom_config argument. What I'm hoping to achieve here is iterating over the list of features in the Dataset, in order to transform them.

This is what I need for that:

  1. The list of features in the dataset (I don't want to hardcode/assume this list, I want my code to come up with it automatically)
  2. Each feature's type (again, I don't want to hardcode/assume them)
  3. Each feature's statistical attributes (like min, max) and again, I don't want to hardcode/assume them!!!

My question is, how can I do this in my preprocessing_fn.py function?

BTW, I know how to do this if I have access to the CSV version of the dataset:

import tensorflow_data_validation as tfdv

dataset_stats = tfdv.generate_statistics_from_csv(examples_file)
feature_1_stats = tfdv.get_feature_stats(dataset_stats.datasets[0],
                                         tfdv.FeaturePath(['feature_1']))

But there is a problem. It is extracting all the info from the dataset while in my code, I believe, they are already extracted by the pipeline steps statistics_gen and schema_gen. And I don't want to redo the whole process. I just need to learn how to use the mentioned two steps to get the info I need.

Upvotes: 3

Views: 131

Answers (1)

Pritam Dodeja
Pritam Dodeja

Reputation: 326

I love this line of thinking!

The reason this currently doesn't work is custom_config needs to be a "pure dictionary", it cannot handle anything dynamic that's populated at runtime. The outputs of statistics_gen are "LIVE" during pipeline execution, but as you likely see, custom_config just shows the string version of the channel information.

One potential approach is to read the schema and statistics that are stored on disk from preprocessing_fn and do dynamic transformations there. I know that function gets parsed by tensorflow transform, so I don't know if this approach will break that parsing required for the production of the computation graph.

The second approach, which might be equivalent, due to the potential parsing problem, is to update the component spec, component, and executor such that Transform takes additional inputs and processes them. This doesn't avoid the problem above, but makes it cleaner from an implementation perspective.

The way I ran into this type of problem is I wanted to make Trainer and Tuner more dynamic by using the number of examples in the input data on the fly. Turns out this hits the same "custom config" must be a pure dictionary problem as you're seeing, but doesn't have the parsing complexity associated with the Transform component.

Upvotes: 0

Related Questions