有效地索引Dataframe中的数组列

时间:2017-09-19 12:56:07

标签: scala apache-spark

我可以通过使用下面的代码得到我需要的东西,但由于在循环的每一步使用收集和定义新的RDD,它非常慢,我知道这是一种可怕的做法...

我需要将一个StringIndexer应用于Dataframe col(event_name)的每个元素,如下所示:

+--------------------+-------+-------+---------+----------------+
|               email|country|manager|       bu|      event_name|
+--------------------+-------+-------+---------+----------------+
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [event1,event2]
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [event3] 
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [event1,event2]
|xxxxxxxx@xxxxxxx....|     CA|   0   |Core - CS| [event3,event4,event3]
|xxxxxxxx@xxxxxxx....|     US|   0   |Core - CS| [event1]         
+--------------------+-------+-------+---------+----------------+

我需要将此event_name col转换并替换或附加到此DF,例如:

+--------------------+-------+-------+---------+----------------+
|               email|country|manager|       bu|      event_name|
+--------------------+-------+-------+---------+----------------+
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [1,2]
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [3] 
|xxxxxxxx@xxxxxxx....|     GB|   0   |Core - CS| [1,2]
|xxxxxxxx@xxxxxxx....|     CA|   0   |Core - CS| [3,4,3]
|xxxxxxxx@xxxxxxx....|     US|   0   |Core - CS| [1]         
+--------------------+-------+-------+---------+----------------+

如果没有下面的巨额开销,我怎么能这样做?

感谢

  val rddX = dfWithSchema.select("event_name").rdd.map(_.getString(0).split(",").map(_.trim replaceAll ("[\\[\\]\"]", "")).toList)
  //val oneRow = Converted(eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq("CCB")).toDF("event_name")).select("eventIndex").first().getDouble(0))
  rddX.take(5).foreach(println)
  val severalRows = rddX.collect().map(row =>
    if (row.length == 1) {
      (eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq(row(0).toString)).toDF("event_name")).select("eventIndex").first().getDouble(0))
    } else {
      row.map(tool => {
        (eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq(tool.toString)).toDF("event_name")).select("eventIndex").first().getDouble(0))
      })
  })

1 个答案:

答案 0 :(得分:1)

我相信你有一个简单的解决方案,就是爆炸事件,应用索引器然后聚合它们:

//这应该通常使用org.apache.spark.sql.functions来应用所有必需的转换._

def string2list = (regexp_replace(_: Column, "[\\[\\]\"]", "")) andThen
(split(_:Column, ","))

//首先爆炸你的事件列表(这里它返回一个数据帧,而不是rddX)

val dfX = dfWithSchema
.withColumn("rowID",monotonically_increasing_id())
.withColumn("exploded_events", explode(string2list($"event_name")))

应用你的stringIndexer,如(https://spark.apache.org/docs/2.1.0/ml-features.html#stringindexer

val indexer = new StringIndexer()
  .setInputCol("exploded_events")
  .setOutputCol("categoryEventName")

val indexedEvents = indexer.fit(dfX).transform(dfX)
indexedEvents.show()

//然后,如果您需要将数据作为List

返回
val aggregatedEvents = indexedEvents
.groupBy("rowID").agg(collect_list("categoryEventName"))
相关问题