Reputation: 4231
I have a sequence of elements, specifically bytes, that I would like to match into something more high-level, e.g. 16-bit integers. The naive approach would look like this:
val stream: Seq[Byte] = ???
def short(b1: Byte, b0: Byte): Short = ???
stream match {
case Seq(a1, a0, b1, b0, _*) => (short(a1, a0), short(b1, b0))
// ...
}
What I managed to do is something like this, using an object short_::
with an unapply
method:
stream match {
case a short_:: b short_:: _ => (a, b)
// ...
}
However, I can't really say I like the syntax here, because it doesn't look a lot like regular pattern matching. I'd be happier to write something like this:
stream match {
case Short(a) :: Short(b) :: _ => (a, b)
// ...
}
Of course, using the identifiers Short
and ::
is probably hard/a bad idea, but I think it gets the point across for the purpose of this question.
Is it possible to write custom pattern matching code that produces a syntax similar to this? I'm restricting myself to fixed-width contents of the stream here (although not a single width: e.g. Short and Int should both be possible), but I need to be able to match the remainder of the stream, like :: tail
or Seq(..., tail @ _*)
.
Upvotes: 1
Views: 1205
Reputation: 8139
Try this out
object ::: {
def unapply(seq: Seq[Byte]): Option[(Int, Seq[Byte])] = {
if(seq.size > 1) Some(seq(0) * 256 + seq(1), seq.drop(2))
else None
}
}
val list = List[Byte](1,2,3,4,5,6,7,8)
list match {
case a ::: b ::: rest => println(a + ", " + b + ", " + rest)
}
you can also mix bytes, shorts and ints.
object +: {
def unapply(seq: Seq[Byte]): Option[(Byte, Seq[Byte])] = {
if(seq.size > 0) Some(seq.head, seq.tail)
else None
}
}
object ++: {
def unapply(seq: Seq[Byte]): Option[(Short, Seq[Byte])] = seq match {
case a +: b +: rest => Some(((a&0xFF) << 8 | (b&0xFF)).toShort, rest)
case _ => None
}
}
object +++: {
def unapply(seq: Seq[Byte]): Option[(Int, Seq[Byte])] = seq match {
case a ++: b ++: rest => Some((a&0xFFFF) << 16 | (b&0xFFFF), rest)
case _ => None
}
}
val list = List[Byte](1,2,3,4,5,6,7,8,9,10,11,12)
list match {
case a +: b ++: c +++: rest => println(a + ", " + b + ", " + c + ", " + rest)
}
+: means the variable to the left is a byte, ++: means the variable to the left is a short and +++: stands for int. btw. for some reason this only works if the names of the objects end with ":".
Upvotes: 2