从DataFrame转换为JavaPairRDD <long,vector =“”>

时间:2015-10-23 17:53:52

标签: java apache-spark apache-spark-mllib

我尝试使用Java API使用apache spark实现LDA算法。方法LDA()。run()接受参数JavaPairRDD文档。 我使用scala创建RDD [(Long,Vector)]跟随:

val countVectors = cvModel.transform(filteredTokens)
    .select("docId", "features")
    .map { case Row(docId: Long, countVector: Vector) => (docId, countVector) }
    .cache()

然后输入LDA:

lda.run(countVectors)

但是在Java API中,我通过使用以下代码来使用CountVectorizerModel:

CountVectorizerModel cvModel = new CountVectorizer()
        .setInputCol("filtered").setOutputCol("features")
        .setVocabSize(vocabSize).fit(filteredTokens);

看起来像那样:

(0,(22,[0,8,9,10,14,16,18],
[1.0,1.0,1.0,1.0,1.0,1.0,1.0]))
(1,(22,[0,1,2,3,4,5,6,7,11,12,13,15,17,19,20,21],
1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]))

如果我想从cvModel转换为JavaPairRDD countVectors,我该怎么办? 我试过这个:

JavaPairRDD<Long, Vector> countVectors = cvModel.transform(filteredTokens)
          .select("docId", "features").toJavaRDD()
          .mapToPair(new PairFunction<Row, Long, Vector>() {
            public Tuple2<Long, Vector> call(Row row) throws Exception {
                return new Tuple2<Long, Vector>(Long.parseLong(row.getString(0)), Vectors.dense(row.getDouble(1)));
            }
          }).cache();

但它不起作用。我尝试时遇到异常:

Vectors.dense(row.getDouble(1))

那么,如果您有任何理想的从DataFrame cvModel转换为JavaPairRDD请告诉我。

我正在使用Spark和MLlib 1.5.1以及Java8

非常感谢任何帮助。谢谢 当我尝试从DataFrame转换为JavaPairRDD时,这是异常日志文件

15/10/25 10:03:07 ERROR Executor: Exception in task 0.0 in stage 7.0     (TID 6)
java.lang.ClassCastException: java.lang.Long cannot be cast to      java.lang.String
at org.apache.spark.sql.Row$class.getString(Row.scala:249)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:191)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:88)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:1)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.storage.MemoryStore.unrollSafely(MemoryStore.scala:278)
at org.apache.spark.CacheManager.putInBlockManager(CacheManager.scala:171)
at org.apache.spark.CacheManager.getOrCompute(CacheManager.scala:78)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:262)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
15/10/25 10:03:07 WARN TaskSetManager: Lost task 0.0 in stage 7.0 (TID 6, localhost): java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.String
at org.apache.spark.sql.Row$class.getString(Row.scala:249)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getString(rows.scala:191)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:88)
at UIT_LDA_ONLINE.LDAOnline$2.call(LDAOnline.java:1)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at org.apache.spark.api.java.JavaPairRDD$$anonfun$pairFunToScalaFun$1.apply(JavaPairRDD.scala:1030)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.storage.MemoryStore.unrollSafely(MemoryStore.scala:278)
at org.apache.spark.CacheManager.putInBlockManager(CacheManager.scala:171)
at org.apache.spark.CacheManager.getOrCompute(CacheManager.scala:78)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:262)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:297)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:264)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:66)
at org.apache.spark.scheduler.Task.run(Task.scala:88)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:214)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)

1 个答案:

答案 0 :(得分:2)

现在我们有错误堆栈,这是错误:

您正尝试从该行中获取字符串,而您的字段为长字段,因此您需要将enter替换为row.getString(0)作为初学者。

一旦你纠正了这个问题,你就会遇到同一类型但是在不同层面上的其他错误,我可以用给出的信息指出这些错误,但如果你应用以下内容,你将能够解决它们:

行getter特定于每个字段类型,您需要使用正确的get方法。

要了解您需要使用的方法,如果您不确定,可以使用DataFrame上的row.getLong(0)方法检查每个字段的类型,然后您可以在官方文档中描述所有类型转换here