Frank B.
Frank B.

Reputation: 41

Reading and writing dictionary of NumPy arrays with CSV files

I am using Python 2.7 and networkx to make a spring layout plot of my network. In order to compare different setups, I want to store the positions calculated and used by networkx into a file (my choice at the moment: csv) and read it everytime I am making a new plot.

My code looks like this:

pos_spring = nx.spring_layout(H, pos=fixed_positions, fixed = fixed_nodes, k = 4, weight='passengers')

This line calculates the positions used later for plotting, that I want to store. The dictionary (pos_spring) looks like this:

{1536: array([ 0.53892015,  0.984306  ]), 
1025: array([ 0.12096853,  0.82587976]), 
1030: array([ 0.20388712,  0.7046137 ]),

Writing file:

w = csv.writer(open("Mexico_spring_layout_positions_2.csv", "w"))
for key, val in pos_spring.items():
    w.writerow([key, val])

The content of the file looks like this:

1536,[ 0.51060853  0.80129841]
1025,[ 0.47442269  0.99838177]
1030,[ 0.02952256  0.45073233]

Reading file:

with open('Mexico_spring_layout_positions_2.csv', mode='r') as infile:
    reader = csv.reader(infile)
    pos_spring = dict((rows[0],rows[1]) for rows in reader)

The content of pos_spring now looks like this:

{'2652': '[ 0.78480322  0.103894  ]', 
'1260': '[ 0.8834103   0.82542163]', 
'2969': '[ 0.33044548  0.31282113]',

Somehow this data looks different from the original dictionary, that was stored inside the csv file. What needs to be changed when writing and/or reading the data to fix this issue?

Upvotes: 4

Views: 1183

Answers (1)

jpp
jpp

Reputation: 164643

You can't store NumPy arrays in CSV files and maintain data types. Remember, CSV files can only store text. What you are seeing is a text representation of your NumPy array.

Instead, you can unpack your NumPy array as you write to your csv file:

import csv

d = {1536: np.array([ 0.53892015,  0.984306  ]), 
     1025: np.array([ 0.12096853,  0.82587976]), 
     1030: np.array([ 0.20388712,  0.7046137 ])}

fp = r'C:\temp\out.csv'

with open(fp, 'w', newline='') as fout:
    w = csv.writer(fout)
    for key, val in d.items():
        w.writerow([key, *val])

Then convert back to NumPy when you read back. For this step you can use a dictionary comprehension:

with open(fp, 'r') as fin:
    r = csv.reader(fin)
    res = {int(k): np.array(list(map(float, v))) for k, *v in r}

print(res)

{1536: array([ 0.53892015,  0.984306  ]),
 1025: array([ 0.12096853,  0.82587976]),
 1030: array([ 0.20388712,  0.7046137 ])}

Upvotes: 4

Related Questions