Reputation: 5492
Some test data, with two columns: the first binary (using alphanumeric bytes in this example), the second an integer:
from pyspark.sql.types import *
from pyspark.sql import functions as F
df = spark.createDataFrame([
(bytearray(b'0001'), 1),
(bytearray(b'0001'), 1),
(bytearray(b'0001'), 2),
(bytearray(b'0002'), 2)
],
schema=StructType([
StructField("bin", BinaryType()),
StructField("number", IntegerType())
]))
Using collect_set to group by the integer column and then remove duplicates doesn't work, because byte arrays don't support hashing. Hence:
(
df
.groupBy('number')
.agg(F.collect_set("bin").alias('bin_array'))
.show()
)
+------+------------+
|number| bin_array|
+------+------------+
| 1|[0001, 0001]|
| 2|[0001, 0002]|
+------+------------+
One hacky option is to embed the binary array in a struct and then unwrap them all again afterwards, but I suspect this will result in a huge number of allocations and be very expensive (haven't actually profiled it though):
def unstruct_array(input):
return [x.bin for x in input]
unstruct_array_udf = F.udf(unstruct_array, ArrayType(BinaryType()))
(
df
.withColumn("bin", F.struct("bin"))
.groupBy('number')
.agg(F.collect_set("bin").alias('bin_array'))
.withColumn('bin_array', unstruct_array_udf('bin_array'))
.show()
)
+------+------------+
|number| bin_array|
+------+------------+
| 1| [0001]|
| 2|[0001, 0002]|
+------+------------+
If I try lots of Google search terms around binary types and Spark, there are various answers that say you should wrap arrays if you need hashing. Suggestions include a custom wrapper or by calling Scala's toSeq that creates a Scala WrappedArray. For instance:
ReduceByKey with a byte array as the key
How to use byte array as key in RDD?
So, options include:
Upvotes: 2
Views: 2213
Reputation: 10406
Here is a hack that will probably be more efficient that wrapping and unwrapping. You could simply call the distinct
method beforehand.
df.show()
+-------------+------+
| bin|number|
+-------------+------+
|[30 30 30 31]| 1|
|[30 30 30 31]| 1|
|[30 30 30 31]| 2|
|[30 30 30 32]| 2|
+-------------+------+
df.distinct().show()
+-------------+------+
| bin|number|
+-------------+------+
|[30 30 30 31]| 1|
|[30 30 30 31]| 2|
|[30 30 30 32]| 2|
+-------------+------+
Note that I probably do not use the same version of Spark as yours (mine is 2.2.1) for the display of binary arrays seems different.
Then, for the collect_set
, it simply boils down to:
df.distinct().groupBy("number").agg(F.collect_set("bin"))
Upvotes: 1