Reputation: 1972
I have a problem with a dynamic programming solution which I'm trying to implement in matlab and was trying to see if there's a better (run-time-wise) implementation than the one I could come up with.
The problem (all values are in the real): input: let X be a T-by-d matrix, W be a k-by-d matrix and A by a k-by-k matrix. output: Y T-by-1 array s.t for row i in X Y(i) is the number of a row in W which maximizes our goal.
A(i,j) gives us the cost of choosing row j if the previous row we chose was i.
To calculate the weight of the output, for each row i in X we sum the dot-product of the Y(i) row of W and add the relevant cost from A.
Our goal is to maximaize the said weight.
Dynamic solution:
instantiate a k-by-T matrix
Fill the first column of the matrix with the results of dot-producting the first row of X with each row of W
for each of the same columns (denote as i) fill with the dot-producting of the i row of X with each row of W and add the cost of A(j,i) where j is the row index of the cell in previous column with maximum value
backtrack from the last column, each time choosing the row index of the cell with the highest value
Matlab implementation (with instantiation of variables):
T = 8;
d = 10;
k = 20;
X = rand(T,d);
W = rand(k,d);
A = rand(k);
Y = zeros(T,1);
weight_table = zeros(k,T);
weight_table(:,1) = W*X(1,:)';
for t = 2 : T
[~, prev_ind] = max(weight_table(:,t-1));
weight_table(:,t) = W*X(t,:)' + A(:,prev_ind);
end
[~, Y] = max(weight_table);
Upvotes: 1
Views: 873
Reputation: 221514
Since there is data dependency across iterations, I would advise keeping the loop, but pre-calculate few things like the product of W
and transpose of each row of X
. This is done here (showing just the weight_table
calculation part as the rest of the code stays the same as in the original post) -
weight_table = zeros(k,T);
weight_table(:,1) = W*X(1,:)';
WXt = W*X.'; %//' Pre-calculate
for t = 2 : T
[~, prev_ind] = max(weight_table(:,t-1));
weight_table(:,t) = WXt(:,t) + A(:,prev_ind); %// Use pre-calculated values and thus avoid that multiplication across each iteration
end
For bigger inputs like - T = 800; d = 1000; k = 2000;
, I am getting 8-10x
performance improvement with it on my system.
Upvotes: 2