Asmd
Asmd

Reputation: 21

Load csv file that has a column with a numpy array written as string

I have a CSV file with 3 columns, where the third one is an array but written as a string. The table format is the following:

int, int, array

1, 4, [array([-3.98456901e-02,  1.11602008e-01,  6.12380356e-03, -4.49424982e-02,\n        7.13399425e-03, -4.11176607e-02,  8.72574970e-02,  9.94107723e-02])]
0, 2, [array([-3.98456901e-02,  1.11602008e-01,  6.12380356e-03, -4.49424982e-02,\n        7.13399425e-03, -4.11176607e-02,  8.72574970e-02,  9.94107723e-02])]

I can easily load the CSV with pandas using read_csv, and I end up with a data frame but the third column is in a string format and I need to use it as an array. How can I change the third column into an array object?

Upvotes: 2

Views: 1342

Answers (2)

user10343540
user10343540

Reputation:

This is a common problem when reading/writing to .csv files.

If we have a string as such:

>>> s = "[array([-3.98456901e-02,  1.11602008e-01,  6.12380356e-03, -4.49424982e-02,\n        7.13399425e-03, -4.11176607e-02,  8.72574970e-02,  9.94107723e-02])]"

We can clean it up and convert it to a python list using literal_eval.

>>> from ast import literal_eval
>>> literal_eval(s[7:-2].replace("\n", ""))
>>> [-0.0398456901, 0.111602008, 0.00612380356, -0.0449424982, 0.00713399425, -0.0411176607, 0.087257497, 0.0994107723]

Then you can wrap this in a function.

from ast import literal_eval

def parse_mystring(s : str) -> list:
    return literal_eval(s[7:-2].replace("\n", ""))

And apply it to the appropriate column in your dataframe. Lets call the column C.

>>> df.C.apply(parse_mystring)

Upvotes: 2

Tim Roberts
Tim Roberts

Reputation: 54635

It's not actually an array. It is a list containing an array. You will have to process this by hand. Read it a line at a time, convert it, make a list, and pass that to pandas. Your CSV file should not have been created that way.

import numpy as np

data = """\
1, 4, [array([-3.98456901e-02,  1.11602008e-01,  6.12380356e-03, -4.49424982e-02,        7.13399425e-03, -4.11176607e-02,  8.72574970e-02,  9.94107723e-02])]
0, 2, [array([-3.98456901e-02,  1.11602008e-01,  6.12380356e-03, -4.49424982e-02,        7.13399425e-03, -4.11176607e-02,  8.72574970e-02,  9.94107723e-02])]"""

rows = []
for line in data.splitlines():
    a,b,rest = line.split(',',2)
    rows.append( [int(a),int(b)] + list(float(f) for f in rest[9:-3].split(',')))

print(np.array(rows))

Ouput:

[[ 1.          4.         -0.03984569  0.11160201  0.0061238  -0.0449425
   0.00713399 -0.04111766  0.0872575   0.09941077]
 [ 0.          2.         -0.03984569  0.11160201  0.0061238  -0.0449425
   0.00713399 -0.04111766  0.0872575   0.09941077]]

Upvotes: -1

Related Questions