Reputation: 65
I have a very large tfrecord directory, and need to filter it with some column to generate new tfrecord files.
Code likes that
val df = spark.read.format("tfrecords").option("recordType", "Example").load(input_path).filter(udf_filter(col("label")))
df.write.format("tfrecords").option("recordType", "Example").mode(SaveMode.Overwrite).save(output_path)
When I run it in spark cluster, I find it will run with two steps(aggregate + write)
I check the code in https://github.com/tensorflow/ecosystem/blob/master/spark/spark-tensorflow-connector/src/main/scala/org/tensorflow/spark/datasources/tfrecords/TensorFlowInferSchema.scala#L39, it have the aggregate steps !
Can I avoid it?
The issue in github is here https://github.com/tensorflow/ecosystem/issues/201
Upvotes: 2
Views: 157