WouterSterkens
WouterSterkens

Reputation: 17

Create SCD2 table from sourcefile that contains multiple updates for one id using Databricks/Spark

I want to make a slowly changing dimension in databricks. My source dataframe contains the following information.

+-------------------+-------------------------+----------+-----------+-------------+
| actionimmediately |          date           | deviceid | patchguid |   status    |
+-------------------+-------------------------+----------+-----------+-------------+
| False             | 2018-08-15 04:01:00.000 |      123 | 00-001    | Install     |
| True              | 2018-08-16 00:00:00.000 |      123 | 00-001    | Install     |
| False             | 2018-08-10 01:00:00.000 |      123 | 00-001    | Not Approved|
| False             | 2020-01-01 00:00:00.000 |      333 | 11-111    | Declined    |
+-------------------+-------------------------+----------+-----------+-------------+

The dataframe I want as output looks like:

+-----------+----------+-----------+--------------+-------------------+-------------------------+-------------------------+---------+
| mergekey  | deviceid | patchguid |    status    | actionimmediately |        starttime        |         endtime         | current |
+-----------+----------+-----------+--------------+-------------------+-------------------------+-------------------------+---------+
| 12300-001 |      123 | 00-001    | Not Approved | False             | 2018-08-10 01:00:00.000 | 2018-08-15 04:01:00.000 | False   |
| 12300-001 |      123 | 00-001    | Install      | False             | 2018-08-15 04:01:00.000 | 2018-08-16 00:00:00.000 | False   |
| 12300-001 |      123 | 00-001    | Install      | True              | 2018-08-16 00:00:00.000 | null                    | True    |
| 33311-111 |      333 | 11-111    | Declined     | False             | 2020-01-01 00:00:00.000 | null                    | True    |
+-----------+----------+-----------+--------------+-------------------+-------------------------+-------------------------+---------+

In reality the sourcefile contains 275475 rows

I tried 2 solutions already but both are peforming very slow. Like +-10h.

Solution 1: Using Delta Lake merge

First I create a seqId I use later to iterate. This because merge can not update the same row multiple times. I'm creating the seqId using a window.

source_df = source_df.withColumn('mergekey',concat(col('deviceid'),col('patchguid')))
w1 = Window.partitionBy('mergekey').orderBy('date')
source_df = source_df.withColumn('seqid', row_number().over(w1))

Than I create a for loop that runs over each seqId and merges the the rows. In reality, max_seq_id is 1900

def createTable (df, SeqId):
  df\
  .withColumn('mergekey',concat(col('deviceid'),col('patchguid')))\
  .select(\
          'mergekey',\
          'deviceid',\
          'patchguid',\
          'status',\
          'actionimmediately',\
          col('date').alias('starttime'))\
  .where(col('seqid') == SeqId)\
  .withColumn('endtime',lit(None).cast('timestamp'))\
  .withColumn('current',lit(True))\
  .write.format('delta')\
  .partitionBy("current")\
  .options(header='true',path='/mnt/destinationncentral/patch_approval')\
  .saveAsTable('patch_approval')

def MergePatchApproval (df,deltatable,seqNummer):
  dataframe = df.where(col('seqid') == seqNummer)
  newToInsert = dataframe.alias('updates')\
  .join(deltatable.toDF().alias('table'),['deviceid','patchguid'])\
  .select(\
          'updates.actionimmediately',\
          'updates.date',\
          'updates.deviceid',\
          'updates.patchguid',\
          'updates.status',\
          'updates.seqid')\
  .where('table.current = true and \
  (table.status <> updates.status or table.actionimmediately <> updates.actionimmediately)')

  stagedUpdates = (newToInsert.selectExpr('NULL as mergekey','*')\
                   .union(dataframe\
                          .withColumn('mergekey',concat(col('deviceid'),col('patchguid')))\
                          .select(\
                                  'mergekey',\
                                  'actionimmediately',\
                                  'date',\
                                  'deviceid',\
                                  'patchguid',\
                                  'status',\
                                  'seqid')))

  deltatable.alias('t')\
  .merge(stagedUpdates.alias('s'),'t.current = true and t.mergekey = s.mergekey')\
  .whenMatchedUpdate(condition = 't.current = true and \
  (t.status <> s.status or t.actionimmediately <> s.actionimmediately)', \
  set = {
    'endtime':'s.date',
    'current':'false'
  }).whenNotMatchedInsert(values = {
    'mergekey':'s.mergekey',
    'deviceid':'s.deviceid',
    'patchguid':'s.patchguid',
    'status':'s.status',
    'actionimmediately':'s.actionimmediately',
    'starttime':'s.date',
    'endtime':'NULL',
    'current':'true'
  }).execute()

for i in range(max_seq_id):
  i = i + 1
  print(i)
  df = source_df.where(col('seqid') == i)
  if(i == 1):
    tablecount = spark.sql("show tables like 'patch_approval'").count()
    if(tablecount == 0):
      createTable(df,i)
      approval_table = DeltaTable.forPath(spark,'/mnt/destinationncentral/patch_approval')
    else:
      approval_table = DeltaTable.forPath(spark,'/mnt/destinationncentral/patch_approval')
      MergePatchApproval(df,approval_table,i)
  else:
    MergePatchApproval(df,approval_table,i) 

The problem I have with this solution is that the time to write the data on azure data lake is taking some time which is normal I think but also the execution time for each iteration is increasing.

Solution 2: Upsert the dataframes and write one time at the end

In this solution I also use the for loop and seqId but istead of writing every loop to azure data lake I only do it at the end. This solution solves the write latency issue but the time for each loop to end is still increasing.

