使用Word2VecModel.transform()在map函数中不起作用

时间:2015-12-24 06:41:18

标签: python apache-spark pyspark apache-spark-mllib word2vec

我使用Spark构建了一个Word2Vec模型并将其另存为模型。现在,我想在另一个代码中使用它作为离线模型。我已经加载了模型并用它来呈现一个单词的向量(例如Hello),它运行良好。但是,我需要使用map在RDD中为多个单词调用它。

当我在map函数中调用model.transform()时,会抛出此错误:

  

"您似乎正在尝试从广播引用SparkContext"   例外:您似乎尝试从广播变量,操作或转换引用SparkContext。 SparkContext只能在驱动程序上使用,而不能在工作程序上运行的代码中使用。有关更多信息,请参阅SPARK-5063。

代码:

from pyspark import SparkContext
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import Word2VecModel

sc = SparkContext('local[4]',appName='Word2Vec')

model=Word2VecModel.load(sc, "word2vecModel")

x= model.transform("Hello")
print(x[0]) # it works fine and returns [0.234, 0.800,....]

y=sc.parallelize([['Hello'],['test']])
y.map(lambda w: model.transform(w[0])).collect() #it throws the error

我将非常感谢你的帮助。

1 个答案:

答案 0 :(得分:8)

这是一种预期的行为。与其他MLlib模型一样,Python对象只是Scala模型的包装器,实际处理委托给它的JVM对应物。由于工作人员无法访问Py4J网关(请参阅How to use Java/Scala function from an action or a transformation?),因此无法通过操作或转换调用Java / Scala方法。

通常,MLlib模型提供了一种可以直接在RDD上工作的辅助方法,但这不是这种情况。 Word2VecModel提供了getVectors方法,该方法将字词从单词返回到向量,但遗憾的是它是JavaMap,因此在转换中不起作用。你可以尝试这样的事情:

from pyspark.mllib.linalg import DenseVector

vectors_ = model.getVectors() # py4j.java_collections.JavaMap
vectors = {k: DenseVector([x for x in vectors_.get(k)])
    for k in vectors_.keys()}

获取Python字典,但速度非常慢。另一个选择是以Python可以使用的形式将此对象转储到磁盘,但它需要对Py4J进行一些修改,最好避免这种情况。而是让我们将模型读作DataFrame:

lookup = sqlContext.read.parquet("path_to_word2vec_model/data").alias("lookup")

我们将得到以下结构:

lookup.printSchema()
## root
## |-- word: string (nullable = true)
## |-- vector: array (nullable = true)
## |    |-- element: float (containsNull = true)

可用于将单词映射到向量,例如通过join

from pyspark.sql.functions import col

words = sc.parallelize([('hello', ), ('test', )]).toDF(["word"]).alias("words")

words.join(lookup, col("words.word") == col("lookup.word"))

## +-----+-----+--------------------+
## | word| word|              vector|
## +-----+-----+--------------------+
## |hello|hello|[-0.030862354, -0...|
## | test| test|[-0.13154022, 0.2...|
## +-----+-----+--------------------+

如果数据适合驱动程序/工作程序内存,您可以尝试使用广播进行收集和映射:

lookup_bd = sc.broadcast(lookup.rdd.collectAsMap())
rdd = sc.parallelize([['Hello'],['test']])
rdd.map(lambda ws: [lookup_bd.value.get(w) for w in ws])
相关问题