求和连续的非零张量值

时间:2018-12-20 19:43:37

标签: python deep-learning pytorch

我正在尝试找到连续的非零张量值的总和,如下所示

比方说,我有一个张量A = [1.3, 0.0, 0.6, 0.7, 0.8]。我想

1)对张量的连续个非零值求和以输出[1.3, 0.0, 2.1],然后选择最大值2.1

2)还要找到用于对这些值求和的索引。在这种情况下,它将是2, 3, 4

2 个答案:

答案 0 :(得分:0)

我们可以通过两个简单的步骤解决此问题:

首先,基于零分割主张量。因此,给定的张量A看起来像[ [1.3], [0], [0.6, 0.7, 0.8] ]。可以使用以下功能完成此操作:

def split_list(lst, value=0):
    """
    Splits a given list based on a given value
    default is zero
    """
    groups = []
    sub_group = []
    for i in lst:
        if i == 0:
            groups.append(sub_group)
            sub_group = []
            groups.append([0])
        else:
            sub_group.append(i)
    if sub_group:
        groups.append(sub_group)
    return groups

第二,对每个子组求和。返回的索引会有些棘手。因此,让我们在代码中看到它:

def get_max_indices(groups):
    """
    This function takes a list of lists and 
    returns the indices of the maximum elements
    """
    maximum = 0
    max_length = 0
    total_elements = 0
    length_before = 0
    for idx, sub_group in enumerate(groups):
        summation = sum(sub_group)
        if summation > maximum:
            maximum = summation
            max_length = len(sub_group)
            length_before = total_elements
        total_elements += len(sub_group)
    return [_ for _ in range(length_before, length_before+max_length)]

现在,让我们同时尝试它们:

>>> lst = [1.3, 0, 0.6, 0.7, 0.8]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[2, 3, 4]

让我们尝试另一个示例:

>>> lst = [1, 2, 3, 0, 6, 9, 0, 10]
>>> groups = split_list(lst, value=0)
>>> print(get_max_indices(groups))
[4, 5]

我希望这可以解决您的问题。我知道这比您想象的要复杂一些,但这将使您入门。我认为可以对此进行优化和清理,但我将留给您。

答案 1 :(得分:0)

我的方法与@Anwarvic有所不同。我试图一口气做到这一点。请参见下面的功能。它遍历数组,并保持到目前为止已看到的最大值和当前和的对数。如果我们为零,则当前总和更新为0;如果非零,则将当前总和更新为当前值。

def find_continguous_max_sum(t): 
    max_,cur=0,0 
    max_indices, cur_indices = [], []

    for idx, i in enumerate(t): 
        if (i==0.0): 
            cur = 0 
            cur_indices = []
        else: 
            cur = cur+i
            cur_indices.append(idx)
        if (max_ < cur): 
            max_ = cur 
            max_indices = cur_indices 

    return max_, max_indices

#usage
find_contiguous_max_sum(A)
相关问题