Chiel
Chiel

Reputation: 6194

Make type of literal constant depend on other variables

I have the following code in Julia, in which the literal constant 2. does a multiplication on array elements. I made the literal constant now single precision (2.f0), but I would like to let the type depend on the other variables (these are either all Float64 or all Float32). How do I do this in an elegant way?

function diff!(
        at, a,
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot)
​
    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - 2.f0 * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - 2.f0 * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - 2.f0 * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

Upvotes: 2

Views: 128

Answers (2)

phipsgabler
phipsgabler

Reputation: 20950

There's a nice little function in Base:

help?> oftype
search: oftype

  oftype(x, y)

  Convert y to the type of x (convert(typeof(x), y)).

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  julia> x = 4;
  
  julia> y = 3.;
  
  julia> oftype(x, y)
  3
  
  julia> oftype(y, x)
  4.0

So you could use something like

two = oftype(at[i,j,k], 2)

in the appropriate place.

For multiple variables at once, you could write something like

two, visc, dxidxi, dyidyi, dzidzi = convert.(T, 2, visc, dxidxi, dyidyi, dzidzi)

at the top (with T a type parameter as in @cbk's answer), since oftype(x, y) = convert(typeof(x), y).

Upvotes: 3

cbk
cbk

Reputation: 4370

In general, if you have a scalar x or an array A, you can get the type with T = typeof(x) or T = eltype(A), respectively, and then use that to convert a literal to the equivalent type, e.g.

julia> A = [1.0]
1-element Vector{Float64}:
 1.0

julia> T = eltype(A)
Float64

julia> T(2)
2.0

So you could in principle use that within the function, and if everything is type-stable, this should actually be overhead-free:

julia> @code_native 2 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ promotion.jl:322 within `*'
; │┌ @ promotion.jl:292 within `promote'
; ││┌ @ promotion.jl:269 within `_promote'
; │││┌ @ number.jl:7 within `convert'
; ││││┌ @ float.jl:94 within `Float32'
    vcvtsi2ss   %rdi, %xmm1, %xmm1
; │└└└└
; │ @ promotion.jl:322 within `*' @ float.jl:331
    vmulss  %xmm0, %xmm1, %xmm0
; │ @ promotion.jl:322 within `*'
    retq
    nopw    (%rax,%rax)
; └

julia> @code_native 2.0f0 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

julia> @code_native Float32(2) * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

As it happens however, there is a somewhat more elegant pattern in Julia for writing a function signature such that it will specialize parametrically on the element type of the arrays you are passing to this function, which you should then be able to use without overhead to ensure your literals are of the appropriate type as follows:

function diff!(at::AbstractArray{T}, a::AbstractArray{T},
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot) where T <: Number

    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - T(2) * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - T(2) * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - T(2) * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

This sort of approach is discussed to some extent in the documentation regarding parametric methods in Julia

Upvotes: 4

Related Questions