SørenHN
SørenHN

Reputation: 696

How to pattern match on the types of list elements

I would like to pattern match on a list of objects based on their types. But specifying the pattern as case x: List[ObjectType] does not seem to work.

Take for example this program.

sealed trait A
case class B() extends A
case class C() extends A

def func(theList: List[A]) = theList match
{
    case listOfB: List[B] => println("All B's")
    case listOfC: List[C] => println("All C's")
    case _ => println("Somthing else")
}

func(List(C(), C(), C())) // prints: "All B's"

Although the list only contains C's and the case pattern specifies a list of B's, the match statement recognises it as a list of B's?

I know I can check each element of the list like this:

case listOfA: List[A] if listOfA.forall{case B() => true case _ => false} => println("All B's")

But it's more cumbersome and I have to specify that it is indeed a list of B's (listOfA.asInstanceOf[List[B]]) when I try to use it.

How can I do this in a smarter / better way?

Upvotes: 1

Views: 1989

Answers (2)

Dmytro Mitin
Dmytro Mitin

Reputation: 51658

Try custom extractors to make pattern matching less cumbersome

object AllB {
  def unapply(listOfA: List[A]): Boolean = 
    listOfA.forall { case B() => true; case _ => false }
}
object AllC {
  def unapply(listOfA: List[A]): Boolean = 
    listOfA.forall { case C() => true; case _ => false }
}

def func(theList: List[A]) = theList match {
  case AllB() => println("All B's")
  case AllC() => println("All C's")
  case _      => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

or

import cats.implicits._

object AllB {
  def unapply(listOfA: List[A]): Option[List[B]] = 
    listOfA.traverse { case b@B() => Some(b); case _ => None }
}
object AllC {
  def unapply(listOfA: List[A]): Option[List[C]] = 
    listOfA.traverse { case c@C() => Some(c); case _ => None }
}

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

Or you can define a single class to create all necessary extractors and remove code repetitions

class All[SubT: ClassTag] {
  def unapply[T >: SubT](listOfA: List[T]): Option[List[SubT]] = 
    listOfA.traverse { case x: SubT => Some(x); case _ => None }
}

object AllB extends All[B]
object AllC extends All[C]
// val AllB = new All[B]
// val AllC = new All[C]

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

I guess, the simplest is to use Shapeless

import shapeless.TypeCase

val AllB = TypeCase[List[B]]
val AllC = TypeCase[List[C]]

def func(theList: List[A]) = theList match {
  case AllB(listOfB) => println("All B's")
  case AllC(listOfC) => println("All C's")
  case _             => println("Something else")
}

func(List(B(), B(), B()))    // All B's
func(List[A](B(), B(), B())) // All B's
func(List(C(), C(), C()))    // All C's
func(List(C(), B(), C()))    // Something else

https://github.com/milessabin/shapeless/wiki/Feature-overview:-shapeless-2.0.0#type-safe-cast

In Shapeless type class Typeable is defined. Just its instances for lists are defined a little trickier than in @LuisMiguelMejíaSuárez's answer (namely, using runtime reflection)

/** Typeable instance for `Traversable`.    
 *  Note that the contents be will tested for conformance to the element type. */  
implicit def genTraversableTypeable[CC[X] <: Iterable[X], T]
  (implicit mCC: ClassTag[CC[_]], castT: Typeable[T]): Typeable[CC[T] with Iterable[T]] =
  // Nb. the apparently redundant `with Iterable[T]` is a workaround for a
  // Scala 2.10.x bug which causes conflicts between this instance and `anyTypeable`.
  new Typeable[CC[T]] {
    def cast(t: Any): Option[CC[T]] =
      if(t == null) None
      else if(mCC.runtimeClass isInstance t) {
        val cc = t.asInstanceOf[CC[Any]]
        if(cc.forall(_.cast[T].isDefined)) Some(t.asInstanceOf[CC[T]])
        else None
      } else None
    def describe = s"${safeSimpleName(mCC)}[${castT.describe}]"
  }

https://github.com/milessabin/shapeless/blob/master/core/src/main/scala/shapeless/typeable.scala#L235-L250

Also see Ways to pattern match generic types in Scala https://gist.github.com/jkpl/5279ee05cca8cc1ec452fc26ace5b68b

Upvotes: 2

Assuming you have at compile time a List[B] or a List[C] and you want to manipulate them in a different way, you can use a typeclass.

Something like this:

trait MyTypeClass[T] {
  def process(data: List[T]): String
}

sealed trait A extends Product with Serializable
final case class B() extends A
final case class C() extends A
object A extends ALowerPriority {
  implicit final val AllOfB: MyTypeClass[B] =
    (_: List[B]) => "All B's"
  
  implicit final val AllOfC: MyTypeClass[C] =
    (_: List[C]) => "All C's"
}

trait ALowerPriority {
  implicit final val Mixed: MyTypeClass[A] =
    (_: List[A]) => "Somenthing else"
}

def func[T](theList: List[T])
          (implicit ev: MyTypeClass[T]): Unit =
  println(ev.process(data = theList))

Which works like this:

val bs = List(B(), B(), B())
val cs = List(C(), C(), C())
val mixed = List(C(), B(), C())

func(bs) // All B's
func(cs) // All C's
func(mixed) // Something else

Note: You would need to think what would be the interface to expose on your typeclass, so you can write generic functions but that behave differently according to the underlying type.


However, remember typeclasses are selected at compile time and using only the types. So, if you have a compile time value of type List[A] even if it is full of Bs it will pick "Something else":

val as: List[A] = List(B(), B(), B())
func(as) // Something else

You can see the code running here.

Upvotes: 2

Related Questions