如何使用不可变数据类型实现DFS

时间:2011-03-29 10:41:39

标签: scala graph-traversal

我正试图找出一种遍历图形Scala样式的简洁方法,最好使用val和不可变数据类型。

如下图所示,

val graph = Map(0 -> Set(1),
                1 -> Set(2),
                2 -> Set(0, 3, 4),
                3 -> Set(),
                4 -> Set(3))

我希望输出是在给定节点中开始的深度优先遍历。例如,从1开始,应该产生例如1 2 3 0 4

在没有可变集合或变量的情况下,我似乎无法找到一种很好的方法。任何帮助将不胜感激。

8 个答案:

答案 0 :(得分:8)

Tail Recursive solution:

  def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
    def childrenNotVisited(parent: Int, visited: List[Int]) =
      graph(parent) filter (x => !visited.contains(x))

    @annotation.tailrec
    def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
      if (stack isEmpty) visited
      else loop(childrenNotVisited(stack.head, visited) ++ stack.tail, 
        stack.head :: visited)
    }
    loop(Set(start), Nil) reverse
  }

答案 1 :(得分:4)

这是我猜的一个变种:

graph.foldLeft((List[Int](), 1)){
  (s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))
}._1.reverse

已更新:这是一个扩展版本。在这里,我向左折叠地图的元素,从一个空列表和数字1的元组开始。对于每个元素,我检查图形的大小并相应地创建一个新的元组。结果列表的顺序相反。

val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
  (s, elem) => 
    val (stack, count) = s
    if (elem._2.size == 0) 
      (0 :: stack, count) 
    else 
      (count :: stack, count + 1)
}
result.reverse

答案 2 :(得分:2)

这是递归解决方案(希望我能正确理解您的要求):

def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] = 
    List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))

traverse(graph, 1)

另请注意,此函数 NOT 尾递归。

答案 3 :(得分:1)

6年后你不知道你是否还在寻找答案,但这里是:)

它还返回图的拓扑排序和周期性: -

case class Node(label: Int)
    case class Graph(adj: Map[Node, Set[Node]]) {
      case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
                          isCylic: Boolean = false)

      def dfs: (List[Node], Boolean) = {
        def dfsVisit(currState: DfsState, src: Node): DfsState = {
          val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
            isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))

          val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
          finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)
        }

        val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
        (stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
      }}

答案 4 :(得分:0)

似乎这个问题比我原先想象的更为复杂。我写了另一个递归解决方案。它仍然不是尾递归。我也努力使它成为单行,但在这种情况下,可读性会受到很大影响,所以这次我决定宣布几个val

def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
  val newResult = result :+ node
  val currentEdges = graph(node) -- newResult
  val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges

  (newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))
}

在我之前的回答中,我尝试在有向图中找到给定节点的所有路径。但根据要求,这是错误的。这个答案试图遵循有向边,但如果它不能,那么它只需要一些未访问的节点并从那里继续。

答案 5 :(得分:0)

天使,

我还没有完全理解你的解决方案,但如果我没有弄错的话,时间复杂度至少为O(| V | ^ 2),因为以下行复杂度为O(| V |):

val newResult = result :+ node

即,将一个元素附加到列表的右侧。

此外,代码不是尾递归,如果递归深度受到您正在使用的环境的限制,则可能会出现问题。

