漂亮打印二进制搜索树

时间:2018-12-11 14:40:06

标签: python binary-search-tree pretty-print

我正在尝试以“图形”方式打印二进制搜索树(prettyPrint函数会这样做)。 我的node类看起来像这样:

class Node:
    def __init__(self, dataValue):
        self.dataValue = dataValue
        self.leftChild = None
        self.rightChild = None

我的tree类看起来像这样:

class binary_search_tree(Node):
    def __init__(self):
        self.root = None

    def insert(self, value):
        if self.root is None:
            self.root = Node(value)
        else:
            self._insert(value, self.root)

    def _insert(self, value, cur_node):
        if value < cur_node.dataValue:
            if cur_node.leftChild is None:
                cur_node.leftChild = Node(value)
            else:
                self._insert(value, cur_node.leftChild)
        if value > cur_node.dataValue:
            if cur_node.rightChild is None:
                cur_node.rightChild = Node(value)
            else:
                self._insert(value, cur_node.rightChild)

    def print_tree(self, transversal_type):
        if transversal_type == "preorder":
            return self._print_preorder(self.root, "")
        elif transversal_type == "inorder":
            return self._print_inorder(self.root, "")
        elif transversal_type == "postorder":
            return self._print_postorder(self.root, "")
        else:
            print("{} does not exist".format(transversal_type))
            return False

    def _print_tree(self, root, indent, transversal = ""):

        if root is not None:
            self._print_tree(root.rightChild, indent + "   ")
            transversal += indent + str(root.dataValue)
            self._print_tree(root.leftChild, indent + "   ")
        return transversal


    def _print_preorder(self, start, transversal):
        # Root -> Left -> Right
        if start:
            transversal += (str(start.dataValue) + " - ")
            transversal = self._print_preorder(start.leftChild, transversal)
            transversal = self._print_preorder(start.rightChild, transversal)
        return transversal

    def _print_inorder(self, start, transversal):
        #Left -> Root -> Right
        if start:
            transversal = self._print_inorder(start.leftChild, transversal)
            transversal += (str(start.dataValue) + " - ")
            transversal = self._print_inorder(start.rightChild, transversal)
        return transversal

    def _print_postorder(self, start, transversal):
        #Left -> Right -> Root
        if start:
            transversal = self._print_postorder(start.leftChild, transversal)
            transversal = self._print_postorder(start.rightChild, transversal)
            transversal += (str(start.dataValue) + " - ")
        return transversal

    def search(self, value):
        if self.root!=None:
            return self._search(value,self.root)
        else:
            return False

    def _search(self, value, cur_node):
        if value==cur_node.dataValue:
            return True
        elif value < cur_node.dataValue and cur_node.leftChild!=None:
            return self._search(value, cur_node.leftChild)
        elif value > cur_node.dataValue and cur_node.rightChild!=None:
            return self._search(value, cur_node.rightChild)
        return False

    def min_value(self, node):
        current = node
        while(current.leftChild is not None):
            current = current.leftChild
        return current

    def delete_node(self, node, value):
        if node is None:
            return node
        if value < node.dataValue:
            node.leftChild = self.delete_node(node.leftChild, value)
        elif value > node.dataValue:
            node.rightChild = self.delete_node(node.rightChild, value)
        else:
            if node.leftChild is None:
                temp = node.rightChild
                node = None
                return temp
            elif node.rightChild is None:
                temp = node.leftChild
                node = None
                return temp
            temp = self.min_value(node)
            node.dataValue = temp.dataValue
            node.rightChild = self

    def getNumNodes(self):
        if self.root:
            return self._getNumNodes(self.root)
        else:
            return 0

    def _getNumNodes(self, node):
        total = 1
        if node.leftChild:
            total += self._getNumNodes(node.leftChild)
        if node.rightChild:
            total += self._getNumNodes(node.rightChild)
        return total

    def getHeight(self):
        return self._getHeight(self.root)

    def _getHeight(self, node):
        if not node:
            return 0
        else:
            return max(self._getHeight(node.leftChild), self._getHeight(node.rightChild)) + 1

    def fillTree(self, height):
        self._fillTree(self.root, height)

    def _fillTree(self, node, height):
        if height <= 1:
            return
        if node:
            if not node.leftChild: node.leftChild = Node(' ')
            if not node.rightChild: node.rightChild = Node(' ')
            self._fillTree(node.leftChild, height - 1)
            self._fillTree(node.rightChild, height - 1)

    def prettyPrint(self):
        """
        """
        # get height of tree
        total_layers = self.getHeight()

        tree = deepcopy(self)

        tree.fillTree(total_layers)
        # start a queue for BFS
        queue = Queue()
        # add root to queue
        queue.enqueue(tree)  # self = root
        # index for 'generation' or 'layer' of tree
        gen = 1
        # BFS main
        while not queue.isEmpty():
            # copy queue
            #
            copy = Queue()
            while not queue.isEmpty():
                copy.enqueue(queue.dequeue())
            #
            # end copy queue

            first_item_in_layer = True
            edges_string = ""
            extra_spaces_next_node = False

            # modified BFS, layer by layer (gen by gen)
            while not copy.isEmpty():

                root = copy.dequeue()

                # -----------------------------
                # init spacing
                spaces_front = pow(2, total_layers - gen + 1) - 2
                spaces_mid = pow(2, total_layers - gen + 2) - 2
                dash_count = pow(2, total_layers - gen) - 2
                if dash_count < 0:
                    dash_count = 0
                spaces_mid = spaces_mid - (dash_count * 2)
                spaces_front = spaces_front - dash_count
                init_padding = 2
                spaces_front += init_padding
                if first_item_in_layer:
                    edges_string += " " * init_padding
                # ----------------------------->

                # -----------------------------
                # construct edges layer
                edge_sym = "/" if root.leftChild and root.leftChild.data is not " " else " "
                if first_item_in_layer:
                    edges_string += " " * (pow(2, total_layers - gen) - 1) + edge_sym
                else:
                    edges_string += " " * (pow(2, total_layers - gen + 1) + 1) + edge_sym
                edge_sym = "\\" if self.root.rightChild and self.root.rightChild.data is not " " else " "
                edges_string += " " * (pow(2, total_layers - gen + 1) - 3) + edge_sym
                # ----------------------------->

                # -----------------------------
                # conditions for dashes
                if self.root.leftChild and self.root.leftChild.data == " ":
                    dash_left = " "
                else:
                    dash_left = "_"

                if self.root.rightChild and self.root.rightChild.data == " ":
                    dash_right = " "
                else:
                    dash_right = "_"
                # ----------------------------->

                # -----------------------------
                # handle condition for extra spaces when node lengths don't match or are even:
                if extra_spaces_next_node:
                    extra_spaces = 1
                    extra_spaces_next_node = False
                else:
                    extra_spaces = 0
                # ----------------------------->

                # -----------------------------
                # account for longer data
                data_length = len(str(self.root.data))
                if data_length > 1:
                    if data_length % 2 == 1:  # odd
                        if dash_count > 0:
                            dash_count -= ((data_length - 1) / 2)
                        else:
                            spaces_mid -= (data_length - 1) / 2
                            spaces_front -= (data_length - 1) / 2
                            if data_length is not 1:
                                extra_spaces_next_node = True
                    else:  # even
                        if dash_count > 0:
                            dash_count -= ((data_length) / 2) - 1
                            extra_spaces_next_node = True
                            # dash_count += 1
                        else:
                            spaces_mid -= (data_length - 1)
                            spaces_front -= (data_length - 1)
                # ----------------------------->

                # -----------------------------
                # print node with/without dashes
                if first_item_in_layer:
                    print(str(" " * spaces_front)   + str(dash_left * dash_count) + str(node.data) + str(dash_right * dash_count), end="")
                    first_item_in_layer = False
                else:
                    print((" " * (spaces_mid - extra_spaces)) + (dash_left * dash_count) + (self.root.data) + (dash_right * dash_count), end=""),
                # ----------------------------->

                if self.root.leftChild: queue.enqueue(self.root.leftChild)
                if self.root.rightChild: queue.enqueue(self.root.rightChild)

            # print the fun squiggly lines
            if not queue.isEmpty():
                print("\n" + edges_string)
            gen += 1

我得到的错误是'binary_search_tree' object has no attribute 'leftChild',在代码的这一部分:

edge_sym = "/" if root.leftChild and root.leftChild.data is not " " else " "

我猜我已经使类“冲突”了,但是我不知道如何解决。

1 个答案:

答案 0 :(得分:0)

您的代码是不完整的,实际上并不是最小的可复制示例。但看起来像

在这里复制binary_search_tree对象

tree = deepcopy(self)

然后假设树是您的根对象,又名Node

queue.enqueue(tree)  # self = root

这是不正确的,因为您的根节点将是tree.root

这至少解决了您目前的问题,但是还有更多问题,例如尝试从名为.data的节点打印.dataValue