RussAbbott
RussAbbott

Reputation: 2738

Is this cast necessary in Scala?

Suppose I have a simple abstract BinaryTree with subclasses Node and Leaf and I want to write a function that produces a List[Leaf].

def getLeaves(tree: BinaryTree): List[Leaf] =
    tree match {
      case Leaf(v) => List(tree.asInstanceOf[Leaf])
      case Node(left, right) => getLeaves(left) ++ getLeaves(right)
    }

Is there a way to avoid the explicit asInstanceOf[Leaf] cast in the Leaf case? If I leave it out I get a diagnostic saying: found: BinaryTree; required Leaf.

Upvotes: 0

Views: 101

Answers (3)

RussAbbott
RussAbbott

Reputation: 2738

I saw this construct used elsewhere. It seems to do the job.

def getLeaves(tree: BinaryTree): List[Leaf] =
    tree match {
      case leaf: Leaf => List(leaf)
      case Node(left, right) => getLeaves(left) ++ getLeaves(right)
    }

Upvotes: 8

user537862
user537862

Reputation: 481

Try this way

def getLeaves(tree: BinaryTree): List[Leaf] =
    tree match {
      case x@Leaf(v) => List(x)
      case Node(left, right) => getLeaves(left) ++ getLeaves(right)
    }

Also note, that your implementation is bad performance-wise, as you're creating a new list in every Node.

it can be fixed this way

def genLeaves(tree:BinaryTree) = {
  def getLeaves0(tree: BinaryTree, acc:List[Leaf]): List[Leaf] =
    tree match {
      case x@Leaf(v) => x::acc
      case Node(left, right) => {
         val leftLeaves = getLeaves(left, acc)
         getLeaves(right, leftLeaves)
         }
    }
  getLeaves0(tree).reverse
}

Here' you'll reuse all already collected items and you'll have only a single list allocated during traversal. You're collecting elements in it as you traverse them, so you'll end up with Leaves in reverse order(List works as LIFO), so to get elements in order how you visited it, we need to reverse resulting list.

Upvotes: 3

Idan Arye
Idan Arye

Reputation: 12623

You can utilize the fact that you have already deconstructed tree to Leaf(v), and rebuild the leaf:

def getLeaves(tree: BinaryTree): List[Leaf] =
    tree match {
        case Leaf(v) => List(Leav(v))
        case Node(left, right) => getLeaves(left) ++ getLeaves(right)
    }

Upvotes: 1

Related Questions