Reputation: 21
I have a CSV file with 3 columns, where the third one is an array but written as a string. The table format is the following:
int, int, array
1, 4, [array([-3.98456901e-02, 1.11602008e-01, 6.12380356e-03, -4.49424982e-02,\n 7.13399425e-03, -4.11176607e-02, 8.72574970e-02, 9.94107723e-02])]
0, 2, [array([-3.98456901e-02, 1.11602008e-01, 6.12380356e-03, -4.49424982e-02,\n 7.13399425e-03, -4.11176607e-02, 8.72574970e-02, 9.94107723e-02])]
I can easily load the CSV with pandas using read_csv, and I end up with a data frame but the third column is in a string format and I need to use it as an array. How can I change the third column into an array object?
Upvotes: 2
Views: 1342
Reputation:
This is a common problem when reading/writing to .csv files.
If we have a string as such:
>>> s = "[array([-3.98456901e-02, 1.11602008e-01, 6.12380356e-03, -4.49424982e-02,\n 7.13399425e-03, -4.11176607e-02, 8.72574970e-02, 9.94107723e-02])]"
We can clean it up and convert it to a python list using literal_eval.
>>> from ast import literal_eval
>>> literal_eval(s[7:-2].replace("\n", ""))
>>> [-0.0398456901, 0.111602008, 0.00612380356, -0.0449424982, 0.00713399425, -0.0411176607, 0.087257497, 0.0994107723]
Then you can wrap this in a function.
from ast import literal_eval
def parse_mystring(s : str) -> list:
return literal_eval(s[7:-2].replace("\n", ""))
And apply it to the appropriate column in your dataframe. Lets call the column C.
>>> df.C.apply(parse_mystring)
Upvotes: 2
Reputation: 54635
It's not actually an array. It is a list containing an array. You will have to process this by hand. Read it a line at a time, convert it, make a list, and pass that to pandas. Your CSV file should not have been created that way.
import numpy as np
data = """\
1, 4, [array([-3.98456901e-02, 1.11602008e-01, 6.12380356e-03, -4.49424982e-02, 7.13399425e-03, -4.11176607e-02, 8.72574970e-02, 9.94107723e-02])]
0, 2, [array([-3.98456901e-02, 1.11602008e-01, 6.12380356e-03, -4.49424982e-02, 7.13399425e-03, -4.11176607e-02, 8.72574970e-02, 9.94107723e-02])]"""
rows = []
for line in data.splitlines():
a,b,rest = line.split(',',2)
rows.append( [int(a),int(b)] + list(float(f) for f in rest[9:-3].split(',')))
print(np.array(rows))
Ouput:
[[ 1. 4. -0.03984569 0.11160201 0.0061238 -0.0449425
0.00713399 -0.04111766 0.0872575 0.09941077]
[ 0. 2. -0.03984569 0.11160201 0.0061238 -0.0449425
0.00713399 -0.04111766 0.0872575 0.09941077]]
Upvotes: -1