Reputation: 31
I want to do a scanLeft type of operation on one column of a dataframe. Scanleft is not pararellizable, but in my case I only want to apply this function to the elements that are already in the same partition. Therefore operation can be exectued in parallel in each partition. (No data shuffling)
Consider following example:
| partitionKey | orderColumn | value | scanLeft(0)(_+_) |
|-------------- |------------- |------- |------------------ |
| 1 | 1 | 1 | 1 |
| 1 | 2 | 2 | 3 |
| 2 | 1 | 3 | 3 |
| 2 | 2 | 4 | 7 |
| 1 | 3 | 5 | 8 |
| 2 | 3 | 6 | 13 |
I want to scanLeft the values within the same partition, and create a new column to store the result.
My code for now would look sth like this:
inDataframe
.repartition(col("partitionKey"))
.foreachPartition{
partition =>
partition.map(row => row(1).asInstanceOf[Double])
.scanLeft(0.0)(_+_)
.foreach(println(_))
})
This aggregates the values as I want and prints out the result, however I want to add these values as a new column of dataframe
Any idea of how to do it?
----edit---- The real use case is to calculate time-weighted rate of return (https://www.investopedia.com/terms/t/time-weightedror.asp) Expected input look sth like this:
| product | valuation date | daily return |
|--------- |---------------- |-------------- |
| 1 | 2019-01-01 | 0.1 |
| 1 | 2019-01-02 | 0.2 |
| 1 | 2019-01-03 | 0.3 |
| 2 | 2019-01-01 | 0.4 |
| 2 | 2019-01-02 | 0.5 |
| 2 | 2019-01-03 | 0.6 |
I want to calculate the cumulated return per product for all dates until the current one. Dataframe is partitioned by product, and partitions are ordered by valuation date. I already wrote the aggregation fuction to pass into scanLeft:
def chain_ret (x: Double, y: Double): Double = {
(1 + x) * (1 + y) - 1
}
Expected return data:
| product | valuation date | daily return | cumulated return |
|--------- |---------------- |-------------- |------------------ |
| 1 | 2019-01-01 | 0.1 | 0.1 |
| 1 | 2019-01-02 | 0.2 | 0.32 |
| 1 | 2019-01-03 | 0.3 | 0.716 |
| 2 | 2019-01-01 | 0.4 | 0.4 |
| 2 | 2019-01-02 | 0.5 | 1.1 |
| 2 | 2019-01-03 | 0.6 | 2.36 |
I already solved this issue, by filtering dataframe for given range of dates and applying and UDAF to it. (look below) It is very long and I think with scanLeft it will be much faster!
while(endPeriod.isBefore(end)) {
val filtered = inDataframe
.where("VALUATION_DATE >= '" + start + "' AND VALUATION_DATE <= '" + endPeriod + "'")
val aggregated = aggregate_returns(filtered)
.withColumn("VALUATION_DATE", lit(Timestamp.from(endPeriod)).cast(TimestampType))
df_ret = df_ret.union(aggregated)
endPeriod = endPeriod.plus(1, ChronoUnit.DAYS)
}
def aggregate_returns(inDataframe: DataFrame): DataFrame = {
val groupedByKey = inDataframe
.groupBy("product")
groupedByKey
.agg(
returnChain(col("RETURN_LOCAL")).as("RETURN_LOCAL_CUMUL"),
returnChain(col("RETURN_FX")).as("RETURN_FX_CUMUL"),
returnChain(col("RETURN_CROSS")).as("RETURN_CROSS_CUMUL"),
returnChain(col("RETURN")).as("RETURN_CUMUL")
)
class ReturnChain extends UserDefinedAggregateFunction{
// Defind the schema of the input data
override def inputSchema: StructType =
StructType(StructField("return", DoubleType) :: Nil)
// Define how the aggregates types will be
override def bufferSchema: StructType = StructType(
StructField("product", DoubleType) :: Nil
)
// define the return type
override def dataType: DataType = DoubleType
// Does the function return the same value for the same input?
override def deterministic: Boolean = true
// Initial values
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.toDouble
}
// Updated based on Input
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = (1.toDouble + buffer.getAs[Double](0)) * (1.toDouble + input.getAs[Double](0))
}
// Merge two schemas
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
}
// Output
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0) - 1.toDouble
}
}
Upvotes: 1
Views: 350
Reputation: 710
foreachPartition doen't return anything, you need to use .mapPartition() instead
The difference between foreachPartition and mapPartition is the same as that between map and foreach. Look here for good explanations Foreach vs Map in Scala
Upvotes: 1