PySpark中可变列数的总和

时间:2018-08-07 18:00:07

标签: python apache-spark pyspark apache-spark-sql

我有一个像这样的Spark DataFrame:

+-----+--------+-------+-------+-------+-------+-------+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|
+-----+--------+-------+-------+-------+-------+-------+
|  Cat|       1|      1|      2|      3|      4|      5|
|  Dog|       2|      1|      2|      3|      4|      5|
|Mouse|       4|      1|      2|      3|      4|      5|
|  Fox|       5|      1|      2|      3|      4|      5|
+-----+--------+-------+-------+-------+-------+-------+

您可以使用下面的代码重现它:

data = [('Cat', 1, 1, 2, 3, 4, 5),
        ('Dog', 2, 1, 2, 3, 4, 5),
        ('Mouse', 4, 1, 2, 3, 4, 5),
        ('Fox', 5, 1, 2, 3, 4, 5)]
columns = ['Type', 'Criteria', 'Value#1', 'Value#2', 'Value#3', 'Value#4', 'Value#5']
df = spark.createDataFrame(data, schema=columns)
df.show()

我的任务是添加总计列,该列是所有“值”列的总和,其中#号不超过此行的条件。

在此示例中:

  • 对于第'Cat'行:条件是1,所以Total仅是Value#1
  • 对于第'Dog'行:条件是2,因此TotalValue#1Value#2的总和。
  • 对于第'Fox'行:条件是5,因此Total是所有列的总和(Value#1Value#5)。

结果应如下所示:

+-----+--------+-------+-------+-------+-------+-------+-----+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|Total|
+-----+--------+-------+-------+-------+-------+-------+-----+
|  Cat|       1|      1|      2|      3|      4|      5|    1|
|  Dog|       2|      1|      2|      3|      4|      5|    3|
|Mouse|       4|      1|      2|      3|      4|      5|   10|
|  Fox|       5|      1|      2|      3|      4|      5|   15|
+-----+--------+-------+-------+-------+-------+-------+-----+

我可以使用Python UDF做到这一点,但是我的数据集很大,并且由于序列化,Python UDF速度很慢。我正在寻找纯Spark解决方案。

我正在使用PySpark和Spark 2.1

1 个答案:

答案 0 :(得分:5)

您可以通过PySpark: compute row maximum of the subset of columns and add to an exisiting dataframe轻松地将解决方案调整为user6910411

from pyspark.sql.functions import col, when

total = sum([
    when(col("Criteria") >= i, col("Value#{}".format(i))).otherwise(0)
    for i in range(1, 6)
])

df.withColumn("total", total).show()

# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# |  Cat|       1|      1|      2|      3|      4|      5|    1|
# |  Dog|       2|      1|      2|      3|      4|      5|    3|
# |Mouse|       4|      1|      2|      3|      4|      5|   10|
# |  Fox|       5|      1|      2|      3|      4|      5|   15|
# +-----+--------+-------+-------+-------+-------+-------+-----+

对于任意一组订单列,请定义一个list

cols = df.columns[2:]

并将总计重新定义为:

total_ = sum([
    when(col("Criteria") > i, col(cols[i])).otherwise(0)
    for i in range(len(cols))
])

df.withColumn("total", total_).show()
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# |  Cat|       1|      1|      2|      3|      4|      5|    1|
# |  Dog|       2|      1|      2|      3|      4|      5|    3|
# |Mouse|       4|      1|      2|      3|      4|      5|   10|
# |  Fox|       5|      1|      2|      3|      4|      5|   15|
# +-----+--------+-------+-------+-------+-------+-------+-----+

重要

这里sum__builtin__.sum而不是pyspark.sql.functions.sum