Convert JavaRDD<row> to JavaRDD<vector>

时间:2016-04-04 19:03:17

标签: java apache-spark spark-dataframe apache-spark-mllib

I'm trying to perform LDA on Wikipedia XML dump. After getting an RDD of raw text, I am creating a dataframe and transforming it through Tokenizer, StopWords and CountVectorizer pipelines. I intend to pass the RDD of Vectors ouput from CountVectorizer to OnlineLDA in MLLib. Here's my code:

 // Configure an ML pipeline
 RegexTokenizer tokenizer = new RegexTokenizer()

 StopWordsRemover remover = new StopWordsRemover()

 CountVectorizer cv = new CountVectorizer()

 Pipeline pipeline = new Pipeline()
          .setStages(new PipelineStage[] {tokenizer, remover, cv});

// Fit the pipeline to train documents.
 PipelineModel model =;

 JavaRDD<Vector> countVectors = model.transform(fileDF)
          .map(new Function<Row, Vector>() {
            public Vector call(Row row) throws Exception {
                Object[] arr = row.getList(0).toArray();

                double[] features = new double[arr.length];
                int i = 0;
                for(Object obj : arr){
                    features[i++] = (double)obj;
                return Vectors.dense(features);

I'm getting the class cast exception because of the line

Object[] arr = row.getList(0).toArray();

Caused by: java.lang.ClassCastException: org.apache.spark.mllib.linalg.SparseVector cannot be cast to scala.collection.Seq
at org.apache.spark.sql.Row$class.getSeq(Row.scala:278)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getSeq(rows.scala:192)
at org.apache.spark.sql.Row$class.getList(Row.scala:286)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getList(rows.scala:192)
at xmlProcess.ParseXML$
at xmlProcess.ParseXML$

I found the Scala syntax to do this here but couldn't find any example to do it in Java. I tried row.getAs[Vector](0) but that's just Scala syntax. Any ways to do it in Java?

         JavaRDD<Vector> countVectors = model.transform(fileDF)
              .map(new Function<Row, Vector>() {
                public Vector call(Row row) throws Exception {
                    return (Vector)row.get(0);

import{CountVectorizer, RegexTokenizer, StopWordsRemover}
import{Vector => MLVector}
import org.apache.spark.mllib.clustering.{LDA, OnlineLDAOptimizer}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}

代码片段遵循其余部分,与this example保持一致:

val cvModel = new CountVectorizer()

val countVectors = cvModel
        .select("docId","features") { case Row(docId: String, features: MLVector) => 
                   (docId.toLong, Vectors.fromML(features)) 
val mbf = {
    // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
    val corpusSize = countVectors.count()
    2.0 / maxIterations + 1.0 / corpusSize
  val lda = new LDA()
    .setOptimizer(new OnlineLDAOptimizer().setMiniBatchFraction(math.min(1.0, mbf)))
    .setDocConcentration(-1) // use default symmetric document-topic prior
    .setTopicConcentration(-1) // use default symmetric topic-word prior

  val startTime = System.nanoTime()
  val ldaModel =
  val elapsed = (System.nanoTime() - startTime) / 1e9

    * Print results.
  // Print training time
  println(s"Finished training LDA model.  Summary:")
  println(s"Training time (sec)\t$elapsed")

