通过列上的爆炸跨列进行pyspark聚合?

时间:2018-02-09 00:42:10

标签: pyspark-sql

我有一个pyspark数据框,看起来像这样:

data = [{'event_id':0, 'bid_0_a':1, 'bid_0_b':2, 'bid_1_a':1, 'bid_1_b':2},
        {'event_id':1, 'bid_0_a':1, 'bid_0_b':2, 'bid_1_a':1, 'bid_1_b':2}]
schema = T.StructType([T.StructField(nm, T.IntegerType(), True) 
                                     for nm in ['event_id', 'bid_0_a', 'bid_0_b', 'bid_1_a', 'bid_1_b']])
df3 = spark.createDataFrame(data, 
                            schema=schema)

+--------+-------+-------+-------+-------+
|event_id|bid_0_a|bid_0_b|bid_1_a|bid_1_b|
+--------+-------+-------+-------+-------+
|       0|      1|      2|      1|      2|
|       1|      1|      2|      1|      2|
+--------+-------+-------+-------+-------+

真实桌面每个活动有很多出价,还有更多的出价字段(超过a和b)。最后,我想对出价进行分类,并在每个类别中进行一些汇总。我想我首先需要爆炸这些列,这就产生了一个像:

这样的表
+--------+---+--+--+
|event_id|bid| a| b|
+--------+---+--+--+
|  0     | 0 | 1| 2|
|  0     | 1 | 1| 2|
|  1     | 0 | 1| 2|
|  1     | 1 | 1| 2|
+--------+---+--+--+

我可以想象这样做,但我想知道是否有更快的pyspark SQL方法来做到这一点?也许如果我将_ * _列的出价收集到地图或数组中,我可以使用explode

1 个答案:

答案 0 :(得分:0)

这是使用create_map()explode()pivot()的有点h​​acky解决方案。我说有点hacky,因为我依靠max()函数来聚合,当你的列名称是唯一的时,它应该工作,但我觉得应该有更好的方法。

首先从列中创建一个地图字段。您必须有一个bid_ids列表。

import pyspark.sql.functions as f
from operator import add
bid_ids = ['0', '1']

df4 = df3.withColumn(
    'map',
    f.create_map(
        *(reduce(
            add, 
            [[f.lit(b),
              f.create_map(f.lit('a'),
                           f.col('bid_%s_a'%b), 
                           f.lit('b'), 
                           f.col('bid_%s_b'%b))
             ] for b in bid_ids
            ]
        ))
    )
)
df4.select('event_id', 'map').show(truncate=False)
#+--------+-------------------------------------------------------+
#|event_id|map                                                    |
#+--------+-------------------------------------------------------+
#|0       |Map(0 -> Map(a -> 1, b -> 2), 1 -> Map(a -> 1, b -> 2))|
#|1       |Map(0 -> Map(a -> 1, b -> 2), 1 -> Map(a -> 1, b -> 2))|
#+--------+-------------------------------------------------------+

现在调用explode()两次(因为地图是嵌套的)。如果不清楚,可以打印出中间步骤(这里省略)。

df4 = df4.select('event_id', f.explode('map'))\
    .select('event_id', f.col('key').alias('bid'), f.explode('value'))
df4.show(truncate=False)
#+--------+---+---+-----+
#|event_id|bid|key|value|
#+--------+---+---+-----+
#|0       |0  |a  |1    |
#|0       |0  |b  |2    |
#|0       |1  |a  |1    |
#|0       |1  |b  |2    |
#|1       |0  |a  |1    |
#|1       |0  |b  |2    |
#|1       |1  |a  |1    |
#|1       |1  |b  |2    |
#+--------+---+---+-----+

数据采用此格式后,您可以拨打groupBy()pivot()。您需要汇总作为分组,因此我选择max() - 它不应该重要,因为每个(event_id, bid, key)组应该只有一个值。

df4.groupBy('event_id', 'bid').pivot('key').max('value').show()
#+--------+---+---+---+
#|event_id|bid|  a|  b|
#+--------+---+---+---+
#|       1|  0|  1|  2|
#|       0|  1|  1|  2|
#|       0|  0|  1|  2|
#|       1|  1|  1|  2|
#+--------+---+---+---+
相关问题