我有一个像这样的Spark DataFrame:
+-----+--------+-------+-------+-------+-------+-------+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|
+-----+--------+-------+-------+-------+-------+-------+
| Cat| 1| 1| 2| 3| 4| 5|
| Dog| 2| 1| 2| 3| 4| 5|
|Mouse| 4| 1| 2| 3| 4| 5|
| Fox| 5| 1| 2| 3| 4| 5|
+-----+--------+-------+-------+-------+-------+-------+
您可以使用下面的代码重现它:
data = [('Cat', 1, 1, 2, 3, 4, 5),
('Dog', 2, 1, 2, 3, 4, 5),
('Mouse', 4, 1, 2, 3, 4, 5),
('Fox', 5, 1, 2, 3, 4, 5)]
columns = ['Type', 'Criteria', 'Value#1', 'Value#2', 'Value#3', 'Value#4', 'Value#5']
df = spark.createDataFrame(data, schema=columns)
df.show()
我的任务是添加总计列,该列是所有“值”列的总和,其中#号不超过此行的条件。
在此示例中:
'Cat'
行:条件是1
,所以Total
仅是Value#1
。 'Dog'
行:条件是2
,因此Total
是Value#1
和Value#2
的总和。'Fox'
行:条件是5
,因此Total
是所有列的总和(Value#1
至Value#5
)。结果应如下所示:
+-----+--------+-------+-------+-------+-------+-------+-----+
| Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|Total|
+-----+--------+-------+-------+-------+-------+-------+-----+
| Cat| 1| 1| 2| 3| 4| 5| 1|
| Dog| 2| 1| 2| 3| 4| 5| 3|
|Mouse| 4| 1| 2| 3| 4| 5| 10|
| Fox| 5| 1| 2| 3| 4| 5| 15|
+-----+--------+-------+-------+-------+-------+-------+-----+
我可以使用Python UDF做到这一点,但是我的数据集很大,并且由于序列化,Python UDF速度很慢。我正在寻找纯Spark解决方案。
我正在使用PySpark和Spark 2.1
答案 0 :(得分:5)
您可以通过PySpark: compute row maximum of the subset of columns and add to an exisiting dataframe轻松地将解决方案调整为user6910411
from pyspark.sql.functions import col, when
total = sum([
when(col("Criteria") >= i, col("Value#{}".format(i))).otherwise(0)
for i in range(1, 6)
])
df.withColumn("total", total).show()
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Cat| 1| 1| 2| 3| 4| 5| 1|
# | Dog| 2| 1| 2| 3| 4| 5| 3|
# |Mouse| 4| 1| 2| 3| 4| 5| 10|
# | Fox| 5| 1| 2| 3| 4| 5| 15|
# +-----+--------+-------+-------+-------+-------+-------+-----+
对于任意一组订单列,请定义一个list
:
cols = df.columns[2:]
并将总计重新定义为:
total_ = sum([
when(col("Criteria") > i, col(cols[i])).otherwise(0)
for i in range(len(cols))
])
df.withColumn("total", total_).show()
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Type|Criteria|Value#1|Value#2|Value#3|Value#4|Value#5|total|
# +-----+--------+-------+-------+-------+-------+-------+-----+
# | Cat| 1| 1| 2| 3| 4| 5| 1|
# | Dog| 2| 1| 2| 3| 4| 5| 3|
# |Mouse| 4| 1| 2| 3| 4| 5| 10|
# | Fox| 5| 1| 2| 3| 4| 5| 15|
# +-----+--------+-------+-------+-------+-------+-------+-----+
重要:
这里sum
是__builtin__.sum
而不是pyspark.sql.functions.sum
。