Reputation: 456
Consider the following Julia "compound" iterator: it merges two iterators, a
and b
,
each of which are assumed to be sorted according to order
, to a single ordered
sequence:
struct MergeSorted{T,A,B,O}
a::A
b::B
order::O
MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
@inline function Base.iterate(self::MergeSorted{T},
state=(iterate(self.a), iterate(self.b))) where T
a_result, b_result = state
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), (iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), (iterate(self.a, a_state), b_result)
end
return T(b_curr), (a_result, iterate(self.b, b_state))
end
This code works, but is type-instable since the Julia iteration facilities are inherently so. For most cases, the compiler can work this out automatically, however, here it does not work: the following test code illustrates that temporaries are created:
>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)
Note the allocation count.
Is there any way to efficiently debug such situations other than playing around with the code and hoping that the compiler will be able to optimize out the type ambiguities? Does anyone know there any solution in this specific case that does not create temporaries?
Upvotes: 2
Views: 486
Reputation: 69949
Answer: use @code_warntype
Run:
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Union{}
@_7::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Tuple{Int64, Int64}
a_result::Tuple{Int64, Int64}
Body::Tuple{Int64, Any}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
│ (a_result = Core.getfield(%9, 1))
│ (@_7 = Core.getfield(%9, 2))
│ %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
│ (b_result = Core.getfield(%12, 1))
│ %14 = (b_result === Main.nothing)::Core.Const(false)
└── goto #3 if not %14
2 ─ Core.Const(:(a_result === Main.nothing))
│ Core.Const(:(%16))
│ Core.Const(:(return Main.nothing))
│ Core.Const(:(Base.indexed_iterate(a_result, 1)))
│ Core.Const(:(a_curr = Core.getfield(%19, 1)))
│ Core.Const(:(@_6 = Core.getfield(%19, 2)))
│ Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
│ Core.Const(:(a_state = Core.getfield(%22, 1)))
│ Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
│ Core.Const(:(Base.getproperty(self, :a)))
│ Core.Const(:(Main.iterate(%25, a_state)))
│ Core.Const(:(Core.tuple(%26, b_result)))
│ Core.Const(:(Core.tuple(%24, %27)))
└── Core.Const(:(return %28))
3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Core.Const(true)
└── goto #6 if not %35
4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #6 if not %46
5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Base.getproperty(self, :a)::Vector{Int64}
│ %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
│ %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
└── return %52
6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %55 = a_result::Tuple{Int64, Int64}
│ %56 = Base.getproperty(self, :b)::Vector{Int64}
│ %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
│ %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
└── return %59
and you see that there are too many types of return value, so Julia gives up specializing them (and just assumes the second element of return type is Any
).
Answer: reduce the number of return type options of iterate
.
Here is a quick write up (I do not claim it is most terse and have not tested it extensively so there might be some bug, but it was simple enough to write quickly using your code to show how one could approach your problem; note that I use special branches when one of the collections is empty as then it should be faster to just iterate one collection):
struct MergeSorted{T,A,B,O,F1,F2}
a::A
b::B
order::O
fa::F1
fb::F2
function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
fa, fb = iterate(a), iterate(b)
F1 = typeof(fa)
F2 = typeof(fb)
new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
end
end
Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
struct State{Ta, Tb}
a::Union{Nothing, Ta}
b::Union{Nothing, Tb}
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
return nothing
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
return self.fa
end
function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
return iterate(self.a, state)
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
return self.fb
end
function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
return iterate(self.b, state)
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = self.fa, self.fb
return iterate(self, State{F1,F2}(a_result, b_result))
end
@inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2},
state::State{F1,F2}) where {T,A,B,O,F1,F2}
a_result, b_result = state.a, state.b
if b_result === nothing
a_result === nothing && return nothing
a_curr, a_state = a_result
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
b_curr, b_state = b_result
if a_result !== nothing
a_curr, a_state = a_result
Base.Order.lt(self.order, a_curr, b_curr) &&
return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
end
return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
end
And now you have:
julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
julia> sum(x)
269
julia> @allocated sum(x)
0
julia> @code_warntype iterate(x, iterate(x)[2])
Variables
#self#::Core.Const(iterate)
self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
@_4::Int64
@_5::Int64
@_6::Int64
b_state::Int64
b_curr::Int64
a_state::Int64
a_curr::Int64
b_result::Union{Nothing, Tuple{Int64, Int64}}
a_result::Union{Nothing, Tuple{Int64, Int64}}
Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
1 ─ nothing
│ Core.NewvarNode(:(@_4))
│ Core.NewvarNode(:(@_5))
│ Core.NewvarNode(:(@_6))
│ Core.NewvarNode(:(b_state))
│ Core.NewvarNode(:(b_curr))
│ Core.NewvarNode(:(a_state))
│ Core.NewvarNode(:(a_curr))
│ %9 = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
│ %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
│ (a_result = %9)
│ (b_result = %10)
│ %13 = (b_result === Main.nothing)::Bool
└── goto #5 if not %13
2 ─ %15 = (a_result === Main.nothing)::Bool
└── goto #4 if not %15
3 ─ return Main.nothing
4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%18, 1))
│ (@_6 = Core.getfield(%18, 2))
│ %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%21, 1))
│ %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %25 = Base.getproperty(self, :a)::Vector{Int64}
│ %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %28
5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (b_curr = Core.getfield(%30, 1))
│ (@_5 = Core.getfield(%30, 2))
│ %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (b_state = Core.getfield(%33, 1))
│ %35 = (a_result !== Main.nothing)::Bool
└── goto #8 if not %35
6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│ (a_curr = Core.getfield(%37, 1))
│ (@_4 = Core.getfield(%37, 2))
│ %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
│ (a_state = Core.getfield(%40, 1))
│ %42 = Base.Order::Core.Const(Base.Order)
│ %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
│ %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
│ %45 = a_curr::Int64
│ %46 = (%43)(%44, %45, b_curr)::Bool
└── goto #8 if not %46
7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
│ %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %50 = Base.getproperty(self, :a)::Vector{Int64}
│ %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %53
8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
│ %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
│ %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
│ %58 = Base.getproperty(self, :b)::Vector{Int64}
│ %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
│ %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
│ %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
└── return %61
EDIT: now I have realized that my implementation is not fully correct, as it assumes that the return value of iterate
if it is not nothing
is type stable (which it does not have to be). But if it is not type stable then compiler must allocate. So a fully correct solution would first check if iterate is type stable. If it is - use my solution, and if it is not - use e.g. your solution.
Upvotes: 3