爆炸匹配列

时间:2018-03-26 15:45:44

标签: scala apache-spark

我正在使用spark来读取包含2列数组的csv,其中第一个数组中的每个单元格与第二个数组中的单元格匹配,即:

arr1[i] <-> arr2[i]

阅读:

val table = spark.read.csv("my_data.csv")

数据示例:

+---------+---------------------+
| ids     |  avg                |
+---------+---------------------+
|  [11,23]|  [0.368633,0.750615]|

我想在相应的内容中分解列,以便每行“id”和“avg”匹配(并且不会获得所有组合)。

2 个答案:

答案 0 :(得分:2)

尝试posexplode它为数组值创建了2列,为索引创建了另一列

      //first explode and rename position columns
table.select(posexplode($"ids") as Seq("id_pos", "id"), $"avg").
     .select($"id", $"id_pos", posexplode($"avg") as Seq("avg_pos", "avg"))    
      //then keep only rows with the same array position
     .filter($"id_pos" === $"avg_pos")    
      //remove positions from dataframe
     .select($"id", $"avg")

答案 1 :(得分:0)

你可以使用一个小的udf:

val df = spark.createDataFrame(
    Seq((
      Array(11, 23), 
      Array(0.368633, 0.750615)
    ))).toDF("ids", "avg")

val udfZip = udf((ids: Seq[Int], avg: Seq[Double]) => ids.zip(avg))

val res = df.select(explode(udfZip($"ids", $"avg")).as("pair"))
res.show
// +-------------+
// |        pair|
// +-------------+
// |[11,0.368633]|
// |[23,0.750615]|
// +-------------+

res.printSchema
// root
//  |-- pair: struct (nullable = true)
//  |    |-- _1: integer (nullable = false)
//  |    |-- _2: double (nullable = false)

甚至更好的UDF:

case class Pair(id: Int, avg: Double)
def udfBetterZip = udf[Seq[Pair], Seq[Int], Seq[Double]](
    (ids: Seq[Int], avg: Seq[Double]) => 
        ids.zip(avg).map{
            case (id, avg) => Pair(id, avg)
        }
)

val res2 = df.select(explode(udfBetterZip($"ids", $"avg")).as("pair"))
res2.printSchema
// |-- pair: struct (nullable = true)
// |    |-- id: integer (nullable = false)
// |    |-- avg: double (nullable = false)