nmr
nmr

Reputation: 753

generate queries for each key in pyspark data frame

I have a data frame in pyspark like below

df = spark.createDataFrame(
[
('2021-10-01','A',25),
('2021-10-02','B',24),
('2021-10-03','C',20),
('2021-10-04','D',21),
('2021-10-05','E',20),
('2021-10-06','F',22),
('2021-10-07','G',23),
('2021-10-08','H',24)],("RUN_DATE", "NAME", "VALUE"))

Now using this data frame I want to update a table in MySql

# query to run should be similar to this
update_query = "UPDATE DB.TABLE SET DATE = '2021-10-01', VALUE = 25 WHERE NAME = 'A'"

# mysql_conn is a function which I use to connect to `MySql`  from `pyspark` and run queries
# Invoking the function 
mysql_conn(host, user_name, password, update_query)

Now when I invoke the mysql_conn function by passing parameters the query runs successfully and the record gets updated in the MySql table.

Now I want to run the update statement for all the records in the data frame.

For each NAME it has to pick the RUN_DATE and VALUE and replace in update_query and trigger the mysql_conn.

I think we need to a for loop but not sure how to proceed.

Upvotes: 0

Views: 222

Answers (1)

ggordon
ggordon

Reputation: 10035

Instead of iterating through the dataframe with a for loop, it would be better to distribute the workload across each partitions using foreachPartition. Moreover, since you are writing a custom query instead of executing one query for each query, it would be more efficient to execute a batch operation to reduce the round trips, latency and concurrent connections. Eg

def update_db(rows):
    temp_table_query=""
    for row in rows:
        if len(temp_table_query) > 0:
            temp_table_query = temp_table_query + " UNION ALL "
        temp_table_query = temp_table_query + " SELECT '%s' as RUNDATE, '%s' as NAME, %d as VALUE " % (row.RUN_DATE,row.NAME,row.VALUE)
  
    update_query="""
        UPDATE DBTABLE 
        INNER JOIN (
            %s
        ) new_records ON DBTABLE.NAME = new_records.NAME
        SET 
            DBTABLE.DATE = new_records.RUNDATE, 
            DBTABLE.VALUE = new_records.VALUE 
    """ % (temp_table_query)
    mysql_conn(host, user_name, password, update_query)
    

df.foreachPartition(update_db)

View Demo on how the UPDATE query works

Let me know if this works for you.

Upvotes: 1

Related Questions