Reputation: 2020
I need to iterate over data frame in specific order and apply some complex logic to calculate new column.
In below example I'll be using simple expression where current value for s
is multiplication of all previous values thus it may seem like this can be done using UDF or even analytic functions. However, in reality logic is much more complex.
Below code does what is needed
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
val q = """
select 10 x, 1 y
union all select 10, 2
union all select 10, 3
union all select 20, 6
union all select 20, 4
union all select 20, 5
val df = spark.sql(q)
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
iter.scanLeft(Row(0,0,1)) {
case (r1, r2) => {
val (x1, y1, s1) = r1 match {case Row(x: Int, y: Int, s: Int) => (x, y, s)}
val (x2, y2) = r2 match {case Row(x: Int, y: Int) => (x, y)}
Row(x2, y2, s1 * y2)
val schema = new StructType().
add(StructField("x", IntegerType, true)).
add(StructField("y", IntegerType, true)).
add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
scala> df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
| x| y| s|
| 20| 4| 4|
| 20| 5| 20|
| 20| 6|120|
| 10| 1| 1|
| 10| 2| 2|
| 10| 3| 6|
What I do not like about it is
1) I explicitly define schema even though Spark can infer names and types for data frame
scala> df
res1: org.apache.spark.sql.DataFrame = [x: int, y: int]
2) If I add any new column to data frame then I have to declare schema again and what is more annoying - re-define function!
Assume there is new column z
in data frame. In this case I have to change almost every line in f_row
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
iter.scanLeft(Row(0,0,"",1)) {
case (r1, r2) => {
val (x1, y1, z1, s1) = r1 match {case Row(x: Int, y: Int, z: String, s: Int) => (x, y, z, s)}
val (x2, y2, z2) = r2 match {case Row(x: Int, y: Int, z: String) => (x, y, z)}
Row(x2, y2, z2, s1 * y2)
val schema = new StructType().
add(StructField("x", IntegerType, true)).
add(StructField("y", IntegerType, true)).
add(StructField("z", StringType, true)).
add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
scala> df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
| x| y| z| s|
| 20| 4|dummy| 4|
| 20| 5|dummy| 20|
| 20| 6|dummy|120|
| 10| 1|dummy| 1|
| 10| 2|dummy| 2|
| 10| 3|dummy| 6|
Is there a way to implement logic in more generic way so I do not need to create function to iterate over every specific data frame?
Or at least to avoid code changes after adding new columns into data frame which are not used in calculation logic.
Please see updated question below.
Below are two options to iterate in more generic way but still with some drawbacks.
// option 1
def f_row(iter: Iterator[Row]): Iterator[Row] = {
val r = Row.fromSeq(Row(0, 0).toSeq :+ 1)
iter.scanLeft(r)((r1, r2) =>
Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y")))
// option 2
def f_row(iter: Iterator[Row]): Iterator[Row] = {{
var s = 1
r => {
s = s * r.getInt(r.fieldIndex("y"))
Row.fromSeq(r.toSeq :+ s)
If a new column added to data frame then initial value for iter.scanLeft has to be changed in Option 1. Also I do not really like Option 2 because it uses mutable var.
Is there a way to improve the code so it's purely functional and no changes are needed when new column added to the data frame?
Upvotes: 1
Views: 2093
Reputation: 2020
Well, sufficient solution is below
def f_row(iter: Iterator[Row]): Iterator[Row] = {
if (iter.hasNext) {
val head =
val r = Row.fromSeq(head.toSeq :+ head.getInt(head.fieldIndex("y")))
iter.scanLeft(r)((r1, r2) =>
Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y"))))
} else iter
val encoder =
RowEncoder(StructType(df.schema.fields :+ StructField("s", IntegerType, false)))
Functions like getInt can be avoided in favor of more generic getAs
Also, in order to be able to access rows of r1
by name we can generate GenericRowWithSchema
which is subclass of Row
Implicit parameter has been added to f_row
so that function can use current schema of the data frame and in the same time it can be used as a parameter of the mapPartitions
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.encoders.RowEncoder
implicit val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
implicit val encoder = RowEncoder(schema)
def mul(x1: Int, x2: Int) = x1 * x2;
def f_row(iter: Iterator[Row])(implicit currentSchema : StructType) : Iterator[Row] = {
if (iter.hasNext) {
val head =
val r =
new GenericRowWithSchema((head.toSeq :+ (head.getAs("y"))).toArray, currentSchema)
iter.scanLeft(r)((r1, r2) =>
new GenericRowWithSchema((r2.toSeq :+ mul(r1.getAs("result"), r2.getAs("y"))).toArray, currentSchema))
} else iter
Finally, logic can be implemented in a tail recursive manner.
import scala.annotation.tailrec
def f_row(iter: Iterator[Row]) = {
def f_row_(iter: Iterator[Row], tmp: Int, result: Iterator[Row]): Iterator[Row] = {
if (iter.hasNext) {
val r =
f_row_(iter, mul(tmp, r.getAs("y")),
result ++ Iterator(Row.fromSeq(r.toSeq :+ mul(tmp, r.getAs("y")))))
} else result
f_row_(iter, 1, Iterator[Row]())
Upvotes: 1