Noob
Noob

Reputation: 67

Scala spark query optimization

I have two dataframes that have 300 columns and 1000 rows each. They have the same column names. The values are of mixed datatypes like Struct/List/Timestamp/String/etc. I am trying to compare the column values for each row, I notice that the query is running for a long time. Is there a way to optimize this?

def compareDatasets(ds1: Dataset[Row], ds2: Dataset[Row]): Dataset[Row] = {
      
    val attributeSet = ds1.columns
    val distinctAsins = ds1.select("item_id").distinct()
      
    val partitionedDs1 = ds1.repartition($"item_id")
    partitionedDs1.cache()
    
    val partitionedDs2 = ds2.repartition($"item_id")
    partitionedDs2.cache()
    
    distinctAsins.collect().foreach { row =>
      val asin = row.get(0)
      
      val ds1Rows = partitionedDs1.filter(col("item_id") === asin)
      val cachedDs1Rows = ds1Rows.cache()
      
      val cipRows = partitionedDs2.filter(col("item_id") === asin)
      val cachedDs2Rows = cipRows.cache()
        attributeSet.foreach { attr => 
        
          val areColumnsEqual: Boolean = cachedDs1Rows.filter(cachedDs1Rows(attr) === cachedDs2Rows(attr)).count() > 0 //<------ trying to optimize this
          println("parsing item_id: " + asin + " attribute: " + attr + " areColumnsEqual: "  + areColumnsEqual)
            
        }
        cachedDs2Rows.unpersist()
        cachedDs1Rows.unpersist()
      }
    partitionedDs1.unpersist()
    partitionedDs2.unpersist()
      
    ds1  
}

Upvotes: 0

Views: 39

Answers (1)

user238607
user238607

Reputation: 2468

You can use Dataframe subtract.

https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.subtract.html

from pyspark.sql import SparkSession
from pyspark.sql import Row

spark = SparkSession.builder.appName("example").getOrCreate()

df1 = spark.createDataFrame([Row(a=1, b=2), Row(a=2, b=3), Row(a=3, b=4)])
df2 = spark.createDataFrame([Row(a=2, b=3), Row(a=4, b=5)])

df_subtract12 = df1.subtract(df2)
df_subtract21 = df2.subtract(df1)


df_subtract12.show()
df_subtract21.show()

Output :

+---+---+
|  a|  b|
+---+---+
|  3|  4|
|  1|  2|
+---+---+

+---+---+
|  a|  b|
+---+---+
|  4|  5|
+---+---+

Upvotes: 1

Related Questions