Litchy
Litchy

Reputation: 363

How to transform a list of 1-D ndarray to 2-D ndarray (mxnet ndarray)

In this example, I have a list of 1-d ndarray, with length 9, the list has 9 elements, and each one has shape=(2048,), so totally 9 * (2048,), I get these ndarray from mxnet so that each of the ndarray is <NDArray 2048 @cpu(0)> the array dtype=numpy.float32

If I use np.asarray to transform this list, it becomes the following result

shape=<class 'tuple'>: (9, 2048, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)

Obviously, I want a 2-D array, with shape=(9, 2048), how to solve this problem?

ps: I discover this problem by saving a npy file and load it. I directly saved the list before converting it to a ndarray (so the np.save would transform the list to the ndarrary automatically) and after I loaded it, I found the shape has become something above, which is really abnormal

The answer below, np.vstack and np.array both works for the common list to ndarray problem but could not solve mine, so I doubt it is some special case of mxnet

Upvotes: 2

Views: 768

Answers (2)

Litchy
Litchy

Reputation: 363

Since the guy who gives the correct answer as comment solve my problem but did not post an answer, I would post his answer here for the others who may also encounter this problem

In fact, the np.array and mxnet.ndarray are not exactly the same, so it is dangerous to directly call numpy methods on mxnet.ndarray. To use numpy method in mxnet.ndarray, we should first transform the array to np.array, which is

mx_ndarray = mxnet.ndarray.zeros(5)
np_array = mx_ndarray.asnumpy() 

Then numpy methods could be used on np_array

Since the above answer is more general(np.vstack()), I accept it and just post this answer as a reference, also, np.array() does the same thing in the above example with np.vstack()

Upvotes: 0

Matt Messersmith
Matt Messersmith

Reputation: 13747

You can use np.vstack. Here's an example:

import numpy as np

li = [np.zeros(2048) for _ in range(9)]
result = np.vstack(li)
print(result.shape)

This outputs (9, 2048) as desired.

Upvotes: 2

Related Questions