Reputation: 132
I am puzzled about the superior performance of matrix multiplication versus for loop in my application. Similar questions have been asked a few times on StackOverflow already (e.g., here and here), but I couldn't get to the bottom of it with the provided answers.
Here's a script to reproduce my problem (using Julia 1.7.3)
# Packages and seeds
using Optim, BenchmarkTools, Random
Random.seed!(1704)
# Parameters
N = 10
M = 11
σ = 3.0
ID1 = repeat(1:N, inner = M)
ID2 = repeat(1:M, inner = N)
A = rand(N * M)
w = rand(M)
dv = dummyvar(ID2)
x0 = ones(N*M)
# Function 1
function function_1(x, dv, w, A, σ)
c = dv * w ./ A
s = x.^(1-σ) ./ (dv * (dv' * x.^(1-σ)))
el = σ*(1 .- s) .+ s
p = (c.*el) ./ (el .- 1)
return sum(abs2.(x .- p))
end
# Function 2
function function_2(x, dv, w, A, M, ID2, σ)
f = 0.0
for m = 1:M
c = w[m] ./ A[ID2 .== m]
s = x[ID2 .== m].^(1-σ) ./ sum(x[ID2 .== m].^(1-σ))
el = σ*(1 .- s) .+ s
f += sum(abs2.(x[ID2 .== m] .- (c.*el) ./ (el .- 1)))
end
return f
end
and dummyvar
is a function that creates a matrix of indices from a vector (similar to the Matlab dummyvar function)
function dummyvar(input_mat::Vector)
# Initialize
n1 = length(input_mat)
n2 = length(unique(input_mat))
# Preallocate output
mat = zeros(n1, n2)
# Fill in output
for i in 1:length(unique(input_mat))
mat[:,i] = input_mat .== unique(input_mat)[i]
end
# Return
return mat
end
Here are the benchmark results for function_1
and here are those of function_2
function_1
allocates significantly less memory and takes less to execute than function_2
. However, I expected to find the opposite pattern since function_1
uses matrix multiplication to compute c
and s
whereas function_2
uses loops. I wonder if pre-allocating f
every time function_2
is called explains the difference.
Can function_2
be rewritten to outperform function_1
in terms of execution time and memory allocation?
Upvotes: 1
Views: 659
Reputation: 2301
take this for example
w[m] ./ A[ID2 .== m]
this allocates the following intermediate arrays
ID2 .== m
A[ID2 .== m]
w[m] ./ A[ID2 .== m]
3 is irreducible because you need this for c
, but 1 and 2 are not.
Apply this logic to each item in your for-loop.
Upvotes: 2