Reputation: 41
I am using Python 2.7 and networkx to make a spring layout plot of my network. In order to compare different setups, I want to store the positions calculated and used by networkx into a file (my choice at the moment: csv) and read it everytime I am making a new plot.
My code looks like this:
pos_spring = nx.spring_layout(H, pos=fixed_positions, fixed = fixed_nodes, k = 4, weight='passengers')
This line calculates the positions used later for plotting, that I want to store. The dictionary (pos_spring
) looks like this:
{1536: array([ 0.53892015, 0.984306 ]),
1025: array([ 0.12096853, 0.82587976]),
1030: array([ 0.20388712, 0.7046137 ]),
Writing file:
w = csv.writer(open("Mexico_spring_layout_positions_2.csv", "w"))
for key, val in pos_spring.items():
w.writerow([key, val])
The content of the file looks like this:
1536,[ 0.51060853 0.80129841]
1025,[ 0.47442269 0.99838177]
1030,[ 0.02952256 0.45073233]
Reading file:
with open('Mexico_spring_layout_positions_2.csv', mode='r') as infile:
reader = csv.reader(infile)
pos_spring = dict((rows[0],rows[1]) for rows in reader)
The content of pos_spring
now looks like this:
{'2652': '[ 0.78480322 0.103894 ]',
'1260': '[ 0.8834103 0.82542163]',
'2969': '[ 0.33044548 0.31282113]',
Somehow this data looks different from the original dictionary, that was stored inside the csv file. What needs to be changed when writing and/or reading the data to fix this issue?
Upvotes: 4
Views: 1183
Reputation: 164643
You can't store NumPy arrays in CSV files and maintain data types. Remember, CSV files can only store text. What you are seeing is a text representation of your NumPy array.
Instead, you can unpack your NumPy array as you write to your csv file:
import csv
d = {1536: np.array([ 0.53892015, 0.984306 ]),
1025: np.array([ 0.12096853, 0.82587976]),
1030: np.array([ 0.20388712, 0.7046137 ])}
fp = r'C:\temp\out.csv'
with open(fp, 'w', newline='') as fout:
w = csv.writer(fout)
for key, val in d.items():
w.writerow([key, *val])
Then convert back to NumPy when you read back. For this step you can use a dictionary comprehension:
with open(fp, 'r') as fin:
r = csv.reader(fin)
res = {int(k): np.array(list(map(float, v))) for k, *v in r}
print(res)
{1536: array([ 0.53892015, 0.984306 ]),
1025: array([ 0.12096853, 0.82587976]),
1030: array([ 0.20388712, 0.7046137 ])}
Upvotes: 4