在Colab中加载已保存的Doc2Vec模型

时间:2018-03-25 16:30:27

标签: python gensim google-colaboratory doc2vec

我已经在colab中训练并使用doc2vec保存了一个模型

model = gensim.models.Doc2Vec(vector_size=size_of_vector, window=10, min_count=5, workers=16,alpha=0.025, min_alpha=0.025, epochs=40)
model.build_vocab(allXs)
model.train(allXs, epochs=model.epochs, total_examples=model.corpus_count)

模型保存在无法从我的驱动器访问的文件夹中,但我可以看到:

from os import listdir
from os.path import isfile, getsize
from operator import itemgetter

files = [(f, getsize(f)) for f in listdir('.') if isfile(f)]
files.sort(key=itemgetter(1), reverse=True)

for f, size in files:
    print ('{} {}'.format(size, f))
print ('({} files {} total size)'.format(len(files), sum(f[1] for f in files)))

输出结果为:

79434928 Model_after_train.docvecs.vectors_docs.npy
9155086 Model_after_train
1024 .rnd
(3 files 88591038 total size)

将两个文件移动到与笔记本相同的共享目录中

folder_id = FolderID

for f, size in files:
  if 'our_first_lda' in f:  
    file = drive.CreateFile({'parents':[{u'id': folder_id}]})
    file.SetContentFile(f)
    file.Upload()

我现在面临的问题是两个: 1)gensim在保存模型时创建两个文件。我应该加载哪一个?

2)当我尝试加载文件或其他文件时:

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

from googleapiclient.discovery import build
drive_service = build('drive', 'v3')

file_id = FileID


import io
from googleapiclient.http import MediaIoBaseDownload

request = drive_service.files().get_media(fileId=file_id)
downloaded = io.BytesIO()
downloader = MediaIoBaseDownload(downloaded, request)
done = False
while done is False:
  _, done = downloader.next_chunk()
model = doc2vec.Doc2Vec.load(downloaded.read())

我无法加载获取错误的模型:

TypeError: file() argument 1 must be encoded string without null bytes, not str

有什么建议吗?

1 个答案:

答案 0 :(得分:0)

我从未使用过gensim,但是从文档的角度来看,这就是我的想法:

  1. 您正在获取两个文件,因为您传递了separately=True to save,这会将输出中的大型numpy数组保存为单独的文件。您将要复制两个文件。

  2. 根据load docs,您要传递文件名,文件的内容。因此,从云端硬盘获取文件时,保存到文件,然后将mmap='r'传递给load

  3. 如果这不能让您启动并运行,那么查看完整的示例(例如使用虚假数据)会很有帮助。

相关问题