Shafique Jamal
Shafique Jamal

Reputation: 1688

How to keep all elements when aggregating on AggregateMessages on a GraphFrame?

Suppose I have the following graph:

scala> v.show()
+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|           null|
|BBB|           null|
|QQQ|           null|
|DDD|           null|
|FFF|           null|
|EEE|           null|
|AAA|           null|
|GGG|           null|
+---+---------------+


scala> e.show()
+---+---+---+
| iD|src|dst|
+---+---+---+
|  1|CCC|AAA| 
|  2|CCC|BBB| 
...
+---+---+---+

I would like to run an aggregation that gets all of the messages (not just the sum, first, last, etc) that are sent from the destination vertexes to the source vertexes. So the command I would like to run is something like:

g.aggregateMessages.sendToSrc(AM.edge("id")).agg(all(AM.msg).as("downstreamEdges")).show()

except that the function all does not exist (not that I'm aware of). The output would be something like:

+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|         [1, 2]|
... 
+---+---------------+

I am able to use the above function with first or last instead of (the non-existent) all, but they would give me only

+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|              1|
... 
+---+---------------+

or

+---+---------------+
| id|downstreamEdges|
+---+---------------+
|CCC|              2|
... 
+---+---------------+

respectively. How could I keep all of the entries? (There could be many, not just 1 and 2, but 1,2,23,45, etc). Thanks.

Upvotes: 1

Views: 341

Answers (2)

Alex Ortner
Alex Ortner

Reputation: 1228

I solved something similar by using the aggregation function collect_set()

 agg = gx.aggregateMessages(
            f.collect_set(AM.msg).alias("aggMess"),
            sendToSrc=AM.edge("id")
            sendToDst=None)

another one (with duplicates) would be collect_list()

Upvotes: 1

Shafique Jamal
Shafique Jamal

Reputation: 1688

I adapted this answer to come up with the following:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.graphframes.lib.AggregateMessages

class KeepAllString extends UserDefinedAggregateFunction {
  private val AM = AggregateMessages

  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", StringType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("ids", ArrayType(StringType, containsNull = true), nullable = true) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = ArrayType(StringType,true)

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = Seq[String]()


  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =
    buffer(0) = buffer.getAs[Seq[String]](0) ++ Seq(input.getAs[String](0))

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit =
    buffer1(0) = buffer1.getAs[Seq[String]](0) ++ buffer2.getAs[Seq[String]](0)

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = buffer.getAs[Seq[String]](0)
}

They my all method above is just: val all = new KeepAllString().

But how would one make it generic, so that for BigDecimal, Timestamp, etc I could just do something like:

val allTimestamp = new KeepAll[Timestamp]()

?

Upvotes: 0

Related Questions