Spark RDD或SQL操作来计算条件计数

时间:2018-03-09 00:36:45

标签: scala apache-spark spark-dataframe rdd

作为一些背景知识,我试图在Spark中实现Kaplan-Meier。特别是,我假设我有一个数据框/集,Double列表示为DataInt列名为censorFlag0值,如果审查,1如果没有,则优先于Boolean类型。

示例:

val df = Seq((1.0, 1), (2.3, 0), (4.5, 1), (0.8, 1), (0.7, 0), (4.0, 1), (0.8, 1)).toDF("data", "censorFlag").as[(Double, Int)] 

现在我需要计算一列wins来计算每个data值的实例。我通过以下代码实现了这一点:

val distDF = df.withColumn("wins", sum(col("censorFlag")).over(Window.partitionBy("data").orderBy("data")))

当我需要计算一个名为atRisk的数量时会出现问题,该数量会为data的每个值计算大于或等于它的data个点数(累积过滤计数,如果你愿意的话。

以下代码有效:

// We perform the counts per value of "bins". This is an array of doubles
val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins").as[Double].collect 
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// this works:
atRiskCounts.show

但是,用例涉及从列bins 本身中派生data,而我将其留作单列数据集(或RDD at最糟糕的),但肯定不是本地阵列。但这不起作用:

// Here, 'bins' rightfully come from the data itself.
val bins = df.select(col("data").as("dataBins")).distinct().as[Double]
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toSeq.toDF("data", "atRisk")
// This doesn't work -- NullPointerException
atRiskCounts.show

这也不是:

// Manually creating the bins and then parallelizing them.
val bins = Seq(0.7, 0.8, 1.0, 3.0).toDS
val atRiskCounts = bins.map(x => (x, df.filter(col("data").geq(x)).count)).toDF("data", "atRisk")
// Also fails with a NullPointerException
atRiskCounts.show

工作的另一种方法,但从并行化角度来看也不令人满意的是使用Window

// Do the counts in one fell swoop using a giant window per value.
val atRiskCounts = df.withColumn("atRisk", count("censorFlag").over(Window.orderBy("data").rowsBetween(0, Window.unboundedFollowing))).groupBy("data").agg(first("atRisk").as("atRisk"))
// Works, BUT, we get a "WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation." 
atRiskCounts.show

这最后一个解决方案并不实用,因为它最终将我的数据混合到一个分区(在这种情况下,我可能会选择使用选项1)。

成功的方法很好,只是箱子不平行,这是我真的想保留的东西。我查看过groupBy个聚合,pivot类型的聚合,但似乎没有任何意义。

我的问题是:有没有办法以分布式方式计算atRisk列?另外,为什么我在失败的解决方案中得到NullPointerException

编辑评论

我最初没有发布NullPointerException,因为它似乎没有包含任何有用的东西。我将在Macbook Pro(Spark版本2.2.1,独立的本地主机模式)上通过自制软件安装Spark。

                18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.package on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/package.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.scala on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/scala.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.org on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/org.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR ExecutorClassLoader: Failed to check existence of class <root>.java on REPL class server at spark://10.37.109.111:53360/classes
            java.net.URISyntaxException: Illegal character in path at index 36: spark://10.37.109.111:53360/classes/<root>/java.class
                at java.net.URI$Parser.fail(URI.java:2848)
                at java.net.URI$Parser.checkChars(URI.java:3021)
                at java.net.URI$Parser.parseHierarchical(URI.java:3105)
                at java.net.URI$Parser.parse(URI.java:3053)
                at java.net.URI.<init>(URI.java:588)
                at org.apache.spark.rpc.netty.NettyRpcEnv.openChannel(NettyRpcEnv.scala:327)
                at org.apache.spark.repl.ExecutorClassLoader.org$apache$spark$repl$ExecutorClassLoader$$getClassFileInputStreamFromSparkRPC(ExecutorClassLoader.scala:90)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader$$anonfun$1.apply(ExecutorClassLoader.scala:57)
                at org.apache.spark.repl.ExecutorClassLoader.findClassLocally(ExecutorClassLoader.scala:162)
                at org.apache.spark.repl.ExecutorClassLoader.findClass(ExecutorClassLoader.scala:80)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
                at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
                . . .
            18/03/12 11:41:00 ERROR Executor: Exception in task 0.0 in stage 55.0 (TID 432)
            java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
                at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
                at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
                at java.lang.Thread.run(Thread.java:748)
            18/03/12 11:41:00 WARN TaskSetManager: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at $line124.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            18/03/12 11:41:00 ERROR TaskSetManager: Task 0 in stage 55.0 failed 1 times; aborting job
            org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 55.0 failed 1 times, most recent failure: Lost task 0.0 in stage 55.0 (TID 432, localhost, executor driver): java.lang.NullPointerException
                at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
                at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
                at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
                at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
                at $anonfun$1.apply(<console>:33)
                at $anonfun$1.apply(<console>:33)
                at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
                at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
                at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
                at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
                at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
                at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
                at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
                at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
                at org.apache.spark.scheduler.Task.run(Task.scala:108)
                at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)

            Driver stacktrace:
              at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1517)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1505)
              at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1504)
              at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
              ... 50 elided
            Caused by: java.lang.NullPointerException
              at org.apache.spark.sql.Dataset.<init>(Dataset.scala:171)
              at org.apache.spark.sql.Dataset$.apply(Dataset.scala:62)
              at org.apache.spark.sql.Dataset.withTypedPlan(Dataset.scala:2889)
              at org.apache.spark.sql.Dataset.filter(Dataset.scala:1301)
              at $anonfun$1.apply(<console>:33)
              at $anonfun$1.apply(<console>:33)
              at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
              at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
              at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:395)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:234)
              at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:228)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:827)
              at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
              at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
              at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
              at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
              at org.apache.spark.scheduler.Task.run(Task.scala:108)
              at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:338)
              at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
              at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
              at java.lang.Thread.run(Thread.java:748)

