Some Name
Some Name

Reputation: 9540

Split list of algebraic date type to lists of branches?

I'm pretty new to shapeless so the question might be easy.

Here is the ADT:

sealed trait Test

final case class A() extends Test
final case class B() extends Test
final case class C() extends Test
...
final case class Z() extends Test

Is it possible to write a function without extremely cumbersome pattern matching?

def split(lst: List[Test]): List[A] :: List[B] :: ... :: HNil = //

Upvotes: 1

Views: 108

Answers (1)

Dmytro Mitin
Dmytro Mitin

Reputation: 51703

At compile time all elements of List have the same static type Test, so there is no way to distinguish elements A, B, C... using compile-time technique only (Shapeless, type classes, implicits, macros, compile-time reflection). The elements are distinguishable at runtime only, so you have to use some runtime technique (pattern matching, casting, runtime reflection).

Why Does This Type Constraint Fail for List[Seq[AnyVal or String]]

Scala: verify class parameter is not instanceOf a trait at compile time

flatMap with Shapeless yield FlatMapper not found

Try split into a Map using runtime reflection

def split(lst: List[Test]): Map[String, List[Test]]  =
  lst.groupBy(_.getClass.getSimpleName)

split(List(C(), B(), A(), C(), B(), A()))
// HashMap(A -> List(A(), A()), B -> List(B(), B()), C -> List(C(), C()))

or split into a HList using Shapeless + runtime reflection

import shapeless.labelled.{FieldType, field}
import shapeless.{::, Coproduct, HList, HNil, LabelledGeneric, Poly1, Typeable, Witness}
import shapeless.ops.coproduct.ToHList
import shapeless.ops.hlist.Mapper
import shapeless.ops.record.Values
import shapeless.record._
import scala.annotation.implicitNotFound
    
object listPoly extends Poly1 {
  implicit def cse[K <: Symbol, V]: Case.Aux[FieldType[K, V], FieldType[K, List[V]]] = null
}

// modified shapeless.ops.maps.FromMap
@implicitNotFound("Implicit not found: FromMapWithDefault[${R}]. Maps can only be converted to appropriate Record types.")
trait FromMapWithDefault[R <: HList] extends Serializable {
  // if no value by this key use default, if can't cast return None
  def apply[K, V](m: Map[K, V], default: V): Option[R]
}
object FromMapWithDefault {
  implicit def hnilFromMap[T]: FromMapWithDefault[HNil] =
    new FromMapWithDefault[HNil] {
      def apply[K, V](m: Map[K, V], default: V): Option[HNil] = Some(HNil)
    }


  implicit def hlistFromMap[K0, V0, T <: HList]
  (implicit wk: Witness.Aux[K0], tv: Typeable[V0], fmt: FromMapWithDefault[T]): FromMapWithDefault[FieldType[K0, V0] :: T] =
    new FromMapWithDefault[FieldType[K0, V0] :: T] {
      def apply[K, V](m: Map[K, V], default: V): Option[FieldType[K0, V0] :: T] = {
        val value = m.getOrElse(wk.value.asInstanceOf[K], default)
        for {
          typed <- tv.cast(value)
          rest <- fmt(m, default)
        } yield field[K0](typed) :: rest
      }
    }
}

def split[T, C <: Coproduct, L <: HList, L1 <: HList](lst: List[T])(
  implicit
  labelledGeneric: LabelledGeneric.Aux[T, C],
  toHList: ToHList.Aux[C, L],
  mapper: Mapper.Aux[listPoly.type, L, L1],
  fromMapWithDefault: FromMapWithDefault[L1],
  values: Values[L1]
): values.Out = {
  val groupped = lst.groupBy(_.getClass.getSimpleName).map { case (k, v) => Symbol(k) -> v }
  fromMapWithDefault(groupped, Nil).get.values
}

Testing:

sealed trait Test
final case class A() extends Test
final case class B() extends Test
final case class C() extends Test
final case class Z() extends Test

val res = split(List[Test](C(), B(), A(), C(), B(), A())) 
// List(A(), A()) :: List(B(), B()) :: List(C(), C()) :: List() :: HNil
res: List[A] :: List[B] :: List[C] :: List[Z] :: HNil

Scala 3 collection partitioning with subtypes (Scala 2/3)

Upvotes: 4

Related Questions