如何计算列的乘积,然后计算所有列的总和?

时间:2017-06-22 13:29:32

标签: scala apache-spark apache-spark-sql

表1 --Spark DataFrame表

enter image description here

表1中有一个名为“productMe”的栏目;还有其他列,如a,b,c等,其架构名称包含在架构数组T中。

我想要的是具有列productMe(表2)的模式数组T中的列(产品中每列两列)的内积。并对表2的每一列求和,得到表3。

如果您有充分的理由在一步中获得表3,则表2不是必需的。

表2 - 内部产品表

enter image description here

例如,列“a·productMe”为(3 * 0.2,6 * 0.6,5 * 0.4)得到(0.6,3.6,2)

表3 - 总和表

enter image description here

例如,“sum(a·productMe)”列为0.6 + 3.6 + 2 = 6.2。

表1是Spark的DataFrame,我该如何获得表3?

4 个答案:

答案 0 :(得分:6)

您可以尝试以下内容:

val df = Seq(
  (3,0.2,0.5,0.4),
  (6,0.6,0.3,0.1),
  (5,0.4,0.6,0.5)).toDF("productMe", "a", "b", "c")
import org.apache.spark.sql.functions.col
val columnsToSum = df.
  columns.  // <-- grab all the columns by their name
  tail.     // <-- skip productMe
  map(col). // <-- create Column objects
  map(c => round(sum(c * col("productMe")), 3).as(s"sum_${c}_productMe"))
val df2 = df.select(columnsToSum: _*)
df2.show()
# +---------------+---------------+---------------+
# |sum_a_productMe|sum_b_productMe|sum_c_productMe|
# +---------------+---------------+---------------+
# |            6.2|            6.3|            4.3|
# +---------------+---------------+---------------+

诀窍是使用df.select(columnsToSum: _*),这意味着您要选择我们将列总和乘以productMe列的所有列。 :_*是一种特定于Scala的语法,用于指定我们传递重复的参数,因为我们没有固定数量的参数。

答案 1 :(得分:2)

我们可以使用简单的SparkSql

来完成
   val table1 = Seq(
   (3,0.2,0.5,0.4),
   (6,0.6,0.3,0.1),
   (5,0.4,0.6,0.5)
 ).toDF("productMe", "a", "b", "c")

table1.show
table1.createOrReplaceTempView("table1") 

val table2 = spark.sql("select a*productMe, b*productMe, c*productMe  from table1")   //spark is sparkSession here
table2.show

val table3 = spark.sql("select sum(a*productMe), sum(b*productMe), sum(c*productMe) from table1")
table3.show

答案 2 :(得分:2)

所有其他答案都使用sum下的groupBy聚合。

groupBy总是引入一个shuffle阶段,通常(总是?)比相应的窗口聚合慢。

在这种特殊情况下,我也相信窗口聚合可以提供更好的性能,因为您可以在他们的物理计划和他们唯一的工作细节中看到。

注意

这两种解决方案都使用一个单独的分区来进行计算,这反过来使得它们对于大型数据集不适合,因为它们的大小可能很容易超过单个JVM的内存大小。

窗口聚合

以下是基于窗口聚合的计算,在这种特殊情况下,我们将数据集中的所有行分组,但遗憾的是,它提供了相同的物理计划。这使我的答案只是一个(希望)很好的学习经历。

val df = Seq(
  (3,0.2,0.5,0.4),
  (6,0.6,0.3,0.1),
  (5,0.4,0.6,0.5)).toDF("productMe", "a", "b", "c")

// yes, I did borrow this trick with columns from @eliasah's answer
import org.apache.spark.sql.functions.col
val columns = df.columns.tail.map(col).map(c => c * col("productMe") as s"${c}_productMe")
val multiplies = df.select(columns: _*)
scala> multiplies.show
+------------------+------------------+------------------+
|       a_productMe|       b_productMe|       c_productMe|
+------------------+------------------+------------------+
|0.6000000000000001|               1.5|1.2000000000000002|
|3.5999999999999996|1.7999999999999998|0.6000000000000001|
|               2.0|               3.0|               2.5|
+------------------+------------------+------------------+

def sumOverRows(name: String) = sum(name) over ()
val multipliesCols = multiplies.
  columns.
  map(c => sumOverRows(c) as s"sum_${c}")
val answer = multiplies.
  select(multipliesCols: _*).
  limit(1)  // <-- don't use distinct or dropDuplicates here
scala> answer.show
+-----------------+---------------+-----------------+
|  sum_a_productMe|sum_b_productMe|  sum_c_productMe|
+-----------------+---------------+-----------------+
|6.199999999999999|            6.3|4.300000000000001|
+-----------------+---------------+-----------------+

物理计划

让我们看看物理计划(因为这是我们想要看看如何使用窗口聚合进行查询的唯一原因,不是吗?)

enter image description here

以下是唯一的工作0的详细信息。

enter image description here

答案 3 :(得分:1)

如果我正确理解您的问题,那么以下可以成为您的解决方案

   val df = Seq(
       (3,0.2,0.5,0.4),
       (6,0.6,0.3,0.1),
       (5,0.4,0.6,0.5)
     ).toDF("productMe", "a", "b", "c")

这样可以提供输入数据帧(可以添加更多)

+---------+---+---+---+
|productMe|a  |b  |c  |
+---------+---+---+---+
|3        |0.2|0.5|0.4|
|6        |0.6|0.3|0.1|
|5        |0.4|0.6|0.5|
+---------+---+---+---+

val productMe = df.columns.head
val colNames = df.columns.tail
var tempdf = df
for(column <- colNames){
  tempdf = tempdf.withColumn(column, col(column)*col(productMe))
}

上面的步骤应该给你Table2

+---------+------------------+------------------+------------------+
|productMe|a                 |b                 |c                 |
+---------+------------------+------------------+------------------+
|3        |0.6000000000000001|1.5               |1.2000000000000002|
|6        |3.5999999999999996|1.7999999999999998|0.6000000000000001|
|5        |2.0               |3.0               |2.5               |
+---------+------------------+------------------+------------------+

表3可以实现如下

tempdf.select(sum("a").as("sum(a.productMe)"), sum("b").as("sum(b.productMe)"), sum("c").as("sum(c.productMe)")).show(false)

表3是

+-----------------+----------------+-----------------+
|sum(a.productMe) |sum(b.productMe)|sum(c.productMe) |
+-----------------+----------------+-----------------+
|6.199999999999999|6.3             |4.300000000000001|
+-----------------+----------------+-----------------+

对于您拥有的任意数量的列,可以实现表2,但Table3将要求您明确定义列