private def trainTree(samples: Seq[Sample], usedFeatures: Set[Int]): DTree = {
  if (shouldStop(samples)) {
  } else {
    val featureIdx = getSplittingFeature(samples, usedFeatures)
    val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
      trainTree(statsWithFeature, usedFeatures + featureIdx), 
      trainTree(statsWithoutFeature, usedFeatures + featureIdx),

所以基本上我根据数据的某些特性递归地将列表细分为两个,并通过一系列已使用的功能,所以我不再重复 - 这些都是在“getSplittingFeature”函数中处理的,所以我们可以忽略它。代码非常简单!尽管如此,我仍然难以找到一个基于堆栈的解决方案,它不仅仅使用闭包而且有效地变成了蹦床。我知道我们至少要在堆栈中保留少量“参数”框架,但我想避免关闭调用。

我知道我应该在递归解决方案中明确地写出callstack和程序计数器为我处理的内容,但是如果没有continuation,我很难做到这一点。在这一点上,它几乎没有效率,我只是好奇。所以,请不要提醒我,过早优化是所有邪恶的根源,基于蹦床的解决方案可能会正常工作。我知道它可能会 - 这基本上是一个谜题。


更新:基于Thipor Kong优秀的解决方案,我编写了一个基于while循环/堆栈/哈希表的算法实现,该算法应该是递归版本的直接转换。这正是我想要的:


private def trainTreeNoMaxDepth(startingSamples: Seq[Sample], startingMaxDepth: Int): DTree = {
  // Use arraybuffer as dense mutable int-indexed map - no IndexOutOfBoundsException, just expand to fit
  type DenseIntMap[T] = ArrayBuffer[T]
  def updateIntMap[@specialized T](ab: DenseIntMap[T], idx: Int, item: T, dfault: T = null.asInstanceOf[T]) = {
    if (ab.length <= idx) {ab.insertAll(ab.length, Iterable.fill(idx - ab.length + 1)(dfault)) }
    ab.update(idx, item)
  var currentChildId = 0 // get childIdx or create one if it's not there already
  def child(childMap: DenseIntMap[Int], heapIdx: Int) =
    if (childMap.length > heapIdx && childMap(heapIdx) != -1) childMap(heapIdx)
    else {currentChildId += 1; updateIntMap(childMap, heapIdx, currentChildId, -1); currentChildId }
  // go down
  val leftChildren, rightChildren = new DenseIntMap[Int]() // heapIdx -> childHeapIdx
  val todo = Stack((startingSamples, Set.empty[Int], startingMaxDepth, 0)) // samples, usedFeatures, maxDepth, heapIdx
  val branches = new Stack[(Int, Int)]() // heapIdx, featureIdx
  val nodes = new DenseIntMap[DTree]() // heapIdx -> node
  while (!todo.isEmpty) {
    val (samples, usedFeatures, maxDepth, heapIdx) = todo.pop()
    if (shouldStop(samples) || maxDepth == 0) {
      updateIntMap(nodes, heapIdx, DTLeaf(makeProportions(samples)))
    } else {
      val featureIdx = getSplittingFeature(samples, usedFeatures)
      val (statsWithFeature, statsWithoutFeature) = samples.partition(hasFeature(featureIdx, _))
      todo.push((statsWithFeature, usedFeatures + featureIdx, maxDepth - 1, child(leftChildren, heapIdx)))
      todo.push((statsWithoutFeature, usedFeatures + featureIdx, maxDepth - 1, child(rightChildren, heapIdx)))
      branches.push((heapIdx, featureIdx))
  // go up
  while (!branches.isEmpty) {
    val (heapIdx, featureIdx) = branches.pop()
    updateIntMap(nodes, heapIdx, DTBranch(nodes(child(leftChildren, heapIdx)), nodes(child(rightChildren, heapIdx)), featureIdx))

只需将二叉树存储在数组中,如Wikipedia所述:对于节点i,左边的孩子进入2*i+1,右边的孩子进入2*i+2 。在做“向下”时,你会保留一系列待办事项,这些待办事项仍需要拆分才能到达一片树叶。一旦你只有叶子,向上(在数组中从右到左)建立决策节点:



sealed trait DTree[A, B]
case class DTLeaf[A, B](a: A, b: B) extends DTree[A, B]
case class DTBranch[A, B](left: DTree[A, B], right: DTree[A, B], b: B) extends DTree[A, B]

def mktree[A, B, Id](a: A, b: B, split: (A, B) => Option[(A, A, B)], ids: Stream[Id]) = {
  def goDown(todo: Seq[(A, B, Id)], branches: Seq[(Id, B, Id, Id)], leafs: Map[Id, DTree[A, B]], ids: Stream[Id]): (Seq[(Id, B, Id, Id)], Map[Id, DTree[A, B]]) =
    todo match {
      case Nil => (branches, leafs)
      case (a, b, id) :: rest =>
        split(a, b) match {
          case None =>
            goDown(rest, branches, leafs + (id -> DTLeaf(a, b)), ids)
          case Some((left, right, b2)) =>
            val leftId #:: rightId #:: idRest = ids
            goDown((right, b2, rightId) +: (left, b2, leftId) +: rest, (id, b2, leftId, rightId) +: branches, leafs, idRest)

  def goUp[A, B](branches: Seq[(Id, B, Id, Id)], nodes: Map[Id, DTree[A, B]]): Map[Id, DTree[A, B]] =
    branches match {
      case Nil => nodes
      case (id, b, leftId, rightId) :: rest =>
        goUp(rest, nodes + (id -> DTBranch(nodes(leftId), nodes(rightId), b)))

  val rootId #:: restIds = ids
  val (branches, leafs) = goDown(Seq((a, b, rootId)), Seq(), Map(), restIds)
  goUp(branches, leafs)(rootId)

// try it out

def split(xs: Seq[Int], b: Int) =
  if (xs.size > 1) {
    val (left, right) = xs.splitAt(xs.size / 2)
    Some((left, right, b + 1))
  } else {

val tree = mktree(0 to 1000, 0, split _, Stream.from(0))