以下代码解决了有向图上的一些与DFS相关的图形问题。它不是最优雅的代码,但如果我没有弄错的话是:

  1. 尾递归。
  2. 仅使用不可变集合(以及它们上的迭代器)。
  3. 有最佳时间O(| V | + | E |)和空间复杂度(O(| V |)。
  4. 代码:

    import scala.annotation.tailrec
    import scala.util.Try
    
    /**
     * Created with IntelliJ IDEA.
     * User: mishaelr
     * Date: 5/14/14
     * Time: 5:18 PM
     */
    object DirectedGraphTraversals {
    
      type Graph[Vertex] = Map[Vertex, Set[Vertex]]
    
      def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
        dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())
    
      def topologicalSort[Vertex](graph: Graph[Vertex]) =
        graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())
    
      def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
        val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
        val reversedGraph = reverse(graph)
    
        exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
          case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
            if(visitedAcc(vertex))
              acc
            else {
              val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
                visitedAcc, visitedAcc,List()).toSet
              (visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)
            }
        }._2
      }
    
      def reverse[Vertex](graph: Graph[Vertex]) = {
        val reverseList = for {
          (vertex, neighbours) <- graph.toList
          neighbour <- neighbours
        } yield (neighbour, vertex)
    
        reverseList.groupBy(_._1).mapValues(_.map(_._2).toSet)
      }
    
      private sealed trait NeighboursFunc {
        def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])
      }
    
      private object DfsNeighbours extends NeighboursFunc {
        def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
          (vertex, graph.getOrElse(vertex, Set()).iterator)
      }
    
      private object TopologicalSortNeighbours extends NeighboursFunc {
        def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
          val neighbours = graph.getOrElse(vertex, Set())
          if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
            throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
          else
            (vertex, neighbours.iterator)
        }
      }
    
      @tailrec
      private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
                                                                 entered: Set[Vertex], exited: Set[Vertex],
                                                                 exitStack: List[Vertex]): List[Vertex] = {
        toVisit match {
          case List() => exitStack
          case (currentVertex, neighbours) :: tl =>
            val filtered = neighbours.filterNot(entered)
            if(filtered.hasNext) {
              val nextNeighbour = filtered.next()
              dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
                entered + nextNeighbour, exited, exitStack)
            } else
              dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)
        }
      }
    
      @tailrec
      private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
                                                                      entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
        if(notVisited.isEmpty)
          order
        else {
          val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
          graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)
        }
      }
    }
    
    object DirectedGraphTraversalsExamples extends App {
      import DirectedGraphTraversals._
    
      val graph = Map(
        "B" -> Set("D", "C"),
        "A" -> Set("B", "D"),
        "D" -> Set("E"),
        "E" -> Set("C"))
    
      println("dfs A " +  dfs(graph, "A"))
      println("dfs B " +  dfs(graph, "B"))
    
      println("topologicalSort " +  topologicalSort(graph))
    
      println("reverse " + reverse(graph))
      println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))
    
      val graph2 = graph + ("C" -> Set("D"))
      println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
      println("topologicalSort graph2 " + Try(topologicalSort(graph2)))
    }
    

答案 6 :(得分:0)

Marimuthu Madasamy的回答确实有效。

以下是它的通用版本:

val graph = Map(0 -> Set(1),
  1 -> Set(2),
  2 -> Set(0, 3, 4),
  3 -> Set[Int](),
  4 -> Set(3))

def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
  def childrenNotVisited(parent: T, visited: List[T]) =
    graph(parent) filter (x => !visited.contains(x))

  @annotation.tailrec
  def loop(stack: Set[T], visited: List[T]): List[T] = {
    if (stack.isEmpty) visited
    else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
      stack.head :: visited)
  }
  loop(Set(start), Nil).reverse
}

traverse(graph,0)

注意:您必须确保T的实例正确实现equals和hashcode。使用具有原始值的case类是一种简单的方法。

答案 7 :(得分:0)

我想修改Marimuthu Madasamy的答案,因为该代码将Set用于无序数据结构的堆栈,并将List用于访问,这需要花费线性时间来调用contains方法,因此整个代码时间复杂度为O(E * V),效率不高(E为边数,V为顶点数)。我宁愿使用List进行堆栈,Set进行访问(将其命名为discovered),还使用List进行结果值排序,以对访问的节点进行排序。

def dfs(stack: List[Int], discovered: Set[Int], orderedVisited: List[Int]): List[Int] = {
  def childrenNotVisited(start: Int) =
    graph(start).filter(!discovered.contains(_)).toList

  if (stack.isEmpty)
    orderedVisited
  else {
    val nextNodes = childrenNotVisited(stack.head)
    dfs(nextNodes ::: stack.tail, discovered ++ nextNodes, stack.head :: orderedVisited)
  }
}

val start = 0
val visitOrder = dfs(List(start), Set(start), Nil)