在Pyspark上高效实施SOM(自组织地图)

时间:2019-02-10 14:26:30

标签: apache-spark parallel-processing pyspark som

我正在努力在Spark / Pyspark上实现SOM Batch算法的高性能版本,以获取具有100多个功能的庞大数据集。 我觉得我可以在自己可以/必须指定并行化的地方使用RDD,或者我可以使用Dataframe来提高性能,但是我看不到如何在使用时为每个工人使用诸如本地累积变量之类的东西数据框。

想法:

  • 使用蓄电池。通过创建UDF并行计算,该UDF将观察值作为输入,计算对网络的影响,并将影响发送到驱动程序中的累加器。 (已经实施了该版本,但似乎很慢(我认为累加器更新需要很长时间))
  • 将结果存储在Dataframe的新列中,然后将其加在一起。 (是否必须在每行(例如20 * 20 * 130)中存储整个神经网络?)火花优化算法是否实现了,它不需要保存每个网络,而仅将它们加在一起即可?
  • 使用类似于https://machinelearningnepal.com/2018/01/22/apache-spark-implementation-of-som-batch-algorithm/的RDD创建自定义的并行算法(但具有更高性能的计算算法)。但是我将不得不使用某种循环来遍历每一行并更新网络->这样听起来听起来效果不佳。)

对其他选项有何想法?还有更好的选择吗?

还是所有想法都不是那么好,我应该只选择我的数据集的最大子集,然后在本地训练一个SOM。 谢谢!

1 个答案:

答案 0 :(得分:2)

这正是我去年所做的,所以我可能会很高兴为您提供答案。

首先,here is my Spark implementation of the batch SOM algorithm(它是用Scala编写的,但是大多数情况在Pyspark中都是相似的。)

我在项目中需要这种算法,我发现的每个实现都至少存在以下两个问题或局限性之一:

  • 他们并没有真正实现批处理SOM算法,但是使用了映射平均方法,该方法给了我奇怪的结果(输出映射中的异常对称性)
  • 他们没有使用DataFrame API(纯RDD API),也不符合Spark ML / MLlib的精神,即使用在DataFrames上运行的简单fit() / transform() API。

因此,我继续自己编写代码:Spark ML风格的批处理SOM算法。我所做的第一件事是查看如何在Spark ML中实现k-means,因为您知道,批处理SOM与k-means算法非常相似。实际上,我可以重用Spark ML k-means代码的很大一部分,但是我不得不修改核心算法和超参数。

我可以快速总结一下模型的构建方式:

  1. 一个SOMParams类,包含SOM超参数(大小,训练参数等)
  2. 一个SOM类,该类继承自spark的Estimator,并包含训练算法。特别是,它包含一个在输入fit()上操作的DataFrame方法,其中要素以spark.ml.linalg.Vector的形式存储在单个列中。 fit()随后将选择此列并解压缩DataFrame以获得所需的RDD[Vector]功能,并在其上调用run()方法。这是所有计算发生的地方,并且您猜到了,它使用RDD,累加器和广播变量。最后,fit()方法返回一个SOMModel对象。
  3. SOMModel是经过训练的SOM模型,它继承自spark的Transformer / Model。它包含地图原型(中心向量),还包含一个transform()方法,该方法可以通过输入要素列并添加带有预测值的新列(在地图上的投影)对DataFrames进行操作。这是通过预测UDF完成的。
  4. 还有SOMTrainingSummary收集诸如目标函数之类的东西。

这是要点:

  • RDDDataFrame(或Dataset)之间并没有真正的对立,但两者之间的区别在这里并不重要。它们只是用于不同的上下文。实际上,DataFrame可以看作是一种RDD,专门用于处理按列(例如关系表)组织的结构化数据,从而允许类似SQL的操作和执行计划的优化(Catalyst优化器)。
  • 对于结构化数据,请始终选择/过滤/聚合操作,务必使用Dataframe
  • ...但是对于诸如机器学习算法之类的更复杂的任务,您需要返回RDD API并使用map / mapPartitions / foreach / reduce / reduceByKey /等自己分发计算儿子。看看MLlib中的工作方式:这只是RDD操作的一个很好的包装!

希望它将解决您的问题。关于性能,正如您所要求的 efficiency 实现一样,我尚未制定任何基准测试,但我在工作中使用了它,并且在生产集群上几分钟内处理了500k / 1M行数据集。