giangian
giangian

Reputation: 111

Argmax in symengine or equivalent alternative

I am working on a simple simulation of a network of nonlinear systems. In particular I have N nodes, each one composed of m units. The output function of each unit depends both on its activity and the activity of the other units in the same node.

The simulation I implemented is in scipy + jitcode.

The first version I implemented was according to a softmax distribution, therefore I implemented this simple function to compute the output of each unit.

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    sum_hc = 0
    for unit in node:
        sum_hc += symengine.exp(unit * G)
    for unit in node:
        act.append(symengine.exp(unit * G)/sum_hc)
return act

Now, I would like to replace the function above with a simple one that, for each node, outputs 1 for the unit with highest activity and 0 in the other units. Long story short, for each node, there one and only one unit outputting 1.

The main issue I am facing right now is how to do this with symengine so that it can be used by jitcode. The function I am implementing below doesn't work for obvious reason. The if condition is not very symbolic I guess.

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    max_act = symengine.Max(*node)
    for unit in node:
        if unit >= max_act:
            act.append(1)
        else:
            act.append(0)           
return act

I didn't find any symengine.argmax() function or any smart alternative solution. Do you have any suggestions?

UPDATES

def max_activation(activities):
    act = []

for hc in activities:
    sum_hc = 0
    max_act = symengine.Max(*hc)
    for mc in hc:
        act.append(symengine.GreaterThan(mc, max_act))
    print(act)
return act

Testing this function:

    max_activation([[y(1), y(2)], [y(3), y(4)]])

I get the following output that is somehow promising. I will update as soon as I have some tests.

[max(y(2), y(1)) <= y(1), max(y(2), y(1)) <= y(2)]

[max(y(4), y(3)) <= y(3), max(y(4), y(3)) <= y(4)]

Upvotes: 1

Views: 87

Answers (0)

Related Questions