确定在编译时表达式的值是否已知

时间:2018-10-17 14:34:00

标签: scala macros constant-expression

假设我想创建一个NonZero类型,以便我的整数除法函数是total:

def div(numerator: Int, denominator: NonZero): Int =
  numerator / denominator.value

我可以通过使用私有构造函数创建一个NonZero类来实现此目的:

class NonZero private[NonZero] (val value : Int) { /*...*/ }

还有一个用于保存Int => Option[NonZero]构造函数的辅助对象,以及一个unapply以便可以在match表达式中使用:

object NonZero {
  def build(n:Int): Option[NonZero] = n match {
    case 0 => None
    case n => Some(new NonZero(n))
  }
  def unapply(nz: NonZero): Option[Int] = Some(nz.value)
  // ...
}

build适用于运行时值,但必须对文字执行NonZero.build(3).get很难看。

使用宏,我们可以定义apply only for literals,因此NonZero(3)可以工作,但是NonZero(0)是编译时错误:

object NonZero {
  // ...
  def apply(n: Int): NonZero = macro apply_impl
  def apply_impl(c: Context)(n: c.Expr[Int]): c.Expr[NonZero] = {
    import c.universe._
    n match {
      case Expr(Literal(Constant(nValue: Int))) if nValue != 0 =>
        c.Expr(q"NonZero.build(n).get")
      case _ => throw new IllegalArgumentException("Expected non-zero integer literal")
    }
  }
}

但是,此宏的作用远不如它有用,因为它只允许使用文字,而不能使用编译时常量表达式:

final val X: Int = 3
NonZero(X) // compile-time error

我在宏中could pattern match on Expr(Constant(_)),但是NonZero(X + 1)又如何呢?我宁愿不必实现自己的Scala表达式评估器。

是否有帮助程序或一些简单的方法来确定在编译时是否知道给宏的表达式的值(C ++称之为constexpr)?

2 个答案:

答案 0 :(得分:0)

如果忽略宏,则在Scala中,仅类型在编译时存在,而值仅在运行时存在。您可以执行类型级别的技巧,以在编译时将数字编码为类型,例如Type Level Programming in Scala

这是上面Peano算术示例的简化版本。首先,我们定义一个类型类,以显示某种类型如何将其转换为整数。

@annotation.implicitNotFound("Create an implicit of type TValue[${T}] to convert ${T} values to integers.")
final class TValue[T](val get: Int) extends AnyVal

然后,我们定义Peano的“零”类型,并说明如何将其转换为运行时整数0:

case object TZero {
  implicit val tValue: TValue[TZero.type] = new TValue(0)
}

然后是Peano的“继任者”类型,以及如何将其转换为运行时整数1 +先前值:

case class TSucc[T: TValue]()
object TSucc {
  implicit def tValue[TPrev](implicit prevTValue: TValue[TPrev]): TValue[TSucc[TPrev]] =
    new TValue(1 + prevTValue.get)
}

然后测试安全划分:

object Test {
  def safeDiv[T](numerator: Int, denominator: TSucc[T])(implicit tValue: TValue[TSucc[T]]): Int =
    numerator / tValue.get
}

尝试一下:

scala> Test.safeDiv(10, TZero)
<console>:14: error: type mismatch;
 found   : TZero.type
 required: TSucc[?]
       Test.safeDiv(10, TZero)
                        ^

scala> Test.safeDiv(10, TSucc[String]())
<console>:14: error: Create an implicit of type TValue[String] to convert String values to integers.
       Test.safeDiv(10, TSucc[String]())
                                     ^

scala> Test.safeDiv(10, TSucc[TZero.type]) // 10/1
res2: Int = 10

scala> Test.safeDiv(10, TSucc[TSucc[TZero.type]]) // 10/2
res3: Int = 5

不过,您可以想象,它很快就会变得冗长。

答案 1 :(得分:0)

som-snytt's advice to check out ToolBox.eval带我去了Context.eval,这是我一直想要的助手:

object NonZero {
  // ...
  def apply(n: Int): NonZero = macro apply_impl
  def apply_impl(c: Context)(n: c.Expr[Int]): c.Expr[NonZero] = try {
    if (c.eval(n) != 0) {
      import c.universe._
      c.Expr(q"NonZero.build(n).get")
    } else {
      throw new IllegalArgumentException("Non-zero value required")
    }
  } catch {
    case _: scala.tools.reflect.ToolBoxError =>
      throw new IllegalArgumentException("Unable to evaluate " + n.tree + " at compile time")
  }
}

所以现在我可以传递NonZero.apply常量和用常量制成的表达式:

scala> final val N = 3
scala> NonZero(N)
res0: NonZero = NonZero(3)
scala> NonZero(2*N + 1)
res1: NonZero = NonZero(7)
scala> NonZero(N - 3)
IllegalArgumentException: ...
scala> NonZero((n:Int) => 2*n + 1)(3))
IllegalArgumentException: ...

虽然eval可以像上面的最后一个示例那样处理纯函数会很好,但这已经足够了。

令人尴尬的是,从问题中回顾并重新测试了我以前的代码,证明我的原始宏也同样处理了相同的表达式!

我断言final val X = 3; NonZero(X) // compile-time error是错误的,因为所有评估都是通过内联处理的(如som-snytt的评论所暗示)。