Alon
Alon

Reputation: 11935

Assert RDD is not sorted

I have a method called split that accepts an RDD[T] and a splitSize and returns an Array[RDD[T]].

Now, one of the test cases I write for it should verify that this function also randomly shuffles the RDD.

So I create a sorted RDD, and then see the results:

  it should "randomize shuffle" in {
    val inputRDD = sc.parallelize((0 until 16))
    val result = RDDUtils.split(inputRDD, 2)

    result.foreach(rdd => {
      rdd.collect.foreach(println)
    })

    // Asset result is not sorted
  }

If the results are:

0 1 2 3 .. 15

Then it's not working as expected.

A good result can be something like:

11 3 9 14 ... 1 6

How can I assert the output Array[RDD[T]]] is not sorted?

Upvotes: 4

Views: 161

Answers (1)

Chema
Chema

Reputation: 2838

You could try something like this

val resultOrder = result.sortBy(....)
assert(!resultOrder.sameElements(result))

or

val resultOrder = result.sortBy(....)
assert(!resultOrder.toList == result.toList)

It's important to note that the key is to know how to sort the Array. For an Integer data type it would be easy, but for a complex data type you could need an implicit Ordering for your data type. e.g:

implicit val ordering: Ordering[T] =
    Ordering.fromLessThan[T]((sa: T, sb: T) => sa < sb)

// OR

implicit val ordering: Ordering[MyClass] =
    Ordering.fromLessThan[MyClass]((sa: MyClass, sb: MyClass) => sa.field1 < sb.field1)

The exact code would depend of your data type.

As a full example of this

package tests

import org.apache.log4j.{Level, Logger}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession

object SortArrayRDD {

  val spark = SparkSession
    .builder()
    .appName("SortArrayRDD")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions","4") //Change to a more reasonable default number of partitions for our data
    .config("spark.app.id","SortArrayRDD") // To silence Metrics warning
    .getOrCreate()

  val sc = spark.sparkContext

  def main(args: Array[String]): Unit = {
    try {

      Logger.getRootLogger.setLevel(Level.ERROR)

      val arrRDD: Array[RDD[Int]] = Array(sc.parallelize(List(2,3)),sc.parallelize(List(10,11)),sc.parallelize(List(6,7)),sc.parallelize(List(8,9)),
        sc.parallelize(List(4,5)),sc.parallelize(List(0,1)),sc.parallelize(List(12,13)),sc.parallelize(List(14,15)))
      val aux = arrRDD

      implicit val ordering: Ordering[RDD[Int]] = Ordering.fromLessThan[RDD[Int]]((sa: RDD[Int], sb: RDD[Int]) => sa.sum() < sb.sum())

      aux.sorted.foreach(rdd => println(rdd.collect().mkString(",")))

      val resultOrder = aux.sorted

      assert(!resultOrder.sameElements(arrRDD))
      println("It's unordered")
    } finally {
      sc.stop()
    }
  }
}

Upvotes: 1

Related Questions