Reputation: 12283
When I do
df: DataFrame = ...
df.write.parquet('some://location/')
Can I track and report (for monitoring) the number of rows that was just written to some://location
?
df.write.parquet('some://location/')
# I imagine something like:
spark_session.someWeirdApi().mostRecentOperation().number_of_rows_written
Upvotes: 2
Views: 2970
Reputation: 12672
If you are using DataFrames on Spark 3.3+, then the modern way to do this is with DataFrame.observe.
from pyspark.sql.functions import count, max
from pyspark.sql import Observation
data = spark.range(1000)
observation = Observation("write-metrics")
(
data
.observe(
observation,
count("*").alias("count"),
max("id").alias("max_id"),
)
.write.parquet("output")
)
# Must call this after an action on `data` populates the observation.
observation.get
You can compute multiple metrics at once as part of an observation. In this example, we're counting the number of rows written as well as tracking the maximum value for id
.
Here's what the observation results look like:
>>> observation.get
{'count': 1000, 'max_id': 999}
Upvotes: 3
Reputation: 12283
After doing some digging I found a way to do it:
@DeveloperApi
in the source) via py4j's callbacksThis is inspired by a post in the cloudera community, I had to port it to a more recent spark version (this uses spark 3.0.1, the answer suggested over there uses the deprecated SQLContext
) and pyspark (using a py4j callback).
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession, DataFrame
class Listener:
def onSuccess(self, funcName, qe, durationNs):
print("success", funcName, durationNs, qe.executedPlan().metrics())
print("rows", qe.executedPlan().metrics().get("numOutputRows").value())
print("files", qe.executedPlan().metrics().get("numFiles").value())
print("bytes", qe.executedPlan().metrics().get("numOutputBytes").value())
def onFailure(self, funcName, qe, exception):
print("failure", funcName, exception, qe.executedPlan().metrics())
class Java:
implements = ["org.apache.spark.sql.util.QueryExecutionListener"]
def run():
spark: SparkSession = SparkSession.builder.getOrCreate()
df: DataFrame = spark.createDataFrame(pd.DataFrame(np.random.randn(20, 3), columns=["foo", "bar", "qux"]))
gateway = spark.sparkContext._gateway
gateway.start_callback_server()
listener = Listener()
spark._jsparkSession.listenerManager().register(listener)
df.write.parquet("/tmp/file.parquet", mode='overwrite')
spark._jsparkSession.listenerManager().unregister(listener)
spark.stop()
spark.sparkContext.stop()
gateway.shutdown()
if __name__ == '__main__':
run()
Upvotes: 4