junichiro
junichiro

Reputation: 5492

Spark/PySpark collect_set with a binary column

Some test data, with two columns: the first binary (using alphanumeric bytes in this example), the second an integer:

from pyspark.sql.types import *
from pyspark.sql import functions as F

df = spark.createDataFrame([
    (bytearray(b'0001'), 1),
    (bytearray(b'0001'), 1),
    (bytearray(b'0001'), 2),
    (bytearray(b'0002'), 2)
],
schema=StructType([
    StructField("bin", BinaryType()),
    StructField("number", IntegerType())
]))

Using collect_set to group by the integer column and then remove duplicates doesn't work, because byte arrays don't support hashing. Hence:

(
    df
    .groupBy('number')
    .agg(F.collect_set("bin").alias('bin_array'))
    .show()
)

+------+------------+
|number|   bin_array|
+------+------------+
|     1|[0001, 0001]|
|     2|[0001, 0002]|
+------+------------+

One hacky option is to embed the binary array in a struct and then unwrap them all again afterwards, but I suspect this will result in a huge number of allocations and be very expensive (haven't actually profiled it though):

def unstruct_array(input):
    return [x.bin for x in input]

unstruct_array_udf = F.udf(unstruct_array, ArrayType(BinaryType()))

(
    df
    .withColumn("bin", F.struct("bin"))
    .groupBy('number')
    .agg(F.collect_set("bin").alias('bin_array'))
    .withColumn('bin_array', unstruct_array_udf('bin_array'))
    .show()
)

+------+------------+                                                           
|number|   bin_array|
+------+------------+
|     1|      [0001]|
|     2|[0001, 0002]|
+------+------------+

If I try lots of Google search terms around binary types and Spark, there are various answers that say you should wrap arrays if you need hashing. Suggestions include a custom wrapper or by calling Scala's toSeq that creates a Scala WrappedArray. For instance:

ReduceByKey with a byte array as the key

How to use byte array as key in RDD?

So, options include:

  1. Mapping the underlying RDD to make the binary field a WrappedArray. Not sure how to do that in Python?
  2. Creating a Python wrapper for the array, and then somehow hashing the underlying Java array in Python? Though not sure that has any advantage over using a struct?
  3. I could wrap in a struct and then never unwrap, which would be a bit more efficient processing-wise, but would then presumably make parquet files bigger and more expensive to parse in all downstream tasks

Upvotes: 2

Views: 2213

Answers (1)

Oli
Oli

Reputation: 10406

Here is a hack that will probably be more efficient that wrapping and unwrapping. You could simply call the distinct method beforehand.

df.show()
+-------------+------+
|          bin|number|
+-------------+------+
|[30 30 30 31]|     1|
|[30 30 30 31]|     1|
|[30 30 30 31]|     2|
|[30 30 30 32]|     2|
+-------------+------+

df.distinct().show()
+-------------+------+
|          bin|number|
+-------------+------+
|[30 30 30 31]|     1|
|[30 30 30 31]|     2|
|[30 30 30 32]|     2|
+-------------+------+

Note that I probably do not use the same version of Spark as yours (mine is 2.2.1) for the display of binary arrays seems different.

Then, for the collect_set, it simply boils down to:

df.distinct().groupBy("number").agg(F.collect_set("bin"))

Upvotes: 1

Related Questions