如何加载训练有素的RandomForestClassificationModel模型?

时间:2017-09-05 03:35:17

标签: java apache-spark apache-spark-mllib

我已经训练过&测试了ML模型(GBTClassificationModel或RandomForestClassificationModel)。然后我想保存训练好的模型以备将来使用。所以我做了以下事情:

 model.save("...");

保存后,以GBTClassificationModel为例。保存的文件是包含"数据,元数据和树的元数据"的目录。我的问题是如何使用这个保存的模型以备将来使用?例如,我想做类似以下的事情:

 model = spark.load("...");
 Dataset<Row> predict_data= model_model.transform(dataset_test1)

有什么建议吗?谢谢。

更新:

结果非常简单:

 GBTClassificationModel model1 = GBTClassificationModel.load("...");
 Dataset<Row> predict_data= model1.transform(dataset_test)

1 个答案:

答案 0 :(得分:2)

您应该使用RandomForestClassificationModel.load方法。

  

load(path:String):RandomForestClassificationModel 从输入路径读取ML实例,快捷方式为read.load(path)

在Scala中,在您的情况下,它如下:

import org.apache.spark.ml.classification.RandomForestClassificationModel
val model = RandomForestClassificationModel.load("/analytics_shared/qoe/km_model")

我强烈建议使用Spark MLlib的ML Pipeline功能:

  

ML Pipelines提供了一套基于DataFrame构建的统一的高级API,可帮助用户创建和调整实用的机器学习流程。

使用ML Pipeline,只需将RandomForestClassificationModel替换为PipelineModel即可轻松实现。

import org.apache.spark.ml.PipelineModel
val model = PipelineModel.load("...")