symengine或同等替代品中的Argmax

时间:2019-04-11 13:51:20

标签: max simulation jit argmax symengine

我正在研究非线性系统网络的简单仿真。特别是我有N个节点,每个节点由m个单元组成。每个单元的输出功能取决于它的活动以及同一节点中其他单元的活动。

我实现的模拟是scipy + jitcode。

我实现的第一个版本是根据softmax发行版本, 因此,我实现了这个简单的函数来计算每个单元的输出。

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    sum_hc = 0
    for unit in node:
        sum_hc += symengine.exp(unit * G)
    for unit in node:
        act.append(symengine.exp(unit * G)/sum_hc)
return act

现在,我想用一个简单的函数替换上面的函数,对于每个节点,该函数为活动性最高的单元输出1,在其他单元中输出0。长话短说,对于每个节点,只有一个单位输出1。

我现在面临的主要问题是如何使用symengine做到这一点,以便jitcode可以使用它。我在下面实现的功能由于明显的原因无法正常工作。我猜if条件不是很象征。

def soft_max(node_activities):
"""
This function computes the output of all the mini-columns
:param nodes_activities: Activities of the minicolumns grouped in nested lists
:return: One unique list with all the outputs
"""
G = 10
act = []
for node in nodes_activities:
    max_act = symengine.Max(*node)
    for unit in node:
        if unit >= max_act:
            act.append(1)
        else:
            act.append(0)           
return act

我没有找到任何symengine.argmax()函数或任何智能的替代解决方案。你有什么建议吗?

更新

def max_activation(activities):
    act = []

for hc in activities:
    sum_hc = 0
    max_act = symengine.Max(*hc)
    for mc in hc:
        act.append(symengine.GreaterThan(mc, max_act))
    print(act)
return act

测试此功能:

    max_activation([[y(1), y(2)], [y(3), y(4)]])

我得到以下输出,该输出在某种程度上很有希望。一旦进行一些测试,我将立即更新。

  

[max(y(2),y(1))<= y(1),max(y(2),y(1))<= y(2)]

     

[max(y(4),y(3))<= y(3),max(y(4),y(3))<= y(4)]

0 个答案:

没有答案