从Spark中的相关矩阵中提取成对相关

时间:2017-03-14 15:37:15

标签: apache-spark apache-spark-sql spark-dataframe apache-spark-mllib

我正在尝试将成对相关(例如,皮尔森)提取到火花数据帧中。我希望在进一步的查询中使用表格格式的成对协同作为机器学习输入。

所以这是一个运行的例子:

数据:

import org.apache.spark.sql.{SQLContext, Row, DataFrame}
import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType, DoubleType}
import org.apache.spark.sql.functions._

// rdd
    val rowsRdd: RDD[Row] = sc.parallelize(
      Seq(
        Row(2.0, 7.0, 1.0),
        Row(3.5, 2.5, 0.0),
        Row(7.0, 5.9, 0.0)
      )
    )

// Schema  
    val schema = new StructType()
      .add(StructField("item_1", DoubleType, true))
      .add(StructField("item_2", DoubleType, true))
      .add(StructField("item_3", DoubleType, true))

// Data frame  
    val df = spark.createDataFrame(rowsRdd, schema)

相关矩阵

import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Row
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD

   val rows = new VectorAssembler().setInputCols(df.columns).setOutputCol("corr_features")
      .transform(df)
      .select("corr_features")
      .rdd
   val items_mllib_vector = rows.map(_.getAs[org.apache.spark.ml.linalg.Vector](0))
                             .map(org.apache.spark.mllib.linalg.Vectors.fromML)

   val correlMatrix: Matrix = Statistics.corr(items_mllib_vector, "pearson")

输出是所有元素的相关矩阵。我想成对地将每个元素(i:j)与相关系数和每个元素的名称一起提取到数据帧中。

需要输出:

item_from | item_to | Correlation
item_1    | item_2  | -0.0096912
item_1    | item_3  | -0.7313071
item_2    | item_3  | 0.68910356

1 个答案:

答案 0 :(得分:0)

在一些帮助下,我找到了解决方案:

将结果导入本地数组:

import scala.collection.mutable.ListBuffer

val pairwiseArr = new ListBuffer[Array[Double]]()

for( i <- 0 to correlMatrix.numRows-1){
  for(j <- 0 to correlMatrix.numCols-1){
    pairwiseArr += Array(i, j, correlMatrix.apply(i,j))
  }
}

将Array转换为spark Dataframe:

case class pairRow(i: Double, j: Double, corr: Double)

val pairwiseDF = pairwiseArr.map(x => pairRow(x(0), x(1), x(2))).toDF()
display(pairwiseDF

由于Array是本地数组,因此最好使用ColumnSimilarities

相关问题