Effective_cellist
Effective_cellist

Reputation: 1128

Expanding array along a dimension

The following Python function expands an array along a dimension

def expand(x, dim,copies):
        trans_cmd = list(range(0,len(x.shape)))
        trans_cmd.insert(dim,len(x.shape))
        new_data = x.repeat(copies).reshape(list(x.shape) + [copies]).transpose(trans_cmd)
        return new_data

>>x = np.array([[1,2,3],[4,5,6]])
>>x.shape
(2, 3)
>>x_new = expand(x,2,4)
>>x_new.shape
(2,3,4)
>>x_new
array([[[1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3]],

       [[4, 4, 4, 4],
        [5, 5, 5, 5],
        [6, 6, 6, 6]]])

How can this function be replicated in Julia?

Upvotes: 2

Views: 747

Answers (1)

mbauman
mbauman

Reputation: 31362

Instead of doing a repeat -> reshape -> permutedims (like you have in numpy), I'd just do a reshape -> repeat. Accounting for the translation of row-major to column-major, it looks like this:

julia> x = [[1,2,3] [4,5,6]]
3×2 Array{Int64,2}:
 1  4
 2  5
 3  6

julia> repeat(reshape(x, (1, 3, 2)), outer=(4,1,1))
4×3×2 Array{Int64,3}:
[:, :, 1] =
 1  2  3
 1  2  3
 1  2  3
 1  2  3

[:, :, 2] =
 4  5  6
 4  5  6
 4  5  6
 4  5  6

The tricky part about doing this efficiently in Julia is the construction of the tuples (1, 3, 2) and (4, 1, 1). While you could convert them to arrays and use array mutation (like you convert it to a python list), it's far more efficient to keep them as tuples. ntuple is your friend here:

julia> function expand(x, dim, copies)
           sz = size(x)
           rep = ntuple(d->d==dim ? copies : 1, length(sz)+1)
           new_size = ntuple(d->d<dim ? sz[d] : d == dim ? 1 : sz[d-1], length(sz)+1)
           return repeat(reshape(x, new_size), outer=rep)
       end
expand (generic function with 1 method)

julia> expand(x, 1, 4)
4×3×2 Array{Int64,3}:
[:, :, 1] =
 1  2  3
 1  2  3
 1  2  3
 1  2  3

[:, :, 2] =
 4  5  6
 4  5  6
 4  5  6
 4  5  6

julia> expand(x, 2, 4)
3×4×2 Array{Int64,3}:
[:, :, 1] =
 1  1  1  1
 2  2  2  2
 3  3  3  3

[:, :, 2] =
 4  4  4  4
 5  5  5  5
 6  6  6  6

julia> expand(x, 3, 4)
3×2×4 Array{Int64,3}:
[:, :, 1] =
 1  4
 2  5
 3  6

[:, :, 2] =
 1  4
 2  5
 3  6

[:, :, 3] =
 1  4
 2  5
 3  6

[:, :, 4] =
 1  4
 2  5
 3  6

Upvotes: 2

Related Questions