Mlib RandomForest(Spark 2.0)预测单个向量

时间:2016-08-02 12:48:54

标签: random apache-spark machine-learning random-forest

使用mlib和DataFrame(Spark 2.0)在PipelineModel中训练RandomForestRegressor之后 我将保存的模型加载到我的RT环境中,以便使用模型预测每个请求 通过加载的PipelineModel处理和转换,但在这个过程中我必须转换 使用spark.createdataframe的单行请求向量到一行DataFrame所有这需要大约700ms!

如果我使用mllib RDD RandomForestRegressor.predict(VECTOR),则比较2.5ms。 有没有办法使用新的mlib来预测单个矢量而不转换为DataFrame或做其他事情来加快速度?

1 个答案:

答案 0 :(得分:0)

基于数据帧的org.apache.spark.ml.regression.RandomForestRegressionModel也将Vector作为输入。我认为您不需要为每个调用将向量转换为数据帧。

这是我认为您的代码应该工作的方式。

    //load the trained RF model
    val rfModel = RandomForestRegressionModel.load("path")  
    val predictionData = //a dataframe containing a column 'feature' of type Vector
    predictionData.map { row => 
        Vector feature = row.getAs[Vector]("feature")
        Double result = rfModel.predict(feature)
        (feature, result)
    }