Reputation: 542
I have two PySpark DataFrames like the following:
DataFrame A:
+-----+------+
|nodes|counts|
+-----+------+
| [0]| 1|
| [1]| 0|
| [2]| 1|
| [3]| 0|
| [4]| 0|
| [5]| 0|
| [6]| 1|
| [7]| 0|
| [8]| 0|
| [9]| 0|
| [10]| 0|
And DataFrame B:
+----+------+
|nodes|counts|
+----+------+
|[0] | 1|
|[1] | 0|
|[2] | 3|
|[6] | 0|
|[8] | 2|
+----+------+
I would like create a new DataFrame C such that values in the "counts" column in DataFrame A are summed with the values in the "counts" column of DataFrame B where the "nodes" columns are equal such that DataFrame C looks like:
+-----+------+
|nodes|counts|
+-----+------+
| [0]| 2|
| [1]| 0|
| [2]| 4|
| [3]| 0|
| [4]| 0|
| [5]| 0|
| [6]| 1|
| [7]| 0|
| [8]| 2|
| [9]| 0|
| [10]| 0|
I appreciate the help! I've tried a few different tricks using lambda functions and sql statements and am coming up short on a solution.
Upvotes: 1
Views: 1667
Reputation: 23099
You can join
this two dataframe as below and replace null
with 0
and add two column to get the sum
A.join(B.withColumnRenamed("count", "countB"), Seq("nodes"), "left")
.na.fill(0)
.withColumn("count", $"count" + $"countB")
.drop("countB")
.show(false)
You can also merge those dataframe in single using union
and then groupBy nodes and calculate the sum
as below
A.union(B).groupBy("nodes").agg(sum($"count").alias("count"))
.orderBy("nodes")
.show(false)
This is in scala hope you can write it in pyspark.
Hope this helps!
Upvotes: 0
Reputation: 465
There's probably a more efficient way, but this should work:
import pyspark.sql.functions as func
dfA = spark.createDataFrame([([0], 1),([1], 0),([2], 1),([3], 0), ([4], 0),([5], 0),([6], 1),([7], 0), ([8], 0),([9], 0),([10], 0)], ["nodes", "counts"])
dfB = spark.createDataFrame([([0], 1),([1], 0),([2], 3),([6], 0), ([8], 2)], ["nodes", "counts"])
dfC = dfA.join(dfB, dfA.nodes == dfB.nodes, "left")\
.withColumn("sum",func.when(dfB.nodes.isNull(), dfA.counts).otherwise(dfA.counts+ dfB.counts))\
.select(dfA.nodes.alias("nodes"), func.col("sum").alias("counts"))
dfC.orderBy("nodes").show()
+-----+------+
|nodes|counts|
+-----+------+
| [0]| 2|
| [1]| 0|
| [2]| 4|
| [3]| 0|
| [4]| 0|
| [5]| 0|
| [6]| 1|
| [7]| 0|
| [8]| 2|
| [9]| 0|
| [10]| 0|
+-----+------+
Upvotes: 1