miha priimek
miha priimek

Reputation: 979

How to dispatch based on the type of any of the splatted args?

Consider an existing function in Base, which takes in a variable number of arguments of some abstract type T. I have defined a subtype S<:T and would like to write a method which dispatches if any of the arguments is my subtype S.

As an example, consider function Base.cat, with T being an AbstractArray and S being some MyCustomArray <: AbstractArray.

Desired behaviour:

julia> v = [1, 2, 3];

julia> cat(v, v, v, dims=2)
3×3 Array{Int64,2}:
 1  1  1
 2  2  2
 3  3  3

julia> w = MyCustomArray([1,2,3])

julia> cat(v, v, w, dims=2)
"do something fancy"

Attempt:

function Base.cat(w::MyCustomArray, a::AbstractArray...; dims)
    pritnln("do something fancy")
end

But this only works if the first argument is MyCustomArray.

What is an elegant way of achieving this?

Upvotes: 2

Views: 275

Answers (2)

phipsgabler
phipsgabler

Reputation: 20970

For general usage, I concur with Bogumił, but let me make an additional comment. If you have control over how cat is called, you can at least write some kind of trait-dispatch code:

struct MyCustomArray{T, N} <: AbstractArray{T, N}
    x::Array{T, N}
end

HasCustom() = Val(false)
HasCustom(::MyCustomArray, rest...) = Val(true)
HasCustom(::AbstractArray, rest...) = HasCustom(rest...)

# `IsCustom` or something would be more elegant, but `Val` is quicker for now
Base.cat(::Val{true}, args...; dims) = println("something fancy")
Base.cat(::Val{false}, args...; dims) = cat(args...; dims=dims)

And the compiler is cool enough to optimize that away:

julia> args = (v, v, w);

julia> @code_warntype cat(HasCustom(args...), args...; dims=2);
Variables
  #self#::Core.Compiler.Const(cat, false)
  #unused#::Core.Compiler.Const(Val{true}(), false)
  args::Tuple{Array{Int64,1},Array{Int64,1},MyCustomArray{Int64,1}}

Body::Nothing
1 ─ %1 = Main.println("something fancy")::Core.Compiler.Const(nothing, false)
└──      return %1

If you don't have control over calls to cat, the only resort I can think of to make the above technique work is to overdub methods containing such call, to replace matching calls by the custom implementation. In which case you don't even need to overload cat, but can directly replace it by some mycat doing your fancy stuff.

Upvotes: 1

Bogumił Kamiński
Bogumił Kamiński

Reputation: 69949

I would say that it is not possible to do it cleanly without type piracy (but if it is possible I would also like to learn how).

For example consider cat that you asked about. It has one very general signature in Base (actually not requiring A to be AbstractArray as you write):

julia> methods(cat)
# 1 method for generic function "cat":
[1] cat(A...; dims) in Base at abstractarray.jl:1654

You could write a specific method:

Base.cat(A::AbstractArray...; dims) = ...

and check if any of elements of A is your special array, but this would be type piracy.

Now the problem is that you cannot even write Union{S, T} as since S <: T it will be resolved as just T.

This would mean that you would have to use S explicitly in the signature, but then even:

f(::S, ::T) = ...
f(::T, ::S) = ...

is problematic and a compiler will ask you to define f(::S, ::S) as the above definitions lead to dispatch ambiguity. So, even if you wanted to limit the number of varargs to some maximum number you would have to annotate types for all divisions of A into subsets to avoid dispatch ambiguity (which is doable using macros, but grows the number of required methods exponentially).

Upvotes: 2

Related Questions