Reputation: 33
I have the following code:
circ(x) = x./sqrt(sum(x .* x))
x -> cat(circ(x), circ(x); dims = 1)
but I want to be able to create a function where I input a number and it concatenates that number of circ(x)s.
so for example:
function Ncircs(n)
#some way to make cat() have as its parameter circ n number of times
end
and I could call Ncircs(2)
and get
x -> cat(circ(x), circ(x); dims = 1)
or Ncircs(3)
and get
x -> cat(circ(x), circ(x), circ(x); dims = 1)
or Ncircs(4)
and get
x -> cat(circ(x), circ(x), circ(x), circ(x); dims = 1)
etc.
Is there a way of doing this? Do I have to use a macro?
Upvotes: 3
Views: 137
Reputation: 69839
You can write:
Ncircs(n) = x -> cat(Iterators.repeated(circ(x), n)...; dims = 1)
and if you know you will be doing dims=1
always then replating cat
with vcat
and reduce
Ncircs(n) = x -> reduce(vcat, Iterators.repeated(circ(x), n))
will be more efficient for large n
.
As a side note: using the other option (vcat
) will produce a type stable result while the first option is not type stable.
EDIT
In general the reason is that then you are not able to tell what should be the result of the reduction. If you want to allow an empty collection you should add init
keyword argument. Here is an example:
julia> reduce(vcat, [])
ERROR: ArgumentError: reducing over an empty collection is not allowed
julia> reduce(vcat, [], init = [1])
1-element Array{Int64,1}:
1
julia> reduce(vcat, [[2,3], [4,5]], init = [1])
5-element Array{Int64,1}:
1
2
3
4
5
It means that Julia is able to tell what is the type of the return value of a function at compilation time (before you execute the code). Type stable code usually runs faster (this is a broad topic though - I recommend you to read the Julia manual to understand it in detail). You can check if the function is type stable using @code_warntype
and Test.@inferred
.
Here let me give you an explanation in your specific case (I truncated some of the output to shorten the answer).
julia> x = [1,2,3]
3-element Array{Int64,1}:
1
2
3
julia> y = [4,5,6]
3-element Array{Int64,1}:
4
5
6
julia> @code_warntype vcat(x,y)
Body::Array{Int64,1}
...
julia> @code_warntype cat(x,y, dims=1)
Body::Any
...
julia> using Test
julia> @inferred vcat(x,y)
6-element Array{Int64,1}:
1
2
3
4
5
6
julia> @inferred cat(x,y, dims=1)
ERROR: return type Array{Int64,1} does not match inferred return type Any
Any
above means that the compiler does not know what will be the type of the answer. The reason is in this case that this type depends on dims
parameter. If it is 1
it will be a vector, if it is 2
it will be a matrix.
n
You can run @which
macro:
julia> @which reduce(vcat, [[1,2,3], [4,5,6]])
reduce(::typeof(vcat), A::AbstractArray{#s72,1} where #s72<:(Union{AbstractArray{T,2}, AbstractArray{T,1}} where T)) in Base at abstractarray.jl:1321
And you see that there is a specialized reduce
method for vcat
.
Now if you run:
@edit reduce(vcat, [[1,2,3], [4,5,6]])
An editor will open and you see that it calls an internal function _typed_vcat
that is optimized for vcat
-ing a lot of arrays. This optimization was introduced because using a splatting like this vcat([[1,2,3], [4,5,6]]...)
is equivalent in the result, but you have to do splatting (the ...
) which in itself has some cost that can be avoided using the reduce
version.
In order to make sure that what I say is true you can do the following benchmark:
julia> using BenchmarkTools
julia> y = [[i] for i in 1:10000];
julia> @benchmark vcat($y...)
BenchmarkTools.Trial:
memory estimate: 156.45 KiB
allocs estimate: 3
--------------
minimum time: 67.200 μs (0.00% GC)
median time: 77.800 μs (0.00% GC)
mean time: 102.804 μs (8.50% GC)
maximum time: 35.179 ms (99.47% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark reduce(vcat, $y)
BenchmarkTools.Trial:
memory estimate: 78.20 KiB
allocs estimate: 2
--------------
minimum time: 67.700 μs (0.00% GC)
median time: 69.700 μs (0.00% GC)
mean time: 82.442 μs (6.39% GC)
maximum time: 32.719 ms (99.58% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark cat($y..., dims=1)
ERROR: StackOverflowError:
And you see that reduce
version is slightly faster than splatting version of vcat
, while cat
simply fails for very large n
(for smaller n
it would work but simply be slower).
Upvotes: 3