Reputation: 732
I have the following situation which I've solved in a very inefficient way:
I have a dataframe called dfUniques
where each row contains a different value from the others (e.g.: 1K rows, but it could be much more, or even less than 100 rows). And a dataframe called dfFull
, where it can be found in some of its rows the same values that are present in dfUniques
. The dfFull
is much bigger than dfUniques
, and also it has 3 times the number of columns than dfUniques
. What I want to do is to find all rows in dfFull
where the columns in common with dfUniques
have the same values that a row in it. Because the objective is count how many rows from the dfUniques
are in dfFull
.
The way I've implemented is wrong (I think) because it takes a lot of time, and also I'm using the collect()
call (and I know it is not the best when data become big). This is my code:
dfUniques.collect().foreach{
row => {
val singlerowRDD = spark.createDataFrame(spark.sparkContext.parallelize(Seq(row)), myschema)
val matching = dfFull
.join(singlerow, columnsInCommon)
.select(selColumns.head, selColumns.tail: _*)
val matchingCount = matching.count()
println("instances that matched\t" + matchingCount)
if (matchingRBCount > 0){
val dfAggr = matching.groupBy("name").avg(selColumns: _*)
resultData = resultData.union(dfAggr)
}
}
}
I think a good approach should be using some join
, but I cannot find which is the best to do what I want. Any suggestion?
I've found this (https://stackoverflow.com/a/51679966/5081366) but it is not for my case, because the post try to join each row of a dataframe with all rows of another dataframe, but I want to obtain just the rows that match with each row of dfUniques
. Well, I hope to be clear.
Upvotes: 1
Views: 231
Reputation: 1712
You are right, join is the best way to go. In your case ,'left semi' will be applicable. You can also read various types of spark joins from here - https://spark.apache.org/docs/2.4.0/api/python/pyspark.sql.html?highlight=join#pyspark.sql.DataFrame.join
tst= sqlContext.createDataFrame([(1,2),(1,3),(9,9),(2,4),(2,10),(3,5),(10,9),(3,6),(3,8),(7,9),(4,5),(19,1),(20,4),(22,3),(30,5),(67,4)],schema=['a','b'])
tst1 = sqlContext.createDataFrame([(1,2),(2,5),(7,6)],schema=['a','c'])
tst_res= tst.join(tst1,on='a',how='left_semi')
tst_res.show()
+---+---+
| a| b|
+---+---+
| 1| 2|
| 1| 3|
| 2| 4|
| 2| 10|
| 7| 9|
+---+---+
Upvotes: 1