scravy
scravy

Reputation: 12283

pyspark: count number of rows written

When I do

df: DataFrame = ...
df.write.parquet('some://location/')

Can I track and report (for monitoring) the number of rows that was just written to some://location?

df.write.parquet('some://location/')
# I imagine something like:
spark_session.someWeirdApi().mostRecentOperation().number_of_rows_written

Upvotes: 2

Views: 2970

Answers (2)

Nick Chammas
Nick Chammas

Reputation: 12672

If you are using DataFrames on Spark 3.3+, then the modern way to do this is with DataFrame.observe.

from pyspark.sql.functions import count, max
from pyspark.sql import Observation

data = spark.range(1000)
observation = Observation("write-metrics")
(
    data
    .observe(
        observation,
        count("*").alias("count"),
        max("id").alias("max_id"),
    )
    .write.parquet("output")
)

# Must call this after an action on `data` populates the observation.
observation.get

You can compute multiple metrics at once as part of an observation. In this example, we're counting the number of rows written as well as tracking the maximum value for id.

Here's what the observation results look like:

>>> observation.get
{'count': 1000, 'max_id': 999}

Upvotes: 3

scravy
scravy

Reputation: 12283

After doing some digging I found a way to do it:

  • You can register a QueryExecutionListener (beware, this is annotated @DeveloperApi in the source) via py4j's callbacks
  • but you need to start the callback server and stop the gateway manually at the end of the run of your application.

This is inspired by a post in the cloudera community, I had to port it to a more recent spark version (this uses spark 3.0.1, the answer suggested over there uses the deprecated SQLContext) and pyspark (using a py4j callback).

import numpy as np
import pandas as pd
from pyspark.sql import SparkSession, DataFrame


class Listener:
    def onSuccess(self, funcName, qe, durationNs):
        print("success", funcName, durationNs, qe.executedPlan().metrics())
        print("rows", qe.executedPlan().metrics().get("numOutputRows").value())
        print("files", qe.executedPlan().metrics().get("numFiles").value())
        print("bytes", qe.executedPlan().metrics().get("numOutputBytes").value())

    def onFailure(self, funcName, qe, exception):
        print("failure", funcName, exception, qe.executedPlan().metrics())

    class Java:
        implements = ["org.apache.spark.sql.util.QueryExecutionListener"]


def run():
    spark: SparkSession = SparkSession.builder.getOrCreate()

    df: DataFrame = spark.createDataFrame(pd.DataFrame(np.random.randn(20, 3), columns=["foo", "bar", "qux"]))

    gateway = spark.sparkContext._gateway
    gateway.start_callback_server()

    listener = Listener()
    spark._jsparkSession.listenerManager().register(listener)

    df.write.parquet("/tmp/file.parquet", mode='overwrite')

    spark._jsparkSession.listenerManager().unregister(listener)

    spark.stop()
    spark.sparkContext.stop()
    gateway.shutdown()


if __name__ == '__main__':
    run()

Upvotes: 4

Related Questions