在Scala中将方法参数限制为特定类型

时间:2019-02-21 17:31:22

标签: scala generics

我有一个函数,希望它是通用的,但将其限制为采用某些子类型。为了使事情简单,我希望我的函数只对Long,Int,Float和Double起作用。所以这是我想出的:

def covariance[A](xElems: Seq[A], yElems: Seq[A]): A = {
  val (meanX, meanY) = (mean(xElems), mean(yElems))
  val (meanDiffX, meanDiffY) = (meanDiff(meanX, xElems), meanDiff(meanY, yElems))
  ((meanDiffX zip meanDiffY).map { case (x, y) => x * y }.sum) / xElems.size - 1
}

def mean[A](elems: Seq[A]): A = {
  (elems.fold(_ + _)) / elems.length
}

def meanDiff[A](mean: A, elems: Seq[A]) = {
  elems.map(elem => elem - mean)
}

这是我用来检查上述类型的方法:

import scala.reflect.{ClassTag, classTag}
def matchList2[A : ClassTag](list: List[A]) = list match {
  case intlist: List[Int @unchecked] if classTag[A] == classTag[Int] => println("A List of ints!")
  case longlist: List[Long @unchecked] if classTag[A] == classTag[Long] => println("A list of longs!")
}

请注意,我正在使用ClassTag。我还可以使用TypeTag,甚至可以使用Shapeless库。

我想知道这是否是一个好方法?还是应该使用有界类型来解决我想要的问题?

编辑:基于使用分数类型类的评论和建议,我认为这是可行的!

def covariance[A: Fractional](xElems: Seq[A], yElems: Seq[A]): A = {
  val (meanX, meanY) = (mean(xElems), mean(yElems))
  val (meanDiffX, meanDiffY) = (meanDiff(meanX, xElems), meanDiff(meanY, yElems))
  ((meanDiffX zip meanDiffY).map { case (x, y) => x * y }.sum) / xElems.size - 1
}

def mean[A](elems: Seq[A]): A = {
  (elems.fold(_ + _)) / elems.length
}

def meanDiff[A](mean: A, elems: Seq[A]) = {
  elems.map(elem => elem - mean)
}

1 个答案:

答案 0 :(得分:0)

根据评论和输入,这是我的想法!

def mean[A](xs: Seq[A])(implicit A: Fractional[A]): A =
  A.div(xs.sum, A.fromInt(xs.size))

def covariance[A](xs: Seq[A], ys: Seq[A])(implicit A: Fractional[A]): A = {
  val (meanX, meanY) = (mean(xs), mean(ys))
  val (meanDiffX, meanDiffY) = (meanDiff(meanX, xs), meanDiff(meanY, ys))
  (meanDiffX zip meanDiffY).map { case (x, y) => A.div(A.times(x, y), A.fromInt(xs.size - 1)) }.sum
}

def meanDiff[A](mean: A, elems: Seq[A])(implicit A: Fractional[A]): Seq[A] = {
  elems.map(elem => A.minus(elem, mean))
}