Luigi Plinge
Luigi Plinge

Reputation: 51109

Method to check reference equality for Any type

I'm trying to make a method that will match on reference equality for any type, including primitives. How best to do this?

eq is only defined on AnyRef. If we try

def refEquals[A <% AnyRef, B <% AnyRef](a: A, b: B) = a eq b

then on running refEquals(1,2) we find there are implicit methods in Predef including int2IntegerConflict to scupper such conversions.

I tried this:

def refEquals(a: Any, b: Any) = a match {
  case x: AnyRef => b match {
    case y: AnyRef => x eq y
    case _ => false
  }
  case x: Any => b match {
    case y: AnyRef => false
    case y: Any => x == y
  }
}

But this doesn't work (refEquals(1.0, 1.0) gives false) for reasons given by Rex Kerr here: Strange pattern matching behaviour with AnyRef

So how do we implement such a method?

edit: should have said "reference equality for reference types, or value equality for primitive types".

edit: here's the method using the idea from Rex's answer, for anyone who needs this and doesn't like typing:

def refEquals(a: Any, b: Any) = a match {
  case x: Boolean if b.isInstanceOf[Boolean] => x == b
  case x: Byte    if b.isInstanceOf[Byte]    => x == b
  case x: Short   if b.isInstanceOf[Short]   => x == b
  case x: Char    if b.isInstanceOf[Char]    => x == b
  case x: Int     if b.isInstanceOf[Int]     => x == b
  case x: Float   if b.isInstanceOf[Float]   => x == b
  case x: Double  if b.isInstanceOf[Double]  => x == b
  case x: Long    if b.isInstanceOf[Long]    => x == b
  case _ => a.asInstanceOf[AnyRef] eq b.asInstanceOf[AnyRef]
}

Upvotes: 1

Views: 430

Answers (2)

Iulian Dragos
Iulian Dragos

Reputation: 5712

Reference equality is undefined for primitive types because they are not references. The only notion of equality in that case is value equality.

However, if you want your code to work both with primitives and reference types, you can either use '==' and make sure you pass objects that don't redefine 'equals', or define your own equality object and pass it around. You could probably use 'scala.math.Equiv[T]'.

def myRefEquals[A](x: A, y: A)(implicit eq: Equiv[A]) {
  eq.equiv(x, y)
}

implicit def anyRefHasRefEquality[A <: AnyRef] = new Equiv[A] {
  def equiv(x: A, y: A) = x eq y
}

implicit def anyValHasUserEquality[A <: AnyVal] = new Equiv[A] {
  def equiv(x: A, y: A) = x == y
}

println(myRefEquals(Some(1), Some(1)))

This assumes you want the both objects to have the same type.

Upvotes: 5

Rex Kerr
Rex Kerr

Reputation: 167891

You catch all the primitives first, then fall through to eq:

def refEquals(a: Any, b: Any) = a match {
  case x: Boolean => b match {
    case y: Boolean => x==y
    case _ => false
  }
  case x: Byte =>
  ...
  case _ => a.asInstanceOf[AnyRef] eq b.asInstanceOf[AnyRef]
}

Upvotes: 4

Related Questions