pyspark中同一列上的多个AND条件,无需连接操作

时间:2019-02-28 19:15:30

标签: sql apache-spark pyspark

我有一个三列[s,p,o]的表。我想删除行,对于s中的每个条目,p列均不包含[P625, P36]值。例如

+----+----+------
|   s|   p|  o  |
+----+----+-----|
| Q31| P36| Q239|
| Q31|P625|   51|
| Q45| P36| Q597|
| Q45|P625|  123|
| Q51|P625|   22|
| Q24|P625|   56|

最终结果应该是

+----+----+------
|   s|   p|  o  |
+----+----+-----|
| Q31| P36| Q239|
| Q31|P625|   51|
| Q45| P36| Q597|
| Q45|P625|  123|

使用加入操作,上述任务很容易。

df.filter(df.p=='P625').join(df.filter(df.p=='P36'),'s')

但是还有更优雅的方法吗?

2 个答案:

答案 0 :(得分:1)

原谅我,因为我对Scala API更加熟悉,但是也许您可以轻松地将其转换:

scala> val df = spark.createDataset(Seq(
     |      ("Q31", "P36", "Q239"),
     |      ("Q31", "P625", "51"),
     |      ("Q45", "P36", "Q597"),
     |      ("Q45", "P625", "123"),
     |      ("Q51", "P625", "22"),
     |      ("Q24", "P625", "56")
     | )).toDF("s", "p", "o")
df: org.apache.spark.sql.DataFrame = [s: string, p: string ... 1 more field]

scala> (df.select($"s", struct($"p", $"o").as("po"))
     |   .groupBy("s")
     |   .agg(collect_list($"po").as("polist"))
     |   .as[(String, Array[(String, String)])]
     |   .flatMap(r => {
     |     val ps = r._2.map(_._1).toSet
     |           if(ps("P625") && ps("P36")) {
     |             r._2.flatMap(po => Some(r._1, po._1, po._2))
     |           } else {
     |             None
     |           }
     |   }).toDF("s", "p", "o")
     |   .show())
+---+----+----+                                                                 
|  s|   p|   o|
+---+----+----+
|Q31| P36|Q239|
|Q31|P625|  51|
|Q45| P36|Q597|
|Q45|P625| 123|
+---+----+----+

作为参考,您上面的join()命令将返回:

scala> df.filter($"p" === "P625").join(df.filter($"p" === "P36"), "s").show
+---+----+---+---+----+
|  s|   p|  o|  p|   o|
+---+----+---+---+----+
|Q31|P625| 51|P36|Q239|
|Q45|P625|123|P36|Q597|
+---+----+---+---+----+

也许可以用更少的代码将其应用于最终解决方案中,但是我不确定哪种方法会更有效,因为这在很大程度上取决于数据。

答案 1 :(得分:1)

您需要一个窗口

from pyspark.sql import Window
from pyspark.sql.functions import *

winSpec = Window.partitionBy('s')
df.withColumn("s_list", collect_list("s").over(winSpec)).
filter(array_contains(col("s_list"), "P625") & array_contains(col("s_list"), "P36") & size(col("s_list")) = 2)