Silly Freak
Silly Freak

Reputation: 4231

matching more than one element from a sequence of elements

I have a sequence of elements, specifically bytes, that I would like to match into something more high-level, e.g. 16-bit integers. The naive approach would look like this:

val stream: Seq[Byte] = ???
def short(b1: Byte, b0: Byte): Short = ???

stream match {
    case Seq(a1, a0, b1, b0, _*) => (short(a1, a0), short(b1, b0))
    // ...
}

What I managed to do is something like this, using an object short_:: with an unapply method:

stream match {
    case a short_:: b short_:: _ => (a, b)
    // ...
}

However, I can't really say I like the syntax here, because it doesn't look a lot like regular pattern matching. I'd be happier to write something like this:

stream match {
    case Short(a) :: Short(b) :: _ => (a, b)
    // ...
}

Of course, using the identifiers Short and :: is probably hard/a bad idea, but I think it gets the point across for the purpose of this question.

Is it possible to write custom pattern matching code that produces a syntax similar to this? I'm restricting myself to fixed-width contents of the stream here (although not a single width: e.g. Short and Int should both be possible), but I need to be able to match the remainder of the stream, like :: tail or Seq(..., tail @ _*).

Upvotes: 1

Views: 1205

Answers (1)

Tesseract
Tesseract

Reputation: 8139

Try this out

object ::: {
  def unapply(seq: Seq[Byte]): Option[(Int, Seq[Byte])] = {
    if(seq.size > 1) Some(seq(0) * 256 + seq(1), seq.drop(2))
    else None
  }
}


val list = List[Byte](1,2,3,4,5,6,7,8)
list match {
  case a ::: b ::: rest => println(a + ", " + b + ", " + rest)
}

you can also mix bytes, shorts and ints.

object +: {
  def unapply(seq: Seq[Byte]): Option[(Byte, Seq[Byte])] = {
    if(seq.size > 0) Some(seq.head, seq.tail)
    else None
  }
}

object ++: {
  def unapply(seq: Seq[Byte]): Option[(Short, Seq[Byte])] = seq match {
    case a +: b +: rest => Some(((a&0xFF) << 8 | (b&0xFF)).toShort, rest)
    case _ => None
  }
}

object +++: {
  def unapply(seq: Seq[Byte]): Option[(Int, Seq[Byte])] = seq match {
    case a ++: b ++: rest => Some((a&0xFFFF) << 16 | (b&0xFFFF), rest)
    case _ => None
  }
}

val list = List[Byte](1,2,3,4,5,6,7,8,9,10,11,12)
list match {
  case a +: b ++: c +++: rest => println(a + ", " + b + ", " + c + ", " + rest)
}

+: means the variable to the left is a byte, ++: means the variable to the left is a short and +++: stands for int. btw. for some reason this only works if the names of the objects end with ":".

Upvotes: 2

Related Questions