akshan
akshan

Reputation: 363

spark custom sort in python

I have a RDD in Spark (python code below):

list1 = [(1,1),(10,100)]
df1 = sc.parallelize(list1)
df1.take(2)
## [(1, 1), (10, 100)]

I want to do a custom sort, that compares these tuples based on both entries in the tuple. In python the logic of this compare is something like:

# THRESH is some constant
def compare_tuple(a, b):
    center = a[0] - b[0]
    dev = a[1] + b[1]
    r = center / dev
    if r < THRESH:
        return -1
    else if r == THRESH:
        return 0
    else:
        return 1

And I would do a custom sort in python as:

list1.sort(compare_tuple)

How to do this in pyspark? As per the rdd docs:

https://spark.apache.org/docs/1.4.1/api/python/pyspark.html#pyspark.RDD

The sortBy method has no custom sort argument.

I see that the scala interface sortBy supports this:

https://spark.apache.org/docs/1.4.1/api/scala/index.html#org.apache.spark.rdd.RDD

But I want this in python spark. Any workaround type solutions are also welcome, thanks!

Upvotes: 3

Views: 1877

Answers (1)

zero323
zero323

Reputation: 330413

You can always create a custom class and implement rich rich comparison methods:

  • pair.py

    class Pair(tuple):
        def _cmp(self, other):
            center = self[0] - other[0]
            dev = self[1] + other[1]
            r = center / dev if dev != 0 else center
            if r < 0:
                return -1
            if r >  0:
                return 1
            return 0
    
        def __lt__(self, other):
            return self._cmp(other) < 0
    
        def __lte__(self, other):
            return self._cmp(other) <= 0
    
        def __eq__(self, other):
            return self._cmp(other) == 0
    
        def __ge__(self, other):
            return self._cmp(other) >= 0
    
        def __gt__(self, other):
            return self._cmp(other) > 0
    
  • main script

    from pair import Pair
    
    sc.addPyFile("pair.py")
    
    rdd = sc.parallelize([(1, 1),(10, 100), (-1, 1), (-1, -0.5)]).map(Pair)
    rdd.sortBy(lambda x: x).collect()
    ## [(-1, 1), (-1, -0.5), (1, 1), (10, 100)]
    

but if dev is a standard deviation then it doesn't affect the outcome and you can safely sort by identity using plain tuples or keyfunc which extracts centers (lambda x x:[0]).

Upvotes: 3

Related Questions