quantum
quantum

Reputation: 3039

Is scala disregading type in function signatures?

I am going through lectures from excellent Martin Odersky's FP course and one of the lectures demonstrates higher-order functions through Newton's method for finding fixed points of certain functions. There is a cruicial step in the lecture where I think type signature is being violated so I would ask for an explanation. (Apologies for the long intro that's inbound - it felt it was needed.)

One way of implementing such an algorithm is given like this:

  val tolerance = 0.0001

  def isCloseEnough(x: Double, y: Double) = abs((x - y) / x) / x < tolerance

  def fixedPoint(f: Double => Double)(firstGuess: Double) = {
    def iterate(guess: Double): Double = {
      val next = f(guess)
      if (isCloseEnough(guess, next)) next
      else iterate(next)
    }
    iterate(firstGuess)
  }

Next, we try to compute the square root via fixedPoint function, but the naive attempt through

def sqrt(x: Double) = fixedPoint(y => x / y)(1)

is foiled because such an approach oscillates (so, for sqrt(2), the result would alternate indefinitely between 1.0 and 2.0).

To deal with that, we introduce average damping, so that essentially we compute the mean of two nearest calculated values and converge to solution, therefore

def sqrt(x: Double) = fixedPoint(y => (y + x / y) / 2)(1)

Finally, we introduce averageDamp function and the task is to write sqrt with fixedPoint and averageDamp. The averageDamp is defined as follows:

def averageDamp(f: Double => Double)(x: Double) = (x + f(x)) / 2

Here comes the part I don't understand - my initial solution was this:

def sqrt(x: Double) = fixedPoint(z => averageDamp(y => x / y)(z))(1)

but prof. Odersky's solution was more concise:

def sqrt(x: Double) = fixedPoint(averageDamp(y => x / y))(1)

My question is - why does it work? According to function signature, the fixedPoint function is supposed to take a function (Double => Double) but it doesn't mind being passed an ordinary Double (which is what averageDamp returns - in fact, if you try to explicitly specify the return type of Double to averageDamp, the compiler won't throw an error).

I think that my approach follows types correctly - so what am I missing here? Where is it specified or implied(?) that averageDamp returns a function, especially given the right-hand side is clearly returning a scalar? How can you pass a scalar to a function that clearly expects functions only? How do you reason about code that seems to not honour type signatures?

Upvotes: 4

Views: 249

Answers (2)

lmm
lmm

Reputation: 17431

Multiple parameter lists are syntactic sugar for a function that returns another function. You can see this in the scala shell:

scala> :t averageDamp _
(Double => Double) => (Double => Double)

We can write the same function without the syntactic sugar - this is the way we'd do it in e.g. Python:

def averageDamp(f: Double => Double): (Double => Double) = {
   def g(x: Double): Double = (x + f(x)) / 2
   g
}

Returning a function can look a bit weird to start with, but it's complementary to passing a function as an argument and enables some very powerful programming techniques. Functions are just another type of value, like Int or String.

In your original solution you were reusing the variable name y, which I think makes it slightly confusing; we can translate what you've written into:

def sqrt(x: Double) = fixedPoint(z => averageDamp(y => x / y)(z))(1)

With this form, you can hopefully see the pattern:

def sqrt(x: Double) = fixedPoint(z => something(z))(1)

And hopefully it's now obvious that this is the same as:

def sqrt(x: Double) = fixedPoint(something)(1)

which is Odersky's version.

Upvotes: 1

Herrington Darkholme
Herrington Darkholme

Reputation: 6315

Your solution is correct, but it can be more concise.

Let's scrutinize the averageDamp function more closely.

def averageDamp(f: Double => Double)(x: Double): Double = (x + f(x)) / 2

The return type annotation is added to make it more clearly. I think what you are missing is here:

but it doesn't mind being passed an ordinary Double (which is what averageDamp returns - in fact, if you try to explicitly specify the return type of Double to averageDamp, the compiler won't throw an error).

But averageDamp(y => y/x) does return a Double => Double function! averageDamp requires to be passed TWO argument lists to return a Double.

If the function receive just one argument, it still wants the other one to be completed. So rather than returning the result immediately, it returns a function, saying that "I still need an argument here, feed me that so I will return what you want".

Prof MO did pass ONE function argument to it, not two, so averageDamp is partially applied, in the sense that it returns a Double => Double function.

The course will also tell you functions with multiple argument lists are syntactical sugar form of this:

def f(arg1)(arg2)(arg3)...(argN-1)(argN) = (argN) => f(arg1)(arg2)(arg3)...(argN-1)

If you give one less argument than f needs, it just return the right side of equation, that is, a function. So, heeding that averageDamp(y => x / y), the argument passed to fixPoint, is actually a function should help you understand the question.

Notice: There is some difference between partially applied function(or function currying) and multiple argument list function

For example you cannot declare like this

val a = averageDamp(y => y/2)

The compiler will complain about this as 'method is not a partially applied function'.

The difference is explained here: What's the difference between multiple parameters lists and multiple parameters per list in Scala?.

Upvotes: 6

Related Questions