tst
tst

Reputation: 1172

Could someone point out the type instability in my code?

I am doing some interval computations using the library ValidatedNumerics and I am trying to optimize my code. I define the following:

using ValidatedNumerics

immutable Vector2D{T}
  x::Complex{ValidatedNumerics.Interval{T}}
  y::Complex{ValidatedNumerics.Interval{T}}
end

function F(x::Complex{ValidatedNumerics.Interval{BigFloat}},y::Complex{ValidatedNumerics.Interval{BigFloat}})
  g = 3(1+(1+x+y)^2)/4
  Vector2D{BigFloat}(x + 2y + g,y + g)
end

Then using the following gives the correct unswer

x = @biginterval(1+1im)

F(x,x)

However, when I type:

@code_warntype F(x,x)

I get:

Variables:                                                                                                                                                                                                                                                            
  #self#::#F                                                                                                                                                                                                                                                          
  x::Complex{ValidatedNumerics.Interval{BigFloat}}                                                                                                                                                                                                                    
  y::Complex{ValidatedNumerics.Interval{BigFloat}}                                                                                                                                                                                                                    
  g::Complex{T<:Real}                                                                                                                                                                                                                                                 

Body:                                                                                                                                                                                                                                                                 
  begin                                                                                                                                                                                                                                                               
      # meta: location operators.jl + 138                                                                                                                                                                                                                             
      # meta: location complex.jl + 162                                                                                                                                                                                                                               
      SSAValue(0) = (Core.getfield)(x::Complex{ValidatedNumerics.Interval{BigFloat}},:re)::ValidatedNumerics.Interval{BigFloat}                                                                                                                                       
      SSAValue(2) = $(Expr(:invoke, LambdaInfo for +(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.+), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Down}), :(Base.Rounding._convert_rounding), BigFloat, 1, :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Up}), :(Base.Rounding._convert_rounding), BigFloat, 1, :(ValidatedNumerics.RoundUp))))))), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(0),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(0),:hi)::BigFloat), :(ValidatedNumerics.RoundUp)))))))))                                                                                                                                                                                             
      SSAValue(1) = (Core.getfield)(x::Complex{ValidatedNumerics.Interval{BigFloat}},:im)::ValidatedNumerics.Interval{BigFloat}                                                                                                                                       
      # meta: pop location                                                                                                                                                                                                                                            
      SSAValue(7) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(2),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(2),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))                                     
      SSAValue(8) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(1),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(1),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))                                     
      # meta: location complex.jl + 125                                                                                                                                                                                                                               
      SSAValue(5) = $(Expr(:invoke, LambdaInfo for +(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.+), SSAValue(7), :((Core.getfield)(y,:re)::ValidatedNumerics.Interval{BigFloat})))                                       
      SSAValue(4) = $(Expr(:invoke, LambdaInfo for +(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.+), SSAValue(8), :((Core.getfield)(y,:im)::ValidatedNumerics.Interval{BigFloat})))                                       
      # meta: pop location
      SSAValue(6) = $(Expr(:new, Complex{ValidatedNumerics.Interval{BigFloat}}, :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(5),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(5),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(4),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(4),:hi)::BigFloat), :(ValidatedNumerics.RoundUp)))))))))
      # meta: pop location
      SSAValue(9) = (3 * (1 + (Core._apply)(Base.^,$(Expr(:invoke, LambdaInfo for promote(::Complex{ValidatedNumerics.Interval{BigFloat}}, ::Complex{Int64}), :(Base.promote), SSAValue(6), :($(Expr(:new, Complex{Int64}, 2, 0))))))::Complex{T<:Real})::Complex{T<:Real})::Complex{T<:Real}
      g::Complex{T<:Real} = (Base.Complex)(((Core.getfield)(SSAValue(9),:re)::Real / 4)::Any,((Core.getfield)(SSAValue(9),:im)::Real / 4)::Any)::Complex{T<:Real} # line 3:
      # meta: location complex.jl * 170
      SSAValue(11) = (Core.getfield)(y::Complex{ValidatedNumerics.Interval{BigFloat}},:re)::ValidatedNumerics.Interval{BigFloat}
      SSAValue(12) = $(Expr(:invoke, LambdaInfo for *(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.*), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Down}), :(Base.Rounding._convert_rounding), BigFloat, 2, :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Up}), :(Base.Rounding._convert_rounding), BigFloat, 2, :(ValidatedNumerics.RoundUp))))))), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(11),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(11),:hi)::BigFloat), :(ValidatedNumerics.RoundUp)))))))))
      SSAValue(10) = (Core.getfield)(y::Complex{ValidatedNumerics.Interval{BigFloat}},:im)::ValidatedNumerics.Interval{BigFloat}
      SSAValue(13) = $(Expr(:invoke, LambdaInfo for *(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.*), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Down}), :(Base.Rounding._convert_rounding), BigFloat, 2, :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for _convert_rounding(::Type{BigFloat}, ::Int64, ::RoundingMode{:Up}), :(Base.Rounding._convert_rounding), BigFloat, 2, :(ValidatedNumerics.RoundUp))))))), :($(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(10),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(10),:hi)::BigFloat), :(ValidatedNumerics.RoundUp)))))))))
      # meta: pop location
      SSAValue(22) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(12),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(12),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))
      SSAValue(23) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(13),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(13),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))
      # meta: location operators.jl + 138
      # meta: location complex.jl + 125
      SSAValue(16) = $(Expr(:invoke, LambdaInfo for +(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.+), :((Core.getfield)(x,:re)::ValidatedNumerics.Interval{BigFloat}), SSAValue(22)))
      SSAValue(15) = $(Expr(:invoke, LambdaInfo for +(::ValidatedNumerics.Interval{BigFloat}, ::ValidatedNumerics.Interval{BigFloat}), :(Base.+), :((Core.getfield)(x,:im)::ValidatedNumerics.Interval{BigFloat}), SSAValue(23)))
      # meta: pop location
      SSAValue(19) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(16),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(16),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))
      SSAValue(20) = $(Expr(:invoke, LambdaInfo for ValidatedNumerics.Interval{BigFloat}(::BigFloat, ::BigFloat), ValidatedNumerics.Interval{BigFloat}, :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Down}), BigFloat, :((Core.getfield)(SSAValue(15),:lo)::BigFloat), :(ValidatedNumerics.RoundDown)))), :($(Expr(:invoke, LambdaInfo for BigFloat(::BigFloat, ::RoundingMode{:Up}), BigFloat, :((Core.getfield)(SSAValue(15),:hi)::BigFloat), :(ValidatedNumerics.RoundUp))))))
      SSAValue(18) = (Base.Complex)((SSAValue(19) + (Core.getfield)(g::Complex{T<:Real},:re)::Real)::Any,(SSAValue(20) + (Core.getfield)(g::Complex{T<:Real},:im)::Real)::Any)::Complex{T<:Real}
      # meta: pop location
      SSAValue(21) = (Base.Complex)(((Core.getfield)(y::Complex{ValidatedNumerics.Interval{BigFloat}},:re)::ValidatedNumerics.Interval{BigFloat} + (Core.getfield)(g::Complex{T<:Real},:re)::Real)::Any,((Core.getfield)(y::Complex{ValidatedNumerics.Interval{BigFloat}},:im)::ValidatedNumerics.Interval{BigFloat} + (Core.getfield)(g::Complex{T<:Real},:im)::Real)::Any)::Complex{T<:Real}
      return $(Expr(:new, Vector2D{BigFloat}, :($(Expr(:new, Complex{ValidatedNumerics.Interval{BigFloat}}, :((Base.convert)(ValidatedNumerics.Interval{BigFloat},(Core.getfield)(SSAValue(18),:re)::Real)), :((Base.convert)(ValidatedNumerics.Interval{BigFloat},(Core.getfield)(SSAValue(18),:im)::Real))))), :($(Expr(:new, Complex{ValidatedNumerics.Interval{BigFloat}}, :((Base.convert)(ValidatedNumerics.Interval{BigFloat},(Core.getfield)(SSAValue(21),:re)::Real)), :((Base.convert)(ValidatedNumerics.Interval{BigFloat},(Core.getfield)(SSAValue(21),:im)::Real)))))))
  end::Vector2D{BigFloat}

