Stuart
Stuart

Reputation: 1492

Restricting function signatures while using ForwardDiff in Julia

I am trying to use ForwardDiff in a library where almost all functions are restricted to only take in Floats. I want to generalise these function signatures so that ForwardDiff can be used while still being restrictive enough so functions only take numeric values and not things like Dates. I have alot of functions with the same name but different types (ie functions that take in "time" as either a float or a Date with the same function name) and do not want to remove the type qualifiers throughout.

Minimal Working Example

using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
    return sum(exp.(x))
end
function grad_F(x::Array)
  return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error

function G(x::Array{Float64,1})
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error

function G(x)
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.

Is there are a way to restict functions to only take numeric values (Ints and Floats) and whatever dual number structs that ForwardDiff uses but not allow Symbols, Dates, etc.

Upvotes: 0

Views: 383

Answers (2)

hckr
hckr

Reputation: 5583

ForwardDiff.Dual is a subtype of the abstract type Real. The issue you have, however, is that Julia's type parameters are invariant, not covariant. The following, then, returns false.

# check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
julia> Array{Float64, 1} <: Array{Real, 1}
false

That makes your function definition

function G(x::Array{Real,1})
    return sum(exp.(x))
end

incorrect (not suitable for your use). That's why you get the following error.

julia> G(x)
ERROR: MethodError: no method matching G(::Array{Float64,1})

The correct definition should rather be

function G(x::Array{<:Real,1})
    return sum(exp.(x))
end

or if you somehow need an easy access to the concrete element type of the array

 function G(x::Array{T,1}) where {T<:Real}
     return sum(exp.(x))
 end

The same goes for your grad_F function.

You might find it useful to read the relevant section of the Julia documentation for types.


You might also want to type annotate your functions for AbstractArray{<:Real,1} type rather than Array{<:Real, 1} so that your functions can work other types of arrays, like StaticArrays, OffsetArrays etc., without a need for redefinitions.

Upvotes: 2

HarmonicaMuse
HarmonicaMuse

Reputation: 7893

This would accept any kind of array parameterized by any kind of number:

function foo(xs::AbstractArray{<:Number})
  @show typeof(xs)
end

or:

function foo(xs::AbstractArray{T}) where T<:Number
  @show typeof(xs)
end

In case you need to refer to the type parameter T inside the body function.

x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
x2 = [1, 2, 3,4, 5]
x3 = 1:5
x4 = 1.0:5.0
x5 = [1//2, 1//4, 1//8]

xss = [x1, x2, x3, x4, x5]

function foo(xs::AbstractArray{T}) where T<:Number
  @show xs typeof(xs) T
  println()
end

for xs in xss
  foo(xs)
end

Outputs:

xs = [1.0, 2.0, 3.0, 4.0, 5.0]
typeof(xs) = Array{Float64,1}
T = Float64

xs = [1, 2, 3, 4, 5]
typeof(xs) = Array{Int64,1}
T = Int64

xs = 1:5
typeof(xs) = UnitRange{Int64}
T = Int64

xs = 1.0:1.0:5.0
typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
T = Float64

xs = Rational{Int64}[1//2, 1//4, 1//8]
typeof(xs) = Array{Rational{Int64},1}
T = Rational{Int64}

You can run the example code here: https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia

Upvotes: 1

Related Questions