将减法应用于rdd - PySpark中的每一行

时间:2017-06-09 18:25:04

标签: python apache-spark pyspark spark-dataframe

此代码创建一个整数的rdd并打印它们:

schema = StructType([StructField('value', IntegerType(), False)])
rdd = sc.parallelize([[100],[50],[25]])
myrdd = sqlContext.createDataFrame(rdd, schema).rdd
for x in myrdd.collect():
    print(x)

打印:

Row(value=100)
Row(value=50)
Row(value=25)

我正在尝试从此rdd中减去一个值,以便每次减法时如果有一个余数则从下一行中减去。

作为减去125的例子,从第一行取100,从第二行取25,这将留下一个新的rdd值:

Row(value=0)
Row(value=25)
Row(value=25)

作为减去160的另一个例子,从第一行获取100,从第二行获取50,从第三行获取10,这将留下新的rdd值:

Row(value=0)
Row(value=0)
Row(value=15)

我的尝试:

valueToRemove = 125
def myFun(s):
    valueToRemove = valueToRemove - s['value']
    return Row(value = valueToRemove)

myrdd1 = myrdd.map(myFun)

for x in myrdd1.collect():
    print(x)

导致错误:

UnboundLocalError: local variable 'valueToRemove' referenced before assignment

我认为一个自然的解决方案是foldLeft,但Apache spark不支持foldLeft。此外,我无法使用fold,因为要按确定的顺序处理行。

如何从每一行中减去一个值并存储要在下一行中使用的减法结果?

更新:

添加全局:

schema = StructType([StructField('value', IntegerType(), False)])
rdd = sc.parallelize([[100],[50],[25]])
myrdd = sqlContext.createDataFrame(rdd, schema).rdd
for x in myrdd.collect():
    print(x)

global valueToRemove
valueToRemove = 125

def myFun(s):
    valueToRemove = valueToRemove - s['value']
    return Row(value = valueToRemove)

myrdd1 = myrdd.map(myFun)

for x in myrdd1.collect():
    print(x)

导致同样的错误。

1 个答案:

答案 0 :(得分:1)

假设

我解决了它假设:

  1. 数据可以保留为DataFrame
  2. 有一列表示值
  3. 的行号

    根据上述假设,这是我输入的版本

    schema = StructType([StructField('row', IntegerType(), 
    False),StructField('value', IntegerType(), False)])
    rdd = sc.parallelize([[1, 100],[2, 50],[3, 25],[4,225]])
    myrdd = sqlContext.createDataFrame(rdd, schema)
    for x in myrdd.collect():
        print(x)
    

    打印:

    Row(row=1, value=100)
    Row(row=2, value=50)
    Row(row=3, value=25)
    Row(row=4, value=225)
    

    解决方案

    首先添加累积总和列:

    from pyspark.sql.window import Window
    import pyspark.sql.functions as F
    
    w = Window.orderBy("row")
    tempDF = myrdd.select("value","row",F.sum("value").over(w).alias("cumsum"))
    
    tempDF.show()
    

    打印:

    +-----+---+------+
    |value|row|cumsum|
    +-----+---+------+
    |  100|  1|   100|
    |   50|  2|   150|
    |   25|  3|   175|
    |  225|  4|   400|
    +-----+---+------+
    

    最后我定义了一个UDF来计算新值:

    def new_val(cumsum_val, row_val, target_val):
        if cumsum_val - row_val >= target_val:
            #rows that are after the "affected area"
            return row_val
        if cumsum_val - target_val < 0:
            # rows that use all their values
            return 0
        # rows with reminders
        return cumsum_val - target_val
    new_val_udf = F.udf(new_val)
    value = 160
    tempDF.withColumn("new_val",new_val_udf(F.col("cumsum"), F.col("value"), F.lit(value))).show()
    

    输出结果为:

    +-----+---+------+-------+
    |value|row|cumsum|new_val|
    +-----+---+------+-------+
    |  100|  1|   100|      0|
    |   50|  2|   150|      0|
    |   25|  3|   175|     15|
    |  225|  4|   400|    225|
    +-----+---+------+-------+
    
相关问题