In that, the following are red:

::Complex{T<:Real} 
::Real
::Any

Does this actually mean that my definition is type unstable? How can I fix that?

Upvotes: 3

Views: 122

Answers (1)

Dan Getz
Dan Getz

Reputation: 18217

As @David Sanders noted SSAValue(9) is the source of the problem for F. Digging further (quite a bit of digging follows), it is the ^2 which messes the type inference. The type signature of the arguments to ^ are

Complex{ValidatedNumerics.Interval{BigFloat}}, Int64

which calls ^(Complex,Complex) after converting 2 to Complex{Int64}(2,0). In turn, ^(Complex,Complex) promoted the Complex{Int64} into a Complex{ValidatedNumerics.Interval{BigFloat}}. Now, we are getting to some actual computation in ^(T<:Complex,T<:Complex) in complex.jl:506. Here is a code snippet, with some lines marked:

function ^{T<:Complex}(z::T, p::T)
    if isinteger(p)
        rp = real(p)        # <---------- (1)
        if rp < 0
            return power_by_squaring(inv(float(z)), convert(Integer, -rp))
        else
            return power_by_squaring(float(z), convert(Integer, rp))
        end
    end
    pr, pim = reim(p)
    zr, zi = reim(z)
    r = abs(z)            # <---------- (*)
    rp = r^pr             # <---------- (2)
    theta = atan2(zi, zr)

This function has two execution branches, the first, when the power is an integer, and the second when it isn't. Both branches use the rp variable and assign it to different types, making rp's type inferenced to Any. A second problem is the line marked (*). Note this line does not actually get executed in the calculation because the power is an integer. If it was executed, the problems wouldn't stay in the type-inference level but become a stack overflow.

abs(z) when z is Complex{T} in (*) calls hypot(T,T) on the real and imaginary parts. These in turn calls sqrt(float(T)) (for Pythagoras' theorem). This is defined in complex.jl:320 as:

sqrt(z::Complex) = sqrt(float(z))

Since float(Complex{ValidatedNumerics.Interval{BigFloat}}) is Complex{ValidatedNumerics.Interval{Float64}}, this still calls sqrt(Complex) causing a loop and a stackoverflow (not this site's name this time). This self-loop also causes a Union{} type-inference which causes the calling function to fail type-inference too.

In summary, there is room for more careful implementation and testing of complex.jl.

This is all the digging I have time for now, but if clarification or fixes are needed, comments would be welcome.

To practically, solve the problem, an explicit type-assertion with :: on the result of ^ might resolve the problem with F.

Upvotes: 3

Related Questions