Amir
Amir

Reputation: 1905

Custom partitioning on JDBC in PySpark

I have a huge table in an oracle database that I want to work on in pyspark. But I want to partition it using a custom query, for example imagine there is a column in the table that contains the user's name, and I want to partition the data based on the first letter of the user's name. Or imagine that each record has a date, and I want to partition it based on the month. And because the table is huge, I absolutely need the data for each partition to be fetched directly by its executor and NOT by the master. So can I do that in pyspark?

P.S.: The reason that I need to control the partitioning, is that I need to perform some aggregations on each partition (partitions have meaning, not just to distribute the data) and so I want them to be on the same machine to avoid any shuffles. Is this possible? or am I wrong about something?

NOTE

I don't care about even or skewed partitioning! I want all the related records (like all the records of a user, or all the records from a city etc.) to be partitioned together, so that they reside on the same machine and I can aggregate them without any shuffling.

Upvotes: 1

Views: 1382

Answers (1)

Amir
Amir

Reputation: 1905

It turned out that the spark has a way of controlling the partitioning logic exactly. And that is the predicates option in spark.read.jdbc.

What I came up with eventually is as follows:

(For the sake of the example, imagine that we have the purchase records of a store, and we need to partition it based on userId and productId so that all the records of an entity is kept together on the same machine, and we can perform aggregations on these entities without shuffling)

  • First, produce the histogram of every column that you want to partition by (count of each value):
userId count
123456 1640
789012 932
345678 1849
901234 11
... ...
productId count
123456789 5435
523485447 254
363478326 2343
326484642 905
... ...
  • Then, use the multifit algorithm to divide the values of each column into n balanced bins (n being the number of partitions that you want).
userId bin
123456 1
789012 1
345678 1
901234 2
... ...
productId bin
123456789 1
523485447 2
363478326 2
326484642 3
... ...
  • Then, store these in the database

  • Then update your query and join on these tables to get the bin numbers for every record:

url = 'jdbc:oracle:thin:username/password@address:port:dbname'

query = ```
(SELECT
  MY_TABLE.*, 
  USER_PARTITION.BIN as USER_BIN, 
  PRODUCT_PARTITION.BIN AS PRODUCT_BIN 
FROM MY_TABLE 
LEFT JOIN USER_PARTITION 
  ON my_table.USER_ID = USER_PARTITION.USER_ID 
LEFT JOIN PRODUCT_PARTITION 
  ON my_table.PRODUCT_ID = PRODUCT_PARTITION.PRODUCT_ID) MY_QUERY```

df = spark.read\
     .option('driver', 'oracle.jdbc.driver.OracleDriver')\
     jdbc(url=url, table=query, predicates=predicates)
  • And finally, generate the predicates. One for each partition, like these:
predicates = [
  'USER_BIN = 1 OR PRODUCT_BIN = 1',
  'USER_BIN = 2 OR PRODUCT_BIN = 2',
  'USER_BIN = 3 OR PRODUCT_BIN = 3',
  ...
  'USER_BIN = n OR PRODUCT_BIN = n',
]

The predicates are added to the query as WHERE clauses, which means that all the records of the users in partition 1 go to the same machine. Also, all the records of the products in partition 1 go to that same machine as well.

Note that there are no relations between the user and the product here. We don't care which products are in which partition or are sent to which machine. But since we want to perform some aggregations on both the users and the products (separately), we need to keep all the records of an entity (user or product) together. And using this method, we can achieve that without any shuffles.

Also, note that if there are some users or products whose records don't fit in the workers' memory, then you need to do a sub-partitioning. Meaning that you should first add a new random numeric column to your data (between 0 and some chunk_size like 10000 or something), then do the partitioning based on the combination of that number and the original IDs (like userId). This causes each entity to be split into fixed-sized chunks (i.e., 10000) to ensure it fits in the workers' memory. And after the aggregations, you need to group your data on the original IDs to aggregate all the chunks together and make each entity whole again.

The shuffle at the end is inevitable because of our memory restriction and the nature of our data, but this is the most efficient way you can achieve the desired results.

Upvotes: 4

Related Questions