Reputation: 3348
I am trying to write a function that reduces a numpy ndarray to a given shape, which effectively "unbroadcasts" the array. For example, using add as the universal function, i want the following results:
This is the solution i came up with:
def unbroadcast(A, reduced_shape):
fill = reduced_shape[-1] if reduced_shape else None
reduced_axes = tuple(i for i, (a,b) in enumerate(itertools.zip_longest(A, shape, fillvalue=fill)) if a!=b)
return np.add.reduce(A, axis=reduced_axes).reshape(shape)
But it feels unnecessarily complex, is there way to implement this that relies on Numpy's public API?
Upvotes: 0
Views: 441
Reputation: 1
You can come up with a much more generic solution if you iterate from the right most dims. This is what numpy do when they broadcast: https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules.
just add itertools.zip_longest(reversed(list(A)), reversed(list(shape)))
Upvotes: -1
Reputation: 231335
It's not clear how this is an 'un-broadcasting'.
The straight forward way of doing your calculations is to use the axis
parameter of sum
:
In [124]: np.array([1,2,3,4,5]).sum()
Out[124]: 15
In [125]: np.array([[1,2],[1,2]]).sum(axis=0)
Out[125]: array([2, 4])
In [126]: np.array([[1,2],[1,2]]).sum(axis=(0,1))
Out[126]: 6
In [128]: np.array([[[1,2], [1,2]], [[1,2], [1,2]]]).sum(axis=0)
Out[128]:
array([[2, 4],
[2, 4]])
I don't use reduce
as much, but looks like axis does the same:
In [130]: np.add.reduce(np.array([1,2,3,4,5]))
Out[130]: 15
In [132]: np.add.reduce(np.array([[[1,2], [1,2]], [[1,2], [1,2]]]),axis=0)
Out[132]:
array([[2, 4],
[2, 4]])
But I haven't worked out the logic of going from your reduced_shape
to the necessary axis
values. With shapes like (2,) and (2,2,2), there's potential ambiguity when you say reduce the shape to (2,2). It might be clearer if you worked with samples arrays like np.arange(24).reshape(2,3,4)
Upvotes: 1