减少两个Scala方法,只在一个对象类型中有所不同

时间:2015-03-24 23:34:22

标签: scala apache-spark

我有以下两种方法,使用Apache Spark中的对象。

  def SVMModelScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = SVMModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

  def DecisionTreeScoring(sc: SparkContext, scoringDataset: String, modelFileName: String): RDD[(Double, Double)] = {
    val model = DecisionTreeModel.load(sc, modelFileName)

    val scoreAndLabels = 
      MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
        val score = model.predict(point.features)
        (score, point.label)
      }
    return scoreAndLabels
  }

我之前尝试合并这些函数导致了models.predict周围的错误。

有没有办法可以将模型用作Scala中弱类型的参数?

1 个答案:

答案 0 :(得分:2)

免责声明 - 我从未使用过Apache Spark。

在我看来,这两种方法的唯一区别在于model的实例化方式。遗憾的是,这两个model实例实际上并不共享提供predict(...)的共同特征,但我们仍然可以通过拉出更改的部分来实现这一功能 - scorer

def scoreWith(sc: SparkContext, scoringDataset: String)(scorer: (Vector)=>Double): RDD[(Double, Double)] = {
  MLUtils.loadLibSVMFile(sc, scoringDataset).randomSplit(Array(0.1), seed = 11L)(0).map { point =>
    val score = scorer(point.features)
    (score, point.label)
  }
}

现在我们可以通过以下方式获得以前的功能:

def svmScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(SVMModel.load(sc, modelFileName).predict)

def dtScorer(sc: SparkContext, scoringDataset:String, modelFileName:String) =
  scoreWith(sc: SparkContext, scoringDataset:String)(DecisionTreeModel.load(sc, modelFileName).predict)
相关问题