过滤器计数一个火花数据帧

时间:2020-04-26 12:49:29

标签: scala apache-spark apache-spark-sql user-defined-functions

我有两个数据框,如下所示,我从MySQL表中读取逻辑DF

逻辑DF:

slNo | filterCondtion |
-----------------------
1    | age > 100      |
2    | age > 50       |
3    | age > 10       |
4    | age > 20       |

InputDF-从文件读取:

age   | name           |
------------------------
11    | suraj          |
22    | surjeth        |
33    | sam            |
43    | ram            |

我想从逻辑数据框中应用过滤器语句并添加这些过滤器的计数

结果输出:

slNo | filterCondtion | count |
------------------------------
1    | age > 100      |   10  |
2    | age > 50       |   2   |
3    | age > 10       |   5   |
4    | age > 20       |   6   |
-------------------------------

我尝试过的代码:

val LogicDF = spark.read.format("jdbc").option("url", "jdbc:mysql://localhost:3306/testDB").option("driver", "com.mysql.jdbc.Driver").option("dbtable", "logic_table").option("user", "root").option("password", "password").load()

def filterCount(str: String): Long ={
     val counte = inputDF.where(str).count()
counte
}

val filterCountUDF = udf[Long, String](filterCount)

LogicDF.withColumn("count",filterCountUDF(col("filterCondtion")))

错误跟踪:

Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$1: (string) => bigint)
  at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
  at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
  at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
  at org.apache.spark.scheduler.Task.run(Task.scala:121)
  at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
  at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.NullPointerException
  at org.apache.spark.sql.Dataset.where(Dataset.scala:1525)
  at filterCount(<console>:28)
  at $anonfun$1.apply(<console>:25)
  at $anonfun$1.apply(<console>:25)
  ... 21 more

任何其他选择也可以..!预先感谢。

2 个答案:

答案 0 :(得分:0)

没有UDF的解决方案

只要您的logicDF足够小,可以收集到驱动程序中,它就会起作用。

步骤1

将您的逻辑收集为Array[(Int, String)],如下所示:

val rules = logicDF.collect().map{ r: Row =>
  val slNo = r.getAs[Int](0)
  val condition = r.getAs[String](1)
  (slNo, condition)
}

步骤2

使用条件值构建一个新列,这些条件值将这些规则链接到when Column中。为此,请使用一些scala循环,例如:

val unused = when(lit(false), lit(false))
val filters: Column = rules.foldLeft(unused){
  case (acc: Column, (slNo: Int, cond: String)) =>
    acc.when(col("slNo") === slNo, expr(cond))
}

//You will get something like:
//when(col("slNo") === 1, expr("age > 10"))
//.when(col("slNo") === 2, expr("age > 20"))
//...

步骤3

通过联接获取两个DataFrame的笛卡尔积,因此您可以将每个规则应用于数据中的每一行:

val joinDF = logicDF.join(inputDF, lit(true), "inner") //inner or whatever

步骤4

使用前一个Column和条件过滤器进行过滤。

val withRulesDF = joinDF.filter(filters)

步骤5

分组并计数:

val resultDF = withRulesDF
  .groupBy("slNo", "filterCondtion")
  .agg(count("*") as "count")

答案 1 :(得分:-3)

<td><span th:text="${payment.manager}"></span></td>