我可以通过使用下面的代码得到我需要的东西,但由于在循环的每一步使用收集和定义新的RDD,它非常慢,我知道这是一种可怕的做法...
我需要将一个StringIndexer应用于Dataframe col(event_name)的每个元素,如下所示:
+--------------------+-------+-------+---------+----------------+
| email|country|manager| bu| event_name|
+--------------------+-------+-------+---------+----------------+
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [event1,event2]
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [event3]
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [event1,event2]
|xxxxxxxx@xxxxxxx....| CA| 0 |Core - CS| [event3,event4,event3]
|xxxxxxxx@xxxxxxx....| US| 0 |Core - CS| [event1]
+--------------------+-------+-------+---------+----------------+
我需要将此event_name col转换并替换或附加到此DF,例如:
+--------------------+-------+-------+---------+----------------+
| email|country|manager| bu| event_name|
+--------------------+-------+-------+---------+----------------+
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [1,2]
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [3]
|xxxxxxxx@xxxxxxx....| GB| 0 |Core - CS| [1,2]
|xxxxxxxx@xxxxxxx....| CA| 0 |Core - CS| [3,4,3]
|xxxxxxxx@xxxxxxx....| US| 0 |Core - CS| [1]
+--------------------+-------+-------+---------+----------------+
如果没有下面的巨额开销,我怎么能这样做?
感谢
val rddX = dfWithSchema.select("event_name").rdd.map(_.getString(0).split(",").map(_.trim replaceAll ("[\\[\\]\"]", "")).toList)
//val oneRow = Converted(eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq("CCB")).toDF("event_name")).select("eventIndex").first().getDouble(0))
rddX.take(5).foreach(println)
val severalRows = rddX.collect().map(row =>
if (row.length == 1) {
(eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq(row(0).toString)).toDF("event_name")).select("eventIndex").first().getDouble(0))
} else {
row.map(tool => {
(eventIndexer.transform(sqlContext.sparkContext.parallelize(Seq(tool.toString)).toDF("event_name")).select("eventIndex").first().getDouble(0))
})
})
答案 0 :(得分:1)
我相信你有一个简单的解决方案,就是爆炸事件,应用索引器然后聚合它们:
//这应该通常使用org.apache.spark.sql.functions来应用所有必需的转换._
def string2list = (regexp_replace(_: Column, "[\\[\\]\"]", "")) andThen
(split(_:Column, ","))
//首先爆炸你的事件列表(这里它返回一个数据帧,而不是rddX)
val dfX = dfWithSchema
.withColumn("rowID",monotonically_increasing_id())
.withColumn("exploded_events", explode(string2list($"event_name")))
应用你的stringIndexer,如(https://spark.apache.org/docs/2.1.0/ml-features.html#stringindexer)
val indexer = new StringIndexer()
.setInputCol("exploded_events")
.setOutputCol("categoryEventName")
val indexedEvents = indexer.fit(dfX).transform(dfX)
indexedEvents.show()
//然后,如果您需要将数据作为List
返回val aggregatedEvents = indexedEvents
.groupBy("rowID").agg(collect_list("categoryEventName"))