如何计算火花中的距离矩阵?

时间:2016-06-14 10:47:22

标签: apache-spark distance-matrix bigdata

我尝试过对样本进行配对,但由于100个样本导致9900个样本成本更高,因此需要花费大量内存。什么是在spark中分布式环境中计算距离矩阵的更有效方法

以下是我正在尝试的伪代码片段

val input = (sc.textFile("AirPassengers.csv",(numPartitions/2)))
val i = input.map(s => (Vectors.dense(s.split(',').map(_.toDouble))))
val indexed = i.zipWithIndex()                                                                       //Including the index of each sample
val indexedData = indexed.map{case (k,v) => (v,k)}

val pairedSamples = indexedData.cartesian(indexedData)

val filteredSamples = pairedSamples.filter{ case (x,y) =>
(x._1.toInt > y._1.toInt)  //to consider only the upper or lower trainagle
 }
filteredSamples.cache
filteredSamples.count

上面的代码创建了对,但即使我的数据集包含100个样本,通过配对filteredSamples(上面)会产生4950样本,这对于大数据来说可能非常昂贵

3 个答案:

答案 0 :(得分:4)

我最近回答了类似的question

基本上,它会到达计算n(n-1)/2对,在您的示例中将是4950次计算。但是,这种方法的不同之处在于我使用连接而不是cartesian。使用您的代码,解决方案将如下所示:

val input = (sc.textFile("AirPassengers.csv",(numPartitions/2)))
val i = input.map(s => (Vectors.dense(s.split(',').map(_.toDouble))))
val indexed = i.zipWithIndex()

// including the index of each sample
val indexedData = indexed.map { case (k,v) => (v,k) } 

// prepare indices
val count = i.count
val indices = sc.parallelize(for(i <- 0L until count; j <- 0L until count; if i > j) yield (i, j))

val joined1 = indices.join(indexedData).map { case (i, (j, v)) => (j, (i,v)) }
val joined2 = joined1.join(indexedData).map { case (j, ((i,v1),v2)) => ((i,j),(v1,v2)) }

// after that, you can then compute the distance using your distFunc
val distRDD = joined2.mapValues{ case (v1, v2) => distFunc(v1, v2) }

尝试此方法并将其与您已发布的方法进行比较。希望这可以加快你的代码。

答案 1 :(得分:0)

据我所知,通过检查各种来源和Spark mllib clustering site,Spark目前不支持距离或pdist矩阵。

在我看来,100个样本总是输出至少4950个值;因此,使用转换(如.map)手动创建分布式矩阵求解器将是最佳解决方案。

答案 2 :(得分:0)

jtitusjanswer的Java版本。

public JavaPairRDD<Tuple2<Long, Long>, Double> getDistanceMatrix(Dataset<Row> ds, String vectorCol) {

    SparkContext sc = ds.sparkSession().sparkContext();
    JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);

    JavaRDD<Vector> rdd = ds.toJavaRDD().map(new Function<Row, Vector>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Vector call(Row row) throws Exception {
            return row.getAs(vectorCol);
        }

    });

    JavaPairRDD<Vector, Long> indexed_ = rdd.zipWithIndex();

    JavaRDD<Tuple2<Long, Vector>> tmp1 = indexed_.map(new Function<Tuple2<Vector,Long>, Tuple2<Long, Vector>>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Long, Vector> call(Tuple2<Vector, Long> t) throws Exception {
            return new Tuple2<Long, Vector>(t._2, t._1);
        }
    });

    JavaPairRDD<Long, Vector> _indexed = JavaPairRDD.fromJavaRDD(tmp1);

    long count = ds.count();

    List<Tuple2<Long, Long>> indexPairs = new ArrayList<Tuple2<Long, Long>>();

    for(long i=0; i <= count; i++) {
        for(long j=0; j <= count; j++) {
            if(i > j) {
                indexPairs.add(new Tuple2<Long, Long>(i, j));
            }
        }
    }

    JavaRDD<Tuple2<Long, Long>> tmp2 = jsc.parallelize(indexPairs);
    JavaPairRDD<Long, Long> indices = JavaPairRDD.fromJavaRDD(tmp2);

    JavaRDD<Tuple2<Long, Tuple2<Long, Vector>>> tmp3 = indices.join(_indexed).map(new Function<Tuple2<Long,Tuple2<Long,Vector>>, Tuple2<Long,Tuple2<Long,Vector>>>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Long, Tuple2<Long, Vector>> call(Tuple2<Long, Tuple2<Long, Vector>> t) throws Exception {
            long i = t._1;
            long j = t._2._1;
            Vector v = t._2._2;

            return new Tuple2<Long, Tuple2<Long,Vector>>(j, new Tuple2<Long,Vector>(i, v));
        }
    });

    JavaPairRDD<Long, Tuple2<Long, Vector>> joined1 = JavaPairRDD.fromJavaRDD(tmp3);

    JavaRDD<Tuple2<Tuple2<Long, Long>, Tuple2<Vector, Vector>>> tmp4 = joined1.join(_indexed).map(new Function<Tuple2<Long,Tuple2<Tuple2<Long,Vector>,Vector>>, Tuple2<Tuple2<Long,Long>,Tuple2<Vector,Vector>>>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Tuple2<Long, Long>, Tuple2<Vector, Vector>> call(Tuple2<Long, Tuple2<Tuple2<Long, Vector>, Vector>> t) throws Exception {
            long j = t._1;
            long i = t._2._1._1;
            Vector v1 = t._2._1._2;
            Vector v2 = t._2._2;

            Tuple2<Long, Long> t1 = new Tuple2<Long, Long>(i, j);
            Tuple2<Vector, Vector> t2 = new Tuple2<Vector, Vector>(v1, v2);

            return new Tuple2<Tuple2<Long, Long>, Tuple2<Vector, Vector>>(t1, t2);
        }

    });

    JavaPairRDD<Tuple2<Long, Long>, Tuple2<Vector, Vector>> joined2 = JavaPairRDD.fromJavaRDD(tmp4);

    JavaPairRDD<Tuple2<Long, Long>, Double> distance = joined2.mapValues(new Function<Tuple2<Vector,Vector>, Double>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Double call(Tuple2<Vector, Vector> t) throws Exception {
            return DistanceMeasure.getDistance(t._1, t._2);
        }
    });

    return distance;
}
相关问题