Reputation: 53
After several tries and some research, I'm stuck on trying to solve the following problem with Spark.
I have a Dataframe of elements with a priority and a quantity.
+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
| f1| elmt 1| 1| 20|
| f1| elmt 2| 2| 40|
| f1| elmt 3| 3| 10|
| f1| elmt 4| 4| 50|
| f1| elmt 5| 5| 40|
| f1| elmt 6| 6| 10|
| f1| elmt 7| 7| 20|
| f1| elmt 8| 8| 10|
+------+-------+--------+---+
I have a fixed limit quantity :
+------+--------+
|family|limitQty|
+------+--------+
| f1| 100|
+------+--------+
I want to mark as "ok" the elements whose the cumulative sum is under the limit. Here is the expected result :
+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
| f1| elmt 1| 1| 20| 1| -> 20 < 100 => ok
| f1| elmt 2| 2| 40| 1| -> 20 + 40 < 100 => ok
| f1| elmt 3| 3| 10| 1| -> 20 + 40 + 10 < 100 => ok
| f1| elmt 4| 4| 50| 0| -> 20 + 40 + 10 + 50 > 100 => ko
| f1| elmt 5| 5| 40| 0| -> 20 + 40 + 10 + 40 > 100 => ko
| f1| elmt 6| 6| 10| 1| -> 20 + 40 + 10 + 10 < 100 => ok
| f1| elmt 7| 7| 20| 1| -> 20 + 40 + 10 + 10 + 20 < 100 => ok
| f1| elmt 8| 8| 10| 0| -> 20 + 40 + 10 + 10 + 20 + 10 > 100 => ko
+------+-------+--------+---+---+
I try to solve if with a cumulative sum :
initDF
.join(limitQtyDF, Seq("family"), "left_outer")
.withColumn("cumulSum", sum($"qty").over(Window.partitionBy("family").orderBy("priority")))
.withColumn("ok", when($"cumulSum" <= $"limitQty", 1).otherwise(0))
.drop("cumulSum", "limitQty")
But it's not enough because the elements after the element that is up to the limit are not take into account. I can't find a way to solve it with Spark. Do you have an idea ?
Here is the corresponding Scala code :
val sparkSession = SparkSession.builder()
.master("local[*]")
.getOrCreate()
import sparkSession.implicits._
val initDF = Seq(
("f1", "elmt 1", 1, 20),
("f1", "elmt 2", 2, 40),
("f1", "elmt 3", 3, 10),
("f1", "elmt 4", 4, 50),
("f1", "elmt 5", 5, 40),
("f1", "elmt 6", 6, 10),
("f1", "elmt 7", 7, 20),
("f1", "elmt 8", 8, 10)
).toDF("family", "element", "priority", "qty")
val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")
val expectedDF = Seq(
("f1", "elmt 1", 1, 20, 1),
("f1", "elmt 2", 2, 40, 1),
("f1", "elmt 3", 3, 10, 1),
("f1", "elmt 4", 4, 50, 0),
("f1", "elmt 5", 5, 40, 0),
("f1", "elmt 6", 6, 10, 1),
("f1", "elmt 7", 7, 20, 1),
("f1", "elmt 8", 8, 10, 0)
).toDF("family", "element", "priority", "qty", "ok").show()
Thank you for your help !
Upvotes: 3
Views: 1124
Reputation: 268
PFA the answer
val initDF = Seq(("f1", "elmt 1", 1, 20),("f1", "elmt 2", 2, 40),("f1", "elmt 3", 3, 10),
("f1", "elmt 4", 4, 50),
("f1", "elmt 5", 5, 40),
("f1", "elmt 6", 6, 10),
("f1", "elmt 7", 7, 20),
("f1", "elmt 8", 8, 10)
).toDF("family", "element", "priority", "qty")
val limitQtyDF = Seq(("f1", 100)).toDF("family", "limitQty")
sc.broadcast(limitQtyDF)
val joinedInitDF=initDF.join(limitQtyDF,Seq("family"),"left")
case class dataResult(family:String,element:String,priority:Int, qty:Int, comutedValue:Int, limitQty:Int,controlOut:String)
val familyIDs=initDF.select("family").distinct.collect.map(_(0).toString).toList
def checkingUDF(inputRows:List[Row])={
var controlVarQty=0
val outputArrayBuffer=collection.mutable.ArrayBuffer[dataResult]()
val setLimit=inputRows.head.getInt(4)
for(inputRow <- inputRows)
{
val currQty=inputRow.getInt(3)
//val outpurForRec=
controlVarQty + currQty match {
case value if value <= setLimit =>
controlVarQty+=currQty
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ok")
case value =>
outputArrayBuffer+=dataResult(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),value,setLimit,"ko")
}
//outputArrayBuffer+=Row(inputRow.getString(0),inputRow.getString(1),inputRow.getInt(2),inputRow.getInt(3),controlVarQty+currQty,setLimit,outpurForRec)
}
outputArrayBuffer.toList
}
val tmpAB=collection.mutable.ArrayBuffer[List[dataResult]]()
for (familyID <- familyIDs) // val familyID="f1"
{
val currentFamily=joinedInitDF.filter(s"family = '${familyID}'").orderBy("element", "priority").collect.toList
tmpAB+=checkingUDF(currentFamily)
}
tmpAB.toSeq.flatMap(x => x).toDF.show(false)
This works for me .
+------+-------+--------+---+------------+--------+----------+
|family|element|priority|qty|comutedValue|limitQty|controlOut|
+------+-------+--------+---+------------+--------+----------+
|f1 |elmt 1 |1 |20 |20 |100 |ok |
|f1 |elmt 2 |2 |40 |60 |100 |ok |
|f1 |elmt 3 |3 |10 |70 |100 |ok |
|f1 |elmt 4 |4 |50 |120 |100 |ko |
|f1 |elmt 5 |5 |40 |110 |100 |ko |
|f1 |elmt 6 |6 |10 |80 |100 |ok |
|f1 |elmt 7 |7 |20 |100 |100 |ok |
|f1 |elmt 8 |8 |10 |110 |100 |ko |
+------+-------+--------+---+------------+--------+----------+
Please do drop unnecessary columns from the output
Upvotes: 0
Reputation: 1
Cumulative sum for each group
from pyspark.sql.window import Window as window
from pyspark.sql.types import IntegerType,StringType,FloatType,StructType,StructField,DateType
schema = StructType() \
.add(StructField("empno",IntegerType(),True)) \
.add(StructField("ename",StringType(),True)) \
.add(StructField("job",StringType(),True)) \
.add(StructField("mgr",StringType(),True)) \
.add(StructField("hiredate",DateType(),True)) \
.add(StructField("sal",FloatType(),True)) \
.add(StructField("comm",StringType(),True)) \
.add(StructField("deptno",IntegerType(),True))
emp = spark.read.csv('data/emp.csv',schema)
dept_partition = window.partitionBy(emp.deptno).orderBy(emp.sal)
emp_win = emp.withColumn("dept_cum_sal",
f.sum(emp.sal).over(dept_partition.rowsBetween(window.unboundedPreceding, window.currentRow)))
emp_win.show()
Results appear like below:
+-----+------+---------+----+----------+------+-------+------+------------
+
|empno| ename| job| mgr| hiredate| sal| comm|deptno|dept_cum_sal|
+-----+------+---------+----+----------+------+-------+------+------------
+
| 7369| SMITH| CLERK|7902|1980-12-17| 800.0| null| 20| 800.0|
| 7876| ADAMS| CLERK|7788|1983-01-12|1100.0| null| 20| 1900.0|
| 7566| JONES| MANAGER|7839|1981-04-02|2975.0| null| 20| 4875.0|
| 7788| SCOTT| ANALYST|7566|1982-12-09|3000.0| null| 20| 7875.0|
| 7902| FORD| ANALYST|7566|1981-12-03|3000.0| null| 20| 10875.0|
| 7934|MILLER| CLERK|7782|1982-01-23|1300.0| null| 10| 1300.0|
| 7782| CLARK| MANAGER|7839|1981-06-09|2450.0| null| 10| 3750.0|
| 7839| KING|PRESIDENT|null|1981-11-17|5000.0| null| 10| 8750.0|
| 7900| JAMES| CLERK|7698|1981-12-03| 950.0| null| 30| 950.0|
| 7521| WARD| SALESMAN|7698|1981-02-22|1250.0| 500.00| 30| 2200.0|
| 7654|MARTIN| SALESMAN|7698|1981-09-28|1250.0|1400.00| 30| 3450.0|
| 7844|TURNER| SALESMAN|7698|1981-09-08|1500.0| 0.00| 30| 4950.0|
| 7499| ALLEN| SALESMAN|7698|1981-02-20|1600.0| 300.00| 30| 6550.0|
| 7698| BLAKE| MANAGER|7839|1981-05-01|2850.0| null| 30| 9400.0|
+-----+------+---------+----+----------+------+-------+------+------------+
Upvotes: 0
Reputation: 3519
I am new to Spark so this solution may not be optimal. I am assuming the value of 100 is an input to the program here. In that case:
case class Frame(family:String, element : String, priority : Int, qty :Int)
import scala.collection.JavaConverters._
val ans = df.as[Frame].toLocalIterator
.asScala
.foldLeft((Seq.empty[Int],0))((acc,a) =>
if(acc._2 + a.qty <= 100) (acc._1 :+ a.priority, acc._2 + a.qty) else acc)._1
df.withColumn("OK" , when($"priority".isin(ans :_*), 1).otherwise(0)).show
results in:
+------+-------+--------+---+--------+
|family|element|priority|qty|OK |
+------+-------+--------+---+--------+
| f1| elmt 1| 1| 20| 1|
| f1| elmt 2| 2| 40| 1|
| f1| elmt 3| 3| 10| 1|
| f1| elmt 4| 4| 50| 0|
| f1| elmt 5| 5| 40| 0|
| f1| elmt 6| 6| 10| 1|
| f1| elmt 7| 7| 20| 1|
| f1| elmt 8| 8| 10| 0|
+------+-------+--------+---+--------+
The idea is simply to get a Scala iterator and extract the participating priority
values from it and then use those values to filter out the participating rows. Given this solution gathers all the data in memory on one machine, it could run into memory problems if the dataframe size is too large to fit in memory.
Upvotes: 0
Reputation: 94
Another way to do it will be an RDD based approach by iterating row by row.
var bufferRow: collection.mutable.Buffer[Row] = collection.mutable.Buffer.empty[Row]
var tempSum: Double = 0
val iterator = df.collect.iterator
while(iterator.hasNext){
val record = iterator.next()
val y = record.getAs[Integer]("qty")
tempSum = tempSum + y
print(record)
if (tempSum <= 100.0 ) {
bufferRow = bufferRow ++ Seq(transformRow(record,1))
}
else{
bufferRow = bufferRow ++ Seq(transformRow(record,0))
tempSum = tempSum - y
}
}
Defining transformRow
function which is used to add a column to a row.
def transformRow(row: Row,flag : Int): Row = Row.fromSeq(row.toSeq ++ Array[Integer](flag))
Next thing to do will be adding an additional column to the schema.
val newSchema = StructType(df.schema.fields ++ Array(StructField("C_Sum", IntegerType, false))
Followed by creating a new dataframe.
val outputdf = spark.createDataFrame(spark.sparkContext.parallelize(bufferRow.toSeq),newSchema)
Output Dataframe :
+------+-------+--------+---+-----+
|family|element|priority|qty|C_Sum|
+------+-------+--------+---+-----+
| f1| elmt1| 1| 20| 1|
| f1| elmt2| 2| 40| 1|
| f1| elmt3| 3| 10| 1|
| f1| elmt4| 4| 50| 0|
| f1| elmt5| 5| 40| 0|
| f1| elmt6| 6| 10| 1|
| f1| elmt7| 7| 20| 1|
| f1| elmt8| 8| 10| 0|
+------+-------+--------+---+-----+
Upvotes: 0
Reputation: 1586
The solution is shown below:
scala> initDF.show
+------+-------+--------+---+
|family|element|priority|qty|
+------+-------+--------+---+
| f1| elmt 1| 1| 20|
| f1| elmt 2| 2| 40|
| f1| elmt 3| 3| 10|
| f1| elmt 4| 4| 50|
| f1| elmt 5| 5| 40|
| f1| elmt 6| 6| 10|
| f1| elmt 7| 7| 20|
| f1| elmt 8| 8| 10|
+------+-------+--------+---+
scala> val df1 = initDF.groupBy("family").agg(collect_list("qty").as("comb_qty"), collect_list("priority").as("comb_prior"), collect_list("element").as("comb_elem"))
df1: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 2 more fields]
scala> df1.show
+------+--------------------+--------------------+--------------------+
|family| comb_qty| comb_prior| comb_elem|
+------+--------------------+--------------------+--------------------+
| f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|
+------+--------------------+--------------------+--------------------+
scala> val df2 = df1.join(limitQtyDF, df1("family") === limitQtyDF("family")).drop(limitQtyDF("family"))
df2: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]
scala> df2.show
+------+--------------------+--------------------+--------------------+--------+
|family| comb_qty| comb_prior| comb_elem|limitQty|
+------+--------------------+--------------------+--------------------+--------+
| f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...| 100|
+------+--------------------+--------------------+--------------------+--------+
scala> def validCheck = (qty: Seq[Int], limit: Int) => {
| var sum = 0
| qty.map(elem => {
| if (elem + sum <= limit) {
| sum = sum + elem
| 1}else{
| 0
| }})}
validCheck: (scala.collection.mutable.Seq[Int], Int) => scala.collection.mutable.Seq[Int]
scala> val newUdf = udf(validCheck)
newUdf: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,ArrayType(IntegerType,false),Some(List(ArrayType(IntegerType,false), IntegerType)))
val df3 = df2.withColumn("valid", newUdf(col("comb_qty"),col("limitQty"))).drop("limitQty")
df3: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 3 more fields]
scala> df3.show
+------+--------------------+--------------------+--------------------+--------------------+
|family| comb_qty| comb_prior| comb_elem| valid|
+------+--------------------+--------------------+--------------------+--------------------+
| f1|[20, 40, 10, 50, ...|[1, 2, 3, 4, 5, 6...|[elmt 1, elmt 2, ...|[1, 1, 1, 0, 0, 1...|
+------+--------------------+--------------------+--------------------+--------------------+
scala> val myUdf = udf((qty: Seq[Int], prior: Seq[Int], elem: Seq[String], valid: Seq[Int]) => {
| elem zip prior zip qty zip valid map{
| case (((a,b),c),d) => (a,b,c,d)}
| }
| )
scala> val df4 = df3.withColumn("combined", myUdf(col("comb_qty"),col("comb_prior"),col("comb_elem"),col("valid")))
df4: org.apache.spark.sql.DataFrame = [family: string, comb_qty: array<int> ... 4 more fields]
scala> val df5 = df4.drop("comb_qty","comb_prior","comb_elem","valid")
df5: org.apache.spark.sql.DataFrame = [family: string, combined: array<struct<_1:string,_2:int,_3:int,_4:int>>]
scala> df5.show(false)
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|family|combined |
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
|f1 |[[elmt 1, 1, 20, 1], [elmt 2, 2, 40, 1], [elmt 3, 3, 10, 1], [elmt 4, 4, 50, 0], [elmt 5, 5, 40, 0], [elmt 6, 6, 10, 1], [elmt 7, 7, 20, 1], [elmt 8, 8, 10, 0]]|
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------+
scala> val df6 = df5.withColumn("combined",explode(col("combined")))
df6: org.apache.spark.sql.DataFrame = [family: string, combined: struct<_1: string, _2: int ... 2 more fields>]
scala> df6.show
+------+------------------+
|family| combined|
+------+------------------+
| f1|[elmt 1, 1, 20, 1]|
| f1|[elmt 2, 2, 40, 1]|
| f1|[elmt 3, 3, 10, 1]|
| f1|[elmt 4, 4, 50, 0]|
| f1|[elmt 5, 5, 40, 0]|
| f1|[elmt 6, 6, 10, 1]|
| f1|[elmt 7, 7, 20, 1]|
| f1|[elmt 8, 8, 10, 0]|
+------+------------------+
scala> val df7 = df6.select("family", "combined._1", "combined._2", "combined._3", "combined._4").withColumnRenamed("_1","element").withColumnRenamed("_2","priority").withColumnRenamed("_3", "qty").withColumnRenamed("_4","ok")
df7: org.apache.spark.sql.DataFrame = [family: string, element: string ... 3 more fields]
scala> df7.show
+------+-------+--------+---+---+
|family|element|priority|qty| ok|
+------+-------+--------+---+---+
| f1| elmt 1| 1| 20| 1|
| f1| elmt 2| 2| 40| 1|
| f1| elmt 3| 3| 10| 1|
| f1| elmt 4| 4| 50| 0|
| f1| elmt 5| 5| 40| 0|
| f1| elmt 6| 6| 10| 1|
| f1| elmt 7| 7| 20| 1|
| f1| elmt 8| 8| 10| 0|
+------+-------+--------+---+---+
Let me know if it helps!!
Upvotes: 1