添加聚合的列以无需连接即可进行数据透视

时间:2019-02-28 09:14:52

标签: dataframe group-by pyspark pivot-table

考虑表格:

df=sc.parallelize([(1,1,1),(5,0,2),(27,1,1),(1,0,3),(5,1,1),(1,0,2)]).toDF(['id', 'error', 'timestamp'])
df.show()

+---+-----+---------+
| id|error|timestamp|
+---+-----+---------+
|  1|    1|        1|
|  5|    0|        2|
| 27|    1|        1|
|  1|    0|        3|
|  5|    1|        1|
|  1|    0|        2|
+---+-----+---------+

我想重点介绍timestamp列,保留原始表中的其他一些汇总信息。我感兴趣的结果可以通过

实现
df1=df.groupBy('id').agg(sf.sum('error').alias('Ne'),sf.count('*').alias('cnt'))
df2=df.groupBy('id').pivot('timestamp').agg(sf.count('*')).fillna(0)
df1.join(df2, on='id').filter(sf.col('cnt')>1).show()

及其结果表:

+---+---+---+---+---+---+
| id| Ne|cnt|  1|  2|  3|
+---+---+---+---+---+---+
|  5|  1|  2|  1|  1|  0|
|  1|  1|  3|  1|  1|  1|
+---+---+---+---+---+---+

但是,上述解决方案至少存在两个问题:

  1. 我正在脚本末尾按cnt进行过滤。如果我一开始就能做到这一点,那么我可以避免几乎所有处理,因为使用此过滤可以删除大部分数据。除了collectisin方法以外,还有什么方法可以做到这一点?
  2. 我两次进行groupBy的{​​{1}}。首先,汇总我需要的结果中的某些列,第二次获取枢轴列。最后,我需要id来合并这些列。我觉得我肯定会错过一些解决方案,因为仅使用一个join而不使用groubBy就能做到这一点,但是我不知道该怎么做。

2 个答案:

答案 0 :(得分:1)

我认为您无法绕开联接,因为枢轴将需要时间戳值,并且第一分组不应考虑它们。因此,如果必须创建NEcnt值,则只能按id对数据框进行分组,如果要保留列中的值,则会导致时间戳丢失像您单独做枢轴一样,然后将其重新加入。

唯一可以做的改进是将过滤器移至df1创建。因此,正如您所说,由于df1在过滤实际数据后应该小得多,因此已经可以提高性能。

from pyspark.sql.functions import *

df=sc.parallelize([(1,1,1),(5,0,2),(27,1,1),(1,0,3),(5,1,1),(1,0,2)]).toDF(['id', 'error', 'timestamp'])
df1=df.groupBy('id').agg(sum('error').alias('Ne'),count('*').alias('cnt')).filter(col('cnt')>1)
df2=df.groupBy('id').pivot('timestamp').agg(count('*')).fillna(0)
df1.join(df2, on='id').show()

输出:

+---+---+---+---+---+---+
| id| Ne|cnt|  1|  2|  3|
+---+---+---+---+---+---+
|  5|  1|  2|  1|  1|  0|
|  1|  1|  3|  1|  1|  1|
+---+---+---+---+---+---+

答案 1 :(得分:0)

实际上确实可以避免使用join作为Window

w1 = Window.partitionBy('id')
w2 = Window.partitionBy('id', 'timestamp')
df.select('id', 'timestamp', 
          sf.sum('error').over(w1).alias('Ne'), 
          sf.count('*').over(w1).alias('cnt'),
          sf.count('*').over(w2).alias('cnt_2')
         ).filter(sf.col('cnt')>1) \
  .groupBy('id', 'Ne', 'cnt').pivot('timestamp').agg(sf.first('cnt_2')).fillna(0).show()