Scala中二叉树上的尾递归折叠

时间:2017-01-03 09:32:47

标签: scala tree binary-tree tail-recursion fold

我正在尝试为二叉树找到尾递归折叠函数。鉴于以下定义:

// From the book "Functional Programming in Scala", page 45
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

实现非尾递归函数非常简单:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B =
  t match {
    case Leaf(v)      => map(v)
    case Branch(l, r) => 
      red(fold(l)(map)(red), fold(r)(map)(red))
  }

但现在我正在努力寻找尾递归折叠函数,以便可以使用注释@annotation.tailrec

在我的研究过程中,我发现了几个例子,其中树上的尾递归函数可以例如使用自己的堆栈计算所有叶子的总和,然后基本上是List[Tree[Int]]。但据我所知,在这种情况下它只适用于添加,因为无论您是首先评估运算符的左侧还是右侧都不重要。但对于广义折叠来说,它是非常相关的。在这里展示我的意图是一些示例树:

val leafs = Branch(Leaf(1), Leaf(2))
val left = Branch(Branch(Leaf(1), Leaf(2)), Leaf(3))
val right = Branch(Leaf(1), Branch(Leaf(2), Leaf(3)))
val bal = Branch(Branch(Leaf(1), Leaf(2)), Branch(Leaf(3), Leaf(4)))
val cmb = Branch(right, Branch(bal, Branch(leafs, left)))
val trees = List(leafs, left, right, bal, cmb)

基于这些树,我想用给定的折叠方法创建一个深层副本,如:

val oldNewPairs = 
  trees.map(t => (t, fold(t)(Leaf(_): Tree[Int])(Branch(_, _))))

然后证明所有创建的副本的平等条件都适用:

val conditionHolds = oldNewPairs.forall(p => {
  if (p._1 == p._2) true
  else {
    println(s"Original:\n${p._1}\nNew:\n${p._2}")
    false
  }
})
println("Condition holds: " + conditionHolds)

有人可以给我一些指示吗?

您可以在ScalaFiddle中找到此问题中使用的代码:https://scalafiddle.io/sf/eSKJyp2/15

1 个答案:

答案 0 :(得分:6)

如果停止使用函数调用堆栈并开始使用由代码和累加器管理的堆栈,则可以达到尾递归解决方案:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: Vector[B]): Vector[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          val leafRes = map(v)
          foldImp(
            toVisit.tail,
            acc :+ leafRes
          )
        case Branch(l, r) =>
          foldImp(l :: r :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.dropRight(2) ++   Vector(acc.takeRight(2).reduce(red)))
      }
    }

  foldImp(t::Nil, Vector.empty).head

}

这个想法是从左到右累积值,通过引入存根节点跟踪父母关系,并使用你的red函数使用累加器的最后两个元素减少结果节点在探索中找到。

此解决方案可以进行优化,但它已经是尾递归函数实现。

修改

通过将累加器数据结构更改为看作堆栈的列表,可以略微简化:

def fold[A, B](t: Tree[A])(map: A => B)(red: (B, B) => B): B = {

  case object BranchStub extends Tree[Nothing]

  @tailrec
  def foldImp(toVisit: List[Tree[A]], acc: List[B]): List[B] =
    if(toVisit.isEmpty) acc
    else {
      toVisit.head match {
        case Leaf(v) =>
          foldImp(
            toVisit.tail,
            map(v)::acc 
          )
        case Branch(l, r) =>
          foldImp(r :: l :: BranchStub :: toVisit.tail, acc)
        case BranchStub =>
          foldImp(toVisit.tail, acc.take(2).reduce(red) :: acc.drop(2))
      }
    }

  foldImp(t::Nil, Nil).head

}