Reputation: 9
I want to create a function which takes a dictionary of lists as a parameter, and two other outputs which represent the key value of the input dictionary.
>>> where_clause_case_1({'a': [1, 2], 'b': [1, 2], 'c':[3, 2]}, 'b', 'c')
{'a': [2], 'b': [2], 'c': [2]}
What the function is supposed to do is if the key value b1 is equal to key value b2 at an index, the function returns the whole dictionary with each value at that certain index, so it is unchanged. But if b1 is not equal to b2 at an index, the function removes that index from all the lists. I get an error saying "builtins.KeyError: 'c'" for this specific example and I don't understand why.
def dictionary_change(my_dict, b1, b2):
i = 0
for key in my_dict:
if(my_dict[b1][i] == my_dict[b2][i]):
my_dict[key][i] = my_dict[key][i]
else:
my_dict[key][i] = []
i += 1
return my_dict
Upvotes: 0
Views: 72
Reputation: 1121972
You loop over the dictionary keys, and increment i
at the same time. This means that for 3 keys, i
loops from 0 to 2, but there are only ever 2 items in each value (0 and 1). This leads to your index error.
Instead, use zip()
to loop over the two lists, record the indices to keep, then apply that to all values in the whole dictionary:
def where_clause_case_1(my_dict, b1, b2):
# build a set of indices to keep
keep = {i for i, (x, y) in enumerate(zip(my_dict[b1], my_dict[b2])) if x == y}
# build a new dictionary with kept indices
return {key: [v for i, v in enumerate(value) if i in keep] for key, value in my_dict.items()}
The last line builds a new dictionary using a dict comprehension ({key_expression: value_expression for variables in sequence}
); basically a loop that builds keys and values for a dictionary. It takes the keys from the my_dict
but alters the values. Each value is built using a list comprehension; another loop. Here we take all original values only when their index is in the set keep
.
Without comprehensions, it'd look like this:
def where_clause_case_1(my_dict, b1, b2):
# build a set of indices to keep
keep = set()
for i, (x, y) in enumerate(zip(my_dict[b1], my_dict[b2])):
if x == y:
keep.add(i)
# build a new dictionary with kept indices
retval = {}
for key, oldvalue in my_dict.items():
retval[key] = newvalue = []
for i, v in enumerate(oldvalue):
if i in keep:
newvalue.append(v)
return retval
Demo:
>>> def where_clause_case_1(my_dict, b1, b2):
... # build a set of indices to keep
... keep = {i for i, (x, y) in enumerate(zip(my_dict[b1], my_dict[b2])) if x == y}
... # build a new dictionary with kept indices
... return {key: [v for i, v in enumerate(value) if i in keep] for key, value in my_dict.items()}
...
>>> where_clause_case_1({'a': [1, 2], 'b': [1, 2], 'c':[3, 2]}, 'b', 'c')
{'a': [2], 'c': [2], 'b': [2]}
>>> where_clause_case_1({'a': [1, 2], 'b': [1, 2], 'c':[3, 2]}, 'a', 'b')
{'a': [1, 2], 'c': [3, 2], 'b': [1, 2]}
Upvotes: 1