在运行时从Scala源代码创建可序列化对象

时间:2014-02-23 18:01:16

标签: scala compilation interpreter

要将Scala嵌入为“脚本语言”,我需要能够将文本片段编译为简单对象,例如可以从磁盘序列化和反序列化的Function0[Unit],并且可以将其加载到当前运行时并执行。

我该怎么做?

比如说,我的文本片段是(纯粹是假设的):

Document.current.elements.headOption.foreach(_.open())

这可能包含在以下完整文本中:

package myapp.userscripts
import myapp.DSL._

object UserFunction1234 extends Function0[Unit] {
  def apply(): Unit = {
    Document.current.elements.headOption.foreach(_.open())
  }
}

接下来会发生什么?我应该使用IMain来编译此代码吗?我不想使用普通的解释器模式,因为编译应该是“无上下文”而不是累积请求。

我需要从编译中解脱出来的是我猜二进制类文件?在这种情况下,序列化是直接的(字节数组)。然后,我如何将该类加载到运行时并调用apply方法?

如果代码编译为多个辅助类会发生什么?上面的示例包含一个闭包_.open()。如何确保将所有辅助内容“打包”到一个对象中以进行序列化和类加载?


注意:鉴于Scala 2.11即将发布且编译器API可能已更改,我很高兴收到有关如何在Scala 2.11上解决此问题的提示

1 个答案:

答案 0 :(得分:4)

这是一个想法:使用常规的Scala编译器实例。不幸的是,它似乎需要使用硬盘文件进行输入和输出。所以我们使用临时文件。输出将在JAR中压缩,JAR将存储为字节数组(将进入假设的序列化过程)。我们需要一个特殊的类加载器来从提取的JAR中再次检索类。

以下假设Scala 2.10.3在类路径上带有scala-compiler库:

import scala.tools.nsc
import java.io._
import scala.annotation.tailrec

在函数类中包装用户提供的代码,其合成名称将针对每个新片段递增:

val packageName = "myapp"

var userCount = 0

def mkFunName(): String = {
  val c = userCount
  userCount += 1
  s"Fun$c"
}

def wrapSource(source: String): (String, String) = {
  val fun = mkFunName()
  val code = s"""package $packageName
                |
                |class $fun extends Function0[Unit] {
                |  def apply(): Unit = {
                |    $source
                |  }
                |}
                |""".stripMargin
  (fun, code)
}

编译源片段并返回生成的jar的字节数组的函数:

/** Compiles a source code consisting of a body which is wrapped in a `Function0`
  * apply method, and returns the function's class name (without package) and the
  * raw jar file produced in the compilation.
  */
def compile(source: String): (String, Array[Byte]) = {
  val set             = new nsc.Settings
  val d               = File.createTempFile("temp", ".out")
  d.delete(); d.mkdir()
  set.d.value         = d.getPath
  set.usejavacp.value = true
  val compiler        = new nsc.Global(set)
  val f               = File.createTempFile("temp", ".scala")
  val out             = new BufferedOutputStream(new FileOutputStream(f))
  val (fun, code)     = wrapSource(source)
  out.write(code.getBytes("UTF-8"))
  out.flush(); out.close()
  val run             = new compiler.Run()
  run.compile(List(f.getPath))
  f.delete()

  val bytes = packJar(d)
  deleteDir(d)

  (fun, bytes)
}

def deleteDir(base: File): Unit = {
  base.listFiles().foreach { f =>
    if (f.isFile) f.delete()
    else deleteDir(f)
  }
  base.delete()
}

注意:尚未处理编译器错误!

packJar方法使用编译器输出目录并从中生成内存中的jar文件:

// cf. http://stackoverflow.com/questions/1281229
def packJar(base: File): Array[Byte] = {
  import java.util.jar._

  val mf = new Manifest
  mf.getMainAttributes.put(Attributes.Name.MANIFEST_VERSION, "1.0")
  val bs    = new java.io.ByteArrayOutputStream
  val out   = new JarOutputStream(bs, mf)

  def add(prefix: String, f: File): Unit = {
    val name0 = prefix + f.getName
    val name  = if (f.isDirectory) name0 + "/" else name0
    val entry = new JarEntry(name)
    entry.setTime(f.lastModified())
    out.putNextEntry(entry)
    if (f.isFile) {
      val in = new BufferedInputStream(new FileInputStream(f))
      try {
        val buf = new Array[Byte](1024)
        @tailrec def loop(): Unit = {
          val count = in.read(buf)
          if (count >= 0) {
            out.write(buf, 0, count)
            loop()
          }
        }
        loop()
      } finally {
        in.close()
      }
    }
    out.closeEntry()
    if (f.isDirectory) f.listFiles.foreach(add(name, _))
  }

  base.listFiles().foreach(add("", _))
  out.close()
  bs.toByteArray
}

一个实用程序函数,它接受反序列化中找到的字节数组,并创建一个从类名到类字节代码的映射:

def unpackJar(bytes: Array[Byte]): Map[String, Array[Byte]] = {
  import java.util.jar._
  import scala.annotation.tailrec

  val in = new JarInputStream(new ByteArrayInputStream(bytes))
  val b  = Map.newBuilder[String, Array[Byte]]

  @tailrec def loop(): Unit = {
    val entry = in.getNextJarEntry
    if (entry != null) {
      if (!entry.isDirectory) {
        val name  = entry.getName  
        // cf. http://stackoverflow.com/questions/8909743
        val bs  = new ByteArrayOutputStream
        var i   = 0
        while (i >= 0) {
          i = in.read()
          if (i >= 0) bs.write(i)
        }
        val bytes = bs.toByteArray
        b += mkClassName(name) -> bytes
      }
      loop()
    }
  }
  loop()
  in.close()
  b.result()
}

def mkClassName(path: String): String = {
  require(path.endsWith(".class"))
  path.substring(0, path.length - 6).replace("/", ".")
}

合适的类加载器:

class MemoryClassLoader(map: Map[String, Array[Byte]]) extends ClassLoader {
  override protected def findClass(name: String): Class[_] =
    map.get(name).map { bytes =>
      println(s"defineClass($name, ...)")
      defineClass(name, bytes, 0, bytes.length)

    } .getOrElse(super.findClass(name)) // throws exception
}

包含其他类(闭包)的测试用例:

val exampleSource =
  """val xs = List("hello", "world")
    |println(xs.map(_.capitalize).mkString(" "))
    |""".stripMargin

def test(fun: String, cl: ClassLoader): Unit = {
  val clName  = s"$packageName.$fun"
  println(s"Resolving class '$clName'...")
  val clazz = Class.forName(clName, true, cl)
  println("Instantiating...")
  val x     = clazz.newInstance().asInstanceOf[() => Unit]
  println("Invoking 'apply':")
  x()
}

locally {
  println("Compiling...")
  val (fun, bytes) = compile(exampleSource)

  val map = unpackJar(bytes)
  println("Classes found:")
  map.keys.foreach(k => println(s"  '$k'"))

  val cl = new MemoryClassLoader(map)
  test(fun, cl)   // should call `defineClass`
  test(fun, cl)   // should find cached class
}