我最好的猜测是,df("data").geq(x).count行可能是barfs的一部分,因为并非每个节点都有x因此是空指针?

3 个答案:

答案 0 :(得分:1)

我没有测试过这个,所以语法可能很傻,但我会做一系列的连接:

我相信您的第一个陈述与此相同 - 对于每个data值,计算有多少wins

val distDF = df.groupBy($"data").agg(sum($"censorFlag").as("wins"))

然后,正如您所指出的,我们可以构建垃圾箱的数据框:

val distinctData = df.select($"data".as("dataBins")).distinct()

然后加入>=条件:

val atRiskCounts = distDF.join(distinctData, distDF.data >= distinctData.dataBins)
  .groupBy($"data", $"wins")
  .count()

答案 1 :(得分:1)

当您的要求检查列中包含该列中所有其余值的值时,集合是最重要的。当需要检查所有值时,可以确定该列的所有数据都需要在一个执行程序或驱动程序中累积。当你的要求存在时,你无法避免这一步骤。

现在主要部分是如何定义其余步骤以从火花的并行化中受益。我建议你broadcast 收集的集合(仅作为一列的不同数据,因此它们不能很大)并使用udf函数来检查{{ 1}}条件如下

首先你可以优化你的收集步骤

gte

然后你import org.apache.spark.sql.functions._ val collectedData = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]] 收集的集

broadcast

下一步是定义val broadcastedArray = sc.broadcast(collectedData) 函数并检查udf条件并返回gte

counts

并将其用作

def checkingUdf = udf((data: Double)=> broadcastedArray.value.count(x => x >= data))

所以最后你应该

distDF.withColumn("atRisk", checkingUdf(col("data"))).show(false)

我希望这是必需的+----+----------+----+------+ |data|censorFlag|wins|atRisk| +----+----------+----+------+ |4.5 |1 |1 |1 | |0.7 |0 |0 |6 | |2.3 |0 |0 |3 | |1.0 |1 |1 |4 | |0.8 |1 |2 |5 | |0.8 |1 |2 |5 | |4.0 |1 |1 |2 | +----+----------+----+------+

答案 2 :(得分:1)

我尝试了上面的例子(虽然不是最严格的!),似乎左join效果最好。

数据:

import org.apache.spark.mllib.random.RandomRDDs._
val df = logNormalRDD(sc, 1, 3.0, 10000, 100).zip(uniformRDD(sc, 10000, 100).map(x => if(x <= 0.4) 1 else 0)).toDF("data", "censorFlag").withColumn("data", round(col("data"), 2))

联接示例:

def runJoin(sc: SparkContext, df:DataFrame): Unit = {
  val bins = df.select(col("data").as("dataBins")).distinct().sort("dataBins")
  val wins = df.groupBy(col("data")).agg(sum("censorFlag").as("wins"))
  val atRiskCounts = bins.join(df, bins("dataBins") <= df("data")).groupBy("dataBins").count().withColumnRenamed("count", "atRisk")
  val finalDF = wins.join(atRiskCounts, wins("data") === atRiskCounts("dataBins")).select("data", "wins", "atRisk").sort("data")
  finalDF.show
}

广播示例:

def runBroadcast(sc: SparkContext, df: DataFrame): Unit = {
  val bins = df.select(sort_array(collect_set("data"))).collect()(0)(0).asInstanceOf[collection.mutable.WrappedArray[Double]]
  val binsBroadcast = sc.broadcast(bins)
  val df2 = binsBroadcast.value.map(x => (x, df.filter(col("data").geq(x)).select(count(col("data"))).as[Long].first)).toDF("data", "atRisk")
  val finalDF = df.groupBy(col("data")).agg(sum("censorFlag").as("wins")).join(df2, "data")
  finalDF.show
  binsBroadcast.destroy
}

测试代码:

var start = System.nanoTime()
runJoin(sc, sampleDF)
val joinTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

start = System.nanoTime()
runBroadcast(sc, sampleDF)
val broadTime = TimeUnit.SECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)

我针对不同大小的随机数据运行此代码,提供了手动bins数组(一些非常精细,50%的原始不同数据,一些非常小,10%的原始不同数据),并且始终如一似乎join方法是最快的(虽然两者都达到了相同的解决方案,所以这是一个加分!)。

平均而言,我发现bin数组越小,broadcast方法效果越好,但join似乎不会受到太大影响。如果我有更多的时间/资源来测试这个,我会进行大量的模拟,看看平均运行时间是什么样的,但是现在我接受了@hoyland的解决方案。

仍然不确定为什么原始方法不起作用,所以愿意接受评论。

请告诉我代码中的任何问题或改进!谢谢你们:)

相关问题