Karamos
Karamos

Reputation: 49

Keep largest & smallest elements in each row of a matrix in Matlab

I am trying to keep the highest and lowest values in each row of a matrix using Matlab. For example, I have the initial matrix:

A=
[-6 1   3   9   2   -1;
3   -3  6   5   0   -8;
5   10  9   3   2   1;
20  2   -1  4   9   -4;
4   6   -2  2   7   9;
10  5   -3  3   1   4]

and I need to keep the largest & smallest elements in each row of the matrix (let's say 2 largest and 2 smallest), so that the final matrix looks like this:

B=
[-6 0   3   9   0   -1;
0   -3  6   5   0   -8;
0   10  9   0   2   1;
20  0   -1  0   9   -4;
0   0   -2  2   7   9;
10  5   -3  0   1   0]

I tried sort along with ismember but without any luck.

How can I keep the largest & smallest elements in each row of the input matrix, and remove the rest of the elements in each row?

Upvotes: 1

Views: 153

Answers (2)

nirvana-msu
nirvana-msu

Reputation: 4077

There is a very elegant way to do this. We need to use the second output of sort, which gives the order of elements across a particular dimension. Caveat is that it gives us the inverse permutation to what we need, so we have to repeat it twice, recursively. All we need then is to set all those elements who's indices are more than 2 and less than size(A,1)-1 to zero using logical indexing.

Also, note that its much easier to work with numerical matrices, not cells as in your example. You can use cell2mat and num2cell to convert back and forth if you need to.

A = [-6, 1, 3, 9, 2, -1; ...
    3, -3, 6, 5, 0, -8; ...
    5, 10, 9, 3, 2, 1; ...
    20, 2, -1, 4, 9, -4; ...
    4, 6. -2, 2, 7, 9; ...
    10, 5, -3, 3, 1, 4];

[~, idx_inv] = sort(A,2);
[~, idx] = sort(idx_inv,2);
toRemove = idx>2 & (idx<size(A,1)-1);
B = A; B(toRemove) = 0;

>> B
B =
    -6     0     3     9     0    -1
     0    -3     6     5     0    -8
     0    10     9     0     2     1
    20     0    -1     0     9    -4
     0     0    -2     2     7     9
    10     5    -3     0     1     0

EDIT Comparing performance to the option suggested by @luis-mendo:

K>> tic; for t = 1:1e6, [~, idx_inv] = sort(A,2); [~, idx] = sort(idx_inv,2); B = A; B(idx>2 & (idx<size(A,1)-1)) = 0; end; toc;
Elapsed time is 12.379617 seconds.

K>> tic; for t = 1:1e6, s = sum(bsxfun(@gt, A, permute(A, [1 3 2])), 3); B = A.*(s<2 | s>size(A,2)-3); end; toc;
Elapsed time is 17.630724 seconds.

This solution using sort seems to be ~1.42x faster.

Upvotes: 2

Luis Mendo
Luis Mendo

Reputation: 112769

Here's another approach using bsxfun. Probably not very efficient.

Let your data be defined as:

A = [ -6     1     3     9     2    -1
       3    -3     6     5     0    -8
       5    10     9     3     2     1
      20     2    -1     4     9    -4
       4     6    -2     2     7     9
      10     5    -3     3     1     4 ];
m = 2;
M = 2;

Then:

s = sum(bsxfun(@gt, A, permute(A, [1 3 2])), 3);  % for each element, compute how many
                                                  % elements in its row it exceeds
result = A.*(s<m | s>size(A,2)-M-1);              % apply a mask based on that

Upvotes: 1

Related Questions