Reputation: 111
I am working on a simple simulation of a network of nonlinear systems. In particular I have N nodes, each one composed of m units. The output function of each unit depends both on its activity and the activity of the other units in the same node.
The simulation I implemented is in scipy + jitcode.
The first version I implemented was according to a softmax distribution, therefore I implemented this simple function to compute the output of each unit.
def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
sum_hc = 0
for unit in node:
sum_hc += symengine.exp(unit * G)
for unit in node:
act.append(symengine.exp(unit * G)/sum_hc)
return act
Now, I would like to replace the function above with a simple one that, for each node, outputs 1 for the unit with highest activity and 0 in the other units. Long story short, for each node, there one and only one unit outputting 1.
The main issue I am facing right now is how to do this with symengine so that it can be used by jitcode. The function I am implementing below doesn't work for obvious reason. The if condition is not very symbolic I guess.
def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
max_act = symengine.Max(*node)
for unit in node:
if unit >= max_act:
act.append(1)
else:
act.append(0)
return act
I didn't find any symengine.argmax() function or any smart alternative solution. Do you have any suggestions?
UPDATES
def max_activation(activities):
act = []
for hc in activities:
sum_hc = 0
max_act = symengine.Max(*hc)
for mc in hc:
act.append(symengine.GreaterThan(mc, max_act))
print(act)
return act
Testing this function:
max_activation([[y(1), y(2)], [y(3), y(4)]])
I get the following output that is somehow promising. I will update as soon as I have some tests.
[max(y(2), y(1)) <= y(1), max(y(2), y(1)) <= y(2)]
[max(y(4), y(3)) <= y(3), max(y(4), y(3)) <= y(4)]
Upvotes: 1
Views: 87