Usman Khan
Usman Khan

Reputation: 147

How to avoid cross-join to find pairwise distance between each two rows in Spark dataframe

I have a spark data frame with 3 columns that indicate positions of atoms i-e Position X, Y & Z. Now to find the distance between every 2 atoms for which I need to apply distance formula. The distance formula is d= sqrt((x2−x1)^2+(y2−y1)^2+(z2-z1)^2). For the smaller dataset, I was recommended cross-join but for the larger dataset that is very inefficient and time-consuming. Currently, I am using the following piece of code.

df = atomsDF.withColumn("id", F.monotonically_increasing_id())
windowSpec = W.orderBy("id")
df = df.withColumn("id", F.row_number().over(windowSpec))
df_1 = df.select(*(F.col(col).alias("{}_1".format(col)) for col in df.columns))
df_3 = df_1.crossJoin(df).where("id_1 != id")

df_3 = df_3.withColumn(
        "Distance",
        F.sqrt(
            F.pow(df_3["Position_X_1"] - df_3["Position_X"], F.lit(2))
            + F.pow(df_3["Position_Y_1"] - df_3["Position_Y"], F.lit(2))
            + F.pow(df_3["Position_Z_1"] - df_3["Position_Z"], F.lit(2))
        )
    )

My Dataframe look like the following:

Position_X|Position_Y|Position_Z|
+----------+----------+----------+
|    27.545|     6.743|    12.111|
|    27.708|     7.543|    13.332|
|    27.640|     9.039|    12.970|
|    26.991|     9.793|    13.693|
|    29.016|     7.166|    14.106|
|    29.286|     8.104|    15.273|
|    28.977|     5.725|    14.603|
|    28.267|     9.456|    11.844|
|    28.290|    10.849|    11.372|
|    26.869|    11.393|    11.161|
+----------+----------+----------+

Now how can I avoid cross-join because the number of rows grows exponentially after the cross-join? For example, just for the dataset with 3000 rows after the cross join the total number of rows grow to 3000 * 2999 = 8997000 which make it very time-consuming. Any other Efficient way of finding the pairwise distance between every two rows?

Upvotes: 0

Views: 1343

Answers (1)

Michael Entin
Michael Entin

Reputation: 7744

You say that you need to find the distance between every 2 atoms. For this, since the result size is N^2, the run time is by definition quadratic. You can optimize it somewhat, but it will still be quadratic.

You can optimize it only if you don't actually need to find N^2 distances between every 2 atoms, but need to find only pairs based on some criteria.

E.g. commonly one needs to find pairs that are closer than some threshold distance - for this R-trees provide much better scalability. In Spark it might be easier to split atoms into grid of cubes of size equal to threshold distance, then you only need to cross-join between atom and atoms in same or neighboring cubes.

Upvotes: 1

Related Questions