Python中的二进制树计算范围内的节点

时间:2017-10-22 18:38:44

标签: python python-3.x tree binary-tree

我有一个简单的二叉树实现:

class Node:
    def __init__(self, item, left = None, right = None):
        self.item = item
        self.left = left
        self.right = right

class BST:
    def __init__(self):
        self.root = None

    def add(self, item):
        if self.root == None:
            self.root = Node(item, None, None)
        else:
            child_tree = self.root
            while child_tree != None:
                parent = child_tree
                if item < child_tree.item:
                    child_tree = child_tree.left
                else:
                    child_tree = child_tree.right
            if item < parent.item:
                parent.left = Node(item, None, None)
            elif item > parent.item:
                parent.right = Node(item, None, None)

我想添加count(lo,hi)方法,它计算范围内的所有节点(lo,hi)(包括hi)这是我到目前为止所拥有的:

def count(self, lo, hi, ptr='lol', count=0):
    if ptr == 'lol':
        ptr = self.root
    if ptr.left != None:
        if ptr.item >= lo and ptr.item <= hi:
            count += 1
        ptr.left = self.count(lo, hi, ptr.left, count)
    if ptr.right != None:
        if ptr.item >= lo and ptr.item <= hi:
            count += 1
        ptr.right = self.count(lo, hi, ptr.right, count)
    return count

当二叉树右倾或左倾时,它似乎才起作用。它不适用于平衡树,我不知道为什么。我的意见是:

bst = BST()
for ele in [10, 150, 80, 40, 20, 10, 30, 60, 50, 70, 120, 100, 90, 110, 140, 130, 150]:
    bst.add(ele)
print(bst.count(30, 100))

我的代码为我提供了output: 0,但应该说output: 8。你能告诉我哪里出错了吗?

1 个答案:

答案 0 :(得分:2)

错误的部分:

   while child_tree != None:
        if child_tree.item >= lo and child_tree.item <= hi:
            count += 1
        if hi > child_tree.item:  # from here
            child_tree = child_tree.right
        else:
            child_tree = child_tree.left . # to here

如果child_tree介于低和高之间,你应该递归地迭代左右两个孩子 - 你只迭代正确的孩子。

提示:既然你需要检查右边和左边的孩子,应该有一个递归的呼叫......

<强>更新

def count(self, lo, hi, ptr, count=0):
    if not ptr:
        return 0
    elif lo <= ptr.item <= hi:
        return 1 + self.count(lo, hi, ptr.left, count) + \
               self.count(lo, hi, ptr.right, count)
    elif ptr.item < lo:
        return self.count(lo, hi, ptr.right, count)
    elif ptr.item > hi:
        return self.count(lo, hi, ptr.left, count)