def createDestDF(sourceDF):
  dest_df = sourceDF\
    .select(\
            'mergekey',\
            'deviceid',\
            'patchguid',\
            'status',\
            'actionimmediately',\
            col('date').alias('starttime'))\
    .withColumn('endtime',lit(None).cast('timestamp'))\
    .withColumn('current',lit(True))
  return dest_df

def getChangedRecords(sourceDF,destDF):
  changedRecords = sourceDF.alias('u')\
  .join(destDF.alias('t'),['deviceid','patchguid'])\
  .select(\
         'u.actionimmediately',\
         'u.date',\
         'u.deviceid',\
         'u.patchguid',\
         'u.status',\
         'u.seqid',\
         'u.mergekey')\
  .where('t.current = true and \
  (t.status <> u.status or t.actionimmediately <> u.actionimmediately)')

  return changedRecords

def getNewRecords(sourceDF,destDF):
  newRecords = sourceDF.alias('n')\
  .join(destDF.alias('t'),['deviceid','patchguid'],'left')\
  .select(\
          't.mergekey',\
          'n.actionimmediately',\
          'n.date',\
          'deviceid',\
          'patchguid',\
          'n.status',\
          'n.seqid')\
  .where('t.current is null')
  return newRecords

def upsertChangedRecords(sourceDF,destDF):
  endTimeColumn = expr("""IF(endtimeOld IS NULL, date, endtimeOld)""")
  currentColumn = expr("""IF(date IS NULL, currentOld, False)""")

  updateDF = sourceDF.alias('s').join(destDF.alias('t'),'mergekey','right').select(\
                                                                                'mergekey',\
                                                                                't.deviceid',\
                                                                                't.patchguid',\
                                                                                't.status',\
                                                                                't.actionimmediately',\
                                                                                't.starttime',\
                                                                                's.date',\
                                                                                col('t.current').alias('currentOld'),\
                                                                                col('t.endTime').alias('endtimeOld'))\
  .withColumn('endtime',endTimeColumn)\
  .withColumn('current',currentColumn)\
  .drop('currentOld','date','endTimeOld')

  updateInsertDF = sourceDF\
  .select(\
          'mergekey',\
          'deviceid',\
          'patchguid',\
          'status',\
          'actionimmediately',\
          col('date').alias('starttime'))\
  .withColumn('endtime',lit(None).cast('timestamp'))\
  .withColumn('current',lit(True))

  resultDF = updateDF.union(updateInsertDF)
  return resultDF

def insertNewRecords(sourceDF, destDF):
  insertDF = sourceDF\
  .select(\
          'mergekey',\
          'deviceid',\
          'patchguid',\
          'status',\
          'actionimmediately',\
          col('date').alias('starttime'))\
  .withColumn('endtime',lit(None).cast('timestamp'))\
  .withColumn('current',lit(True))

  resultDF = destDF.union(insertDF)

  return resultDF

for i in range(max_seq_id):
  i = i + 1
  print(i)
  seq_df = source_df.where(col('seqid') == i)
  if i == 1:
    tablecount = spark.sql("show tables like 'patch_approval'").count()
    if(tablecount == 0):
      dest_df = createDestDF(seq_df)
    else:
      changed_df = getChangedRecords(seq_df,dest_df)
      new_df = getNewRecords(seq_df,dest_df)
      dest_df = upsertChangedRecords(changed_df,dest_df)
      dest_df = insertNewRecords(new_df,dest_df)
  else:
    changed_df = getChangedRecords(seq_df,dest_df)
    new_df = getNewRecords(seq_df,dest_df)
    dest_df = upsertChangedRecords(changed_df,dest_df)
    dest_df = insertNewRecords(new_df,dest_df)

dest_df\
.write\
.format('delta')\
.partitionBy('current')\
.mode('overwrite')\
.options(header='true',path='/mnt/destinationncentral/patch_approval')\
.saveAsTable('patch_approval')

Any Idea how I can solve the increasing execution time in the for loop?

Kind regards,

Upvotes: 1

Views: 1456

Answers (1)

matkurek
matkurek

Reputation: 771

From what I understand rows don't disappear from your source table as time passes - if so, your problem can be solved by putting your spark dataframe into temporary view and writing a query against it:

df.createOrReplaceTempView("source")

df_scd = spark.sql("""
WITH stage AS (
  SELECT *,
  LEAD(date,1) OVER (PARTITION BY deviceid, patchguid ORDER BY date) AS next_date
  FROM source
)
SELECT 
  concat(deviceid, patchguid) as mergekey
  ,deviceid
  ,patchguid
  ,status
  ,actionimmediately
  ,date AS starttime
  ,next_date AS endtime
  ,CASE WHEN next_date IS NULL THEN True ELSE False END AS current
FROM stage
""")

It should be very fast and result in the exact output you want. I checked that on your sample data, and df_scd after that shows:

+---------+--------+---------+------------+-----------------+--------------------+--------------------+-------+
| mergekey|deviceid|patchguid|      status|actionimmediately|           starttime|             endtime|current|
+---------+--------+---------+------------+-----------------+--------------------+--------------------+-------+
|12300-001|     123|   00-001|Not Approved|            False|2018-08-10 01:00:...|2018-08-15 04:01:...|  false|
|12300-001|     123|   00-001|     Install|            False|2018-08-15 04:01:...|2018-08-16 00:00:...|  false|
|12300-001|     123|   00-001|     Install|             True|2018-08-16 00:00:...|                null|   true|
|33311-111|     333|   11-111|    Declined|            False|2020-01-01 00:00:...|                null|   true|
+---------+--------+---------+------------+-----------------+--------------------+--------------------+-------+

Upvotes: 1

Related Questions