Randomize
Randomize

Reputation: 9103

How to aggregate elements in a list?

I have a few Lists of these two kinds (List[Array[String]]):

1) List(Array("Mark","2000","2002"), Array("John","2001","2003"), Array("Andrew","1999","2001"), Array("Erik","1996","1998"))

2) List(Array("Steve","2000","2005"))

Based on this condition:

If the range of years overlap, it means that the guys know each others otherwise no.

What I am expecting are data grouped in this way:

Array(name, start_year, end_year, known_people, unknown_people)

so for the specific example 1) the final result is:

List(
  Array("Mark",   "2000", "2002", "John#Andrew", "Erik"), 
  Array("John",   "2001", "2003", "Mark#Andrew", "Erik"), 
  Array("Andrew", "1999", "2001", "Mark#John",   "Erik"), 
  Array("Erik",   "1996", "1998", "",            "Mark#John#Andrew")
)

For the second case just:

List(Array("Steve","2000","2005", "", ""))

I am not sure what to do as I am stucked in doing a cartesian product and filter out the same name like:

my_list.cartesian(my_list).filter { case (a,b) => a(0) != b(0) }

but at this point I cannot make work an aggregateByKey.

Any idea?

Upvotes: 1

Views: 1462

Answers (1)

Martin Senne
Martin Senne

Reputation: 6059

Answer

The code

class UnsortedTestSuite3 extends SparkFunSuite {
  configuredUnitTest("SO - aggregateByKey") { sc =>
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{UserDefinedFunction, Column, SQLContext, DataFrame}

    val persons = Seq(
      Person("Mark",   2000, 2002),
      Person("John",   2001, 2003),
      Person("Andrew", 1999, 2001),
      Person("Erik",   1996, 1998)
    )

    // input
    val personDF = sc.parallelize( persons ).toDF
    val personRenamedDF = personDF.select(
      col("name").as("right_name"),
      col("fromYear").as("right_fromYear"),
      col("toYear").as("right_toYear")
    )

    /**
      * Group entries of a DateFrame by entries in second column.
      * @param df a dataframe with two string columns
      * @return dataframe, where second column contains group of values for the an identical entry in first column
      */
    def groupBySecond( df: DataFrame ) : DataFrame = {
      val st: StructType = df.schema
      if ( (st.size != 2) &&
           (! st(0).dataType.equals(StringType) ) &&
           (! st(1).dataType.equals(StringType) ) ) throw new RuntimeException("Wrong schema for groupBySecond.")

      df.rdd
        .map( row => (row.getString(0), row.getString(1)) )
        .groupByKey().map( x => ( x._1, x._2.toList))
        .toDF( st(0).name, st(1).name )
    }

    val joined = personDF.join(personRenamedDF, col("name") !== col("right_name"), "inner")
    val intervalOverlaps = (col("toYear") >= col("right_fromYear")) && (col("fromYear") <= col("right_toYear"))
    val known = groupBySecond( joined.filter( intervalOverlaps ).select(col("name"), col("right_name").as("knows")) )
    val unknown = groupBySecond( joined.filter( !intervalOverlaps ).select(col("name"), col("right_name").as("does_not_know")) )

    personDF.join( known, "name").join(unknown, "name").show()
  }
}

gives you the expected result

+------+--------+------+--------------+-------------+
|  name|fromYear|toYear|         knows|does_not_know|
+------+--------+------+--------------+-------------+
|  John|    2001|  2003|[Mark, Andrew]|       [Erik]|
|  Mark|    2000|  2002|[John, Andrew]|       [Erik]|
|Andrew|    1999|  2001|  [Mark, John]|       [Erik]|
+------+--------+------+--------------+-------------+

Explanation

  • Using case classes to model your Person, so you have not to struggle with Array.
  • Using Spark SQL as it is most concise.
  • Technically:
    • Using an inner join to create pairs of all people. Pairs with identical name are discarded via join criterion
    • Using filter to find overlapping or non-overlapping intervals.
    • Then using the helper method groupBySecond to perform a groupBy on a DataFrame. Currently, this is not possible in Spark SQL, as no UDAF (User Defined Aggregation Functions) exists yet. Will raise a subsequent SO ticket, as to hear experts on this
    • Join the original DataFrame personDF with the known and unknown DataFrame to yield final result.

Edit 2015-11-13 - 2pm

I just discovered that the present code is not delivering the proper results. (Erik is missing!)

Thus

case class Person(name: String, fromYear: Int, toYear: Int)

class UnsortedTestSuite3 extends SparkFunSuite {
  configuredUnitTest("SO - aggregateByKey") { sc =>
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{UserDefinedFunction, Column, SQLContext, DataFrame}

    val persons = Seq(
      Person("Mark",   2000, 2002),
      Person("John",   2001, 2003),
      Person("Andrew", 1999, 2001),
      Person("Erik",   1996, 1998)
    )

    // input
    val personDF = sc.parallelize( persons ).toDF
    val personRenamedDF = personDF.select(
      col("name").as("right_name"),
      col("fromYear").as("right_fromYear"),
      col("toYear").as("right_toYear")
    )

    /**
      * Group entries of a DateFrame by entries in second column.
      * @param df a dataframe with two string columns
      * @return dataframe, where second column contains group of values for the an identical entry in first column
      */
    def groupBySecond( df: DataFrame ) : DataFrame = {
      val st: StructType = df.schema
      if ( (st.size != 2) &&
           (! st(0).dataType.equals(StringType) ) &&
           (! st(1).dataType.equals(StringType) ) ) throw new RuntimeException("Wrong schema for groupBySecond.")

      df.rdd
        .map( row => (row.getString(0), row.getString(1)) )
        .groupByKey().map( x => ( x._1, if (x._2 == List(null)) List() else x._2.toList ))
        .toDF( st(0).name, st(1).name )
    }

    val distinctName = col("name") !== col("right_name")
    val intervalOverlaps = (col("toYear") >= col("right_fromYear")) && (col("fromYear") <= col("right_toYear"))

    val knownDF_t = personDF.join(personRenamedDF, distinctName && intervalOverlaps, "leftouter")
    val knownDF = groupBySecond( knownDF_t.select(col("name").as("kname"), col("right_name").as("knows")) )

    val unknownDF_t = personDF.join(personRenamedDF, distinctName && !intervalOverlaps, "leftouter")
    val unknownDF = groupBySecond( unknownDF_t.filter( !intervalOverlaps ).select(col("name")as("uname"), col("right_name").as("does_not_know")) )

    personDF
      .join( knownDF, personDF("name") === knownDF("kname"), "leftouter")
      .join( unknownDF, personDF("name") === unknownDF("uname"), "leftouter")
      .select( col("name"), col("fromYear"), col("toYear"), col("knows"), col("does_not_know"))
      .show()

  }
}

does the trick with the result

+------+--------+------+--------------+--------------------+
|  name|fromYear|toYear|         knows|       does_not_know|
+------+--------+------+--------------+--------------------+
|  John|    2001|  2003|[Mark, Andrew]|              [Erik]|
|  Mark|    2000|  2002|[John, Andrew]|              [Erik]|
|Andrew|    1999|  2001|  [Mark, John]|              [Erik]|
|  Erik|    1996|  1998|            []|[Mark, John, Andrew]|
+------+--------+------+--------------+--------------------+

Upvotes: 3

Related Questions