Reputation: 6194
I have the following code in Julia, in which the literal constant 2.
does a multiplication on array elements. I made the literal constant now single precision (2.f0
), but I would like to let the type depend on the other variables (these are either all Float64
or all Float32
). How do I do this in an elegant way?
function diff!(
at, a,
visc, dxidxi, dyidyi, dzidzi,
itot, jtot, ktot)
@tturbo for k in 2:ktot-1
for j in 2:jtot-1
for i in 2:itot-1
at[i, j, k] += visc * (
(a[i-1, j , k ] - 2.f0 * a[i, j, k] + a[i+1, j , k ]) * dxidxi +
(a[i , j-1, k ] - 2.f0 * a[i, j, k] + a[i , j+1, k ]) * dyidyi +
(a[i , j , k-1] - 2.f0 * a[i, j, k] + a[i , j , k+1]) * dzidzi )
end
end
end
end
Upvotes: 2
Views: 128
Reputation: 20950
There's a nice little function in Base
:
help?> oftype
search: oftype
oftype(x, y)
Convert y to the type of x (convert(typeof(x), y)).
Examples
≡≡≡≡≡≡≡≡≡≡
julia> x = 4;
julia> y = 3.;
julia> oftype(x, y)
3
julia> oftype(y, x)
4.0
So you could use something like
two = oftype(at[i,j,k], 2)
in the appropriate place.
For multiple variables at once, you could write something like
two, visc, dxidxi, dyidyi, dzidzi = convert.(T, 2, visc, dxidxi, dyidyi, dzidzi)
at the top (with T
a type parameter as in @cbk's answer), since oftype(x, y) = convert(typeof(x), y)
.
Upvotes: 3
Reputation: 4370
In general, if you have a scalar x
or an array A
, you can get the type with T = typeof(x)
or T = eltype(A)
, respectively, and then use that to convert a literal to the equivalent type, e.g.
julia> A = [1.0]
1-element Vector{Float64}:
1.0
julia> T = eltype(A)
Float64
julia> T(2)
2.0
So you could in principle use that within the function, and if everything is type-stable, this should actually be overhead-free:
julia> @code_native 2 * 1.0f0
.section __TEXT,__text,regular,pure_instructions
; ┌ @ promotion.jl:322 within `*'
; │┌ @ promotion.jl:292 within `promote'
; ││┌ @ promotion.jl:269 within `_promote'
; │││┌ @ number.jl:7 within `convert'
; ││││┌ @ float.jl:94 within `Float32'
vcvtsi2ss %rdi, %xmm1, %xmm1
; │└└└└
; │ @ promotion.jl:322 within `*' @ float.jl:331
vmulss %xmm0, %xmm1, %xmm0
; │ @ promotion.jl:322 within `*'
retq
nopw (%rax,%rax)
; └
julia> @code_native 2.0f0 * 1.0f0
.section __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
vmulss %xmm1, %xmm0, %xmm0
retq
nopw %cs:(%rax,%rax)
; └
julia> @code_native Float32(2) * 1.0f0
.section __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
vmulss %xmm1, %xmm0, %xmm0
retq
nopw %cs:(%rax,%rax)
; └
As it happens however, there is a somewhat more elegant pattern in Julia for writing a function signature such that it will specialize parametrically on the element type of the arrays you are passing to this function, which you should then be able to use without overhead to ensure your literals are of the appropriate type as follows:
function diff!(at::AbstractArray{T}, a::AbstractArray{T},
visc, dxidxi, dyidyi, dzidzi,
itot, jtot, ktot) where T <: Number
@tturbo for k in 2:ktot-1
for j in 2:jtot-1
for i in 2:itot-1
at[i, j, k] += visc * (
(a[i-1, j , k ] - T(2) * a[i, j, k] + a[i+1, j , k ]) * dxidxi +
(a[i , j-1, k ] - T(2) * a[i, j, k] + a[i , j+1, k ]) * dyidyi +
(a[i , j , k-1] - T(2) * a[i, j, k] + a[i , j , k+1]) * dzidzi )
end
end
end
end
This sort of approach is discussed to some extent in the documentation regarding parametric methods in Julia
Upvotes: 4