Norbert Radyk
Norbert Radyk

Reputation: 2618

Type based collection partitioning in Scala

Given the following data model:

sealed trait Fruit

case class Apple(id: Int, sweetness: Int) extends Fruit

case class Pear(id: Int, color: String) extends Fruit

I've been looking to implement a segregate basket function which for the given basket of fruits will return separate baskets of apples and pears:

def segregateBasket(fruitBasket: Set[Fruit]): (Set[Apple], Set[Pear])

I've attempted a couple of approaches, but none of them seems to be fitting the bill perfectly. Below are my attempts:

  def segregateBasket1(fruitBasket: Set[Fruit]): (Set[Apple], Set[Pear]) = fruitBasket
    .partition(_.isInstanceOf[Apple])
    .asInstanceOf[(Set[Apple], Set[Pear])]

This is the most concise solution I've found, but suffers from explicit type casting via asInstanceOf and is going to be a pain to extend should I decide to add additional types of fruits. Therefore:

  def segregateBasket2(fruitBasket: Set[Fruit]): (Set[Apple], Set[Pear]) = {
    val mappedFruits = fruitBasket.groupBy(_.getClass)
    val appleSet = mappedFruits.getOrElse(classOf[Apple], Set()).asInstanceOf[Set[Apple]]
    val pearSet = mappedFruits.getOrElse(classOf[Pear], Set()).asInstanceOf[Set[Pear]]
    (appleSet, pearSet)
  }

Resolves the problem of additional fruit types (extension really easy), but still strongly depends on risky type casting 'asInstanceOf' which I'd rather avoid. Therefore:

  def segregateBasket3(fruitBasket: Set[Fruit]): (Set[Apple], Set[Pear]) = {
    val appleSet = collection.mutable.Set[Apple]()
    val pearSet = collection.mutable.Set[Pear]()

    fruitBasket.foreach {
      case a: Apple => appleSet += a
      case p: Pear => pearSet += p
    }
    (appleSet.toSet, pearSet.toSet)
  }

Resolves the problem of explicit casting, but uses mutable collections and ideally I'd like to stick with immutable collections and idiomatic code.

I've looked here: Scala: Filtering based on type for some inspiration, but couldn't find a better approach either.

Does anyone have any suggestions on how this functionality can be better implemented in Scala?

Upvotes: 25

Views: 3635

Answers (5)

Xavier Guihot
Xavier Guihot

Reputation: 61736

Starting in Scala 2.13, Sets (and most collections) are provided with a partitionMap method which partitions elements based on a function which returns either Right or Left.

By pattern matching on the type we can map Pears into Left[Pear] and Apples into Right[Apple] for partitionMap to create a tuple of pears and apples:

val (pears, apples) =
  Set(Apple(1, 10), Pear(2, "red"), Apple(4, 6)).partitionMap {
    case pear: Pear   => Left(pear)
    case apple: Apple => Right(apple)
}
// pears: Set[Pear] = Set(Pear(2, "red"))
// apples: Set[Apple] = Set(Apple(1, 10), Apple(4, 6))

Upvotes: 9

Travis Brown
Travis Brown

Reputation: 139058

It's possible to do this in a very clean and generic way using Shapeless 2.0's LabelledGeneric type class. First we define a type class that will show how to partition a list with elements of some algebraic data type into an HList of collections for each constructor:

import shapeless._, record._

trait Partitioner[C <: Coproduct] extends DepFn1[List[C]] { type Out <: HList }

And then for the instances:

object Partitioner {
  type Aux[C <: Coproduct, Out0 <: HList] = Partitioner[C] { type Out = Out0 }

  implicit def cnilPartitioner: Aux[CNil, HNil] = new Partitioner[CNil] {
    type Out = HNil

    def apply(c: List[CNil]): Out = HNil
  }

  implicit def cpPartitioner[K, H, T <: Coproduct, OutT <: HList](implicit
    cp: Aux[T, OutT]
  ): Aux[FieldType[K, H] :+: T, FieldType[K, List[H]] :: OutT] =
    new Partitioner[FieldType[K, H] :+: T] {
      type Out = FieldType[K, List[H]] :: OutT

      def apply(c: List[FieldType[K, H] :+: T]): Out =
        field[K](c.collect { case Inl(h) => (h: H) }) ::
        cp(c.collect { case Inr(t) => t })
  }
}

And then the partition method itself:

implicit def partition[A, C <: Coproduct, Out <: HList](as: List[A])(implicit
  gen: LabelledGeneric.Aux[A, C],
  partitioner: Partitioner.Aux[C, Out]
) = partitioner(as.map(gen.to))

Now we can write the following:

val fruits: List[Fruit] = List(
  Apple(1, 10),
  Pear(2, "red"),
  Pear(3, "green"),
  Apple(4, 6),
  Pear(5, "purple")
)

And then:

scala> val baskets = partition(fruits)
partitioned: shapeless.:: ...

scala> baskets('Apple)
res0: List[Apple] = List(Apple(1,10), Apple(4,6))

scala> baskets('Pear)
res1: List[Pear] = List(Pear(2,red), Pear(3,green), Pear(5,purple))

We could also write a version that would return a tuple of the lists instead of using the record('symbol) syntax—see my blog post here for details.

Upvotes: 10

nlim
nlim

Reputation: 287

  val emptyBaskets: (List[Apple], List[Pear]) = (Nil, Nil)

  def separate(fruits: List[Fruit]): (List[Apple], List[Pear]) = {
    fruits.foldRight(emptyBaskets) { case (f, (as, ps)) =>
      f match {
        case a @ Apple(_, _) => (a :: as, ps)
        case p @ Pear(_, _)  => (as, p :: ps)
      }
    }
  }

Upvotes: 14

Rex Kerr
Rex Kerr

Reputation: 167911

An "immutable" solution would use your mutable solution except not show you the collections. I'm not sure there's a strong reason to think it's okay if library designers do it but anathema for you. However, if you want to stick to purely immutable constructs, this is probably about as good as it gets:

def segregate4(basket: Set[Fruit]) = {
  val apples = basket.collect{ case a: Apple => a }
  val pears = basket.collect{ case p: Pear => p }
  (apples, pears)
}

Upvotes: 15

ach
ach

Reputation: 6234

I'm a little confused by your examples. The return type of each of your "segregate" methods is a Tuple2, yet you want to be able to add more types of Fruit freely. Your method will need to return something with dynamic length (Iterable/Seq/etc) since the length of a tuple needs to be deterministic at compile time.

With that said, maybe I'm oversimplifying it but what about just using groupBy?

val fruit = Set(Apple(1, 1), Pear(1, "Green"), Apple(2, 2), Pear(2, "Yellow"))
val grouped = fruit.groupBy(_.getClass)

And then do whatever you want with the keys/values:

grouped.keys.map(_.getSimpleName).mkString(", ") //Apple, Pear
grouped.values.map(_.size).mkString(", ") //2, 2

link: http://ideone.com/M4N0Pd

Upvotes: 2

Related Questions