火花累积器螺纹安全(LongAccumulator似乎被打破)

时间:2017-09-25 10:15:16

标签: scala apache-spark concurrency

我正在撰写自定义AccumulatorV2,希望它能够正确同步。

我已经阅读了一些关于Spark AccumulatorV2并发问题的帖子,据我了解

1)写入累加器发生在读取之前,反之亦然

2)但不同步对不同分区的累加器的写入

因此累加器的add方法必须支持并发访问。

以下是证据:

object MyAccumulator {
  val uniqueId: AtomicInteger = new AtomicInteger()
}

//Just a regular int accumulator but with a delay inside Add method
class MyAccumulator(var w: Int = 0) extends AccumulatorV2[Int, Int] {

  //log this id to prove that updates happen on the same object
  val thisAccId = MyAccumulator.uniqueId.incrementAndGet()

  override def isZero: Boolean = w == 0

  override def copy(): AccumulatorV2[Int, Int] = new MyAccumulator(w)

  override def reset(): Unit = w = 0

  override def add(v: Int): Unit = {
    println(s"Start adding $thisAccId " + Thread.currentThread().getId)
    Thread.sleep(500)
    w += v
    println(s"End adding $thisAccId " + Thread.currentThread().getId)
  }

  override def merge(other: AccumulatorV2[Int, Int]): Unit = w += other.value

  override def value: Int = w
}

object Test extends App {
  val conf = new SparkConf()
  conf.setMaster("local[5]")
  conf.setAppName("test")
  val sc = new SparkContext(conf)
  val rdd = sc.parallelize(1 to 50, 10)
  val acc = new MyAccumulator
  sc.register(acc)
  rdd.foreach(x => acc.add(x))
  println(acc.value)

  sc.stop()
}

输出示例:

Start adding 1 73
Start adding 1 77
Start adding 1 76
Start adding 1 74
Start adding 1 75
End adding 1 74
End adding 1 75
Start adding 1 74
End adding 1 76
End adding 1 77
End adding 1 73
Start adding 1 77
Start adding 1 76
....
....
....
1212

我们可以看到add方法中有多个线程同时存在,结果是错误的(它必须是1275)。​​

然后我查看了内置累加器的源代码(例如org.apache.spark.util.LongAccumulator),发现没有同步跟踪。它只使用可变var s。

它如何运作?

更新: LongAccumulator确实已损坏(以下代码失败):

object LongAccumulatorIsBroken extends App {
  val conf = new SparkConf()
  conf.setMaster("local[5]")
  conf.setAppName("test")
  val sc = new SparkContext(conf)
  val rdd = sc.parallelize((1 to 50000).map(x => 1), 1000)
  val acc = new LongAccumulator
  sc.register(acc)
  rdd.foreach(x => acc.add(x))
  val value = acc.value
  println(value)
  sc.stop()
  assert(value == 50000, message = s"$value is not 50000")
}

0 个答案:

没有答案