Reputation: 119
I am currently working on a project where I am reading 19 different parquet files and joining on an ID. Some of these files have multiple rows per consumer, some have none.
I have a key file which has 1 column that I join on and another (userName) that I need, and I need all the columns of the other files.
I create a different reader for each parquet file which reads the file and converts it into a spark dataset with a structure like this:
GenericStructure1 record;
int id;
I then join all of these created datasets like this (imagine all 19):
keyDataset.join(dataSet1, dataSet1.col("id").equalTo(keyDataset.col("id")), "left_outer")
.join(dataSet19, dataSet19.col("id").equalTo(keyDataset.col("id")), "left_outer")
.groupBy(keyDataset.col("id"), keyDataset.col("userName"))
.agg(
collect_set(dataSet1.col("record")).as("set1"),
collect_set(dataSet19.col("record")).as("set19")
.select(
keyDataset.col("id"),
keyDataset.col("userName"),
col("set1"),
col("set19")
)
.as(Encoders.bean(Set.class));
where Set.class looks something like this:
public class Set implements Serializable {
long id;
String userName;
List<GenericStructure1> set1;
List<GenericStructure19> set19;
}
This works fine for 100 records, but when I try to ramp up to one part of a 5mm parquet file (something like 75K records), it churns and burns through memory until ultimately it runs out. In production I need for this to be able to run on millions, so the fact that it chokes on 75K is a real problem. The only thing is, I don't see a straightforward way to optimize this so it can handle that kind of workload. Does anybody know of an inexpensive way to join a large amount of data like shown above?
Upvotes: 2
Views: 254
Reputation: 119
I was able to get it to work. In the question, I mention a keyDataset, which has all of the keys possible in all of the different datasets. Instead of trying to join that against all of the other files right out of the gate, I instead broadcast the keyDataset and join against that after creating a generic dataframe for each dataset.
Dataset<Row> set1RowDataset = set1Dataset
.groupBy(keyDataset.col(joinColumn))
.agg(collect_set(set1Dataset.col("record")).as("set1s"))
.select(
keyDataset.col("id"),
col("set1"));
Once I create 19 of those, I then join the generic datasets in their own join like so:
broadcast(set1RowDataset)
.join(set2RowDataset, "id")
.join(set3RowDataset, "id")
.join(set4RowDataset, "id")
.join(set19RowDataset, "id")
.as(Encoders.bean(Set.class));
Performance-wise, I'm not sure how much of a hit I'm taking by doing the groupBy separately from the join, but my memory remains intact and Spark no longer spills so badly to disk during the shuffle. I was able to run this on one part locally which was failing before as I mentioned above. I haven't tried it yet on the cluster with the full parquet file, but that's my next step.
I used this as my example: Broadcast Example
Upvotes: 2