将Class实例作为参数传递

时间:2017-11-20 15:44:02

标签: python tensorflow

我有一个python脚本,在其中一行中调用一个存在于另一个类中的函数,存在于另一个脚本中,我已经导入了脚本并访问了该函数,但该函数接受了一个参数,这是一个实例班级本身。 该函数位于脚本model.py中,其中Model内的Model类是函数, model.py

内的函数
  class Model(object):    
    def create_base(self,
      images,
      labels_one_hot,
      scope='AttentionOcr_v1',
      reuse=None):

所以当我运行以下代码时

import skimage.io as io
import numpy as np
import collections
import os
import tensorflow as tf
import common_flags
import model

from tensorflow.python.platform import flags

FLAGS = flags.FLAGS
common_flags.define()

images_placeholder = tf.placeholder(tf.float32, shape=[635, 1219])
fn ='/home/ubuntu/tensorflow/6m.jpg'
images = [io.imread(fn, dtype='float')]
print(images)
images_actual_data = np.stack(images)
images_actual_data = 2.5*(images_actual_data - 0.5)  # normalize values



dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
modelout = common_flags.create_model(dataset.num_char_classes,
                                    dataset.max_sequence_length,
                                    dataset.num_of_views, dataset.null_code)
endpoints = model.Model.create_base(images_placeholder, labels_one_hot=None)

with tf.Session() as sess:
    init_fn = model.create_init_fn_to_restore('/nsfs/tensor_models/models/research/attention_ocr/python/inception_v3.ckpt', '')
    init_fn(sess)
    sess.run(tf.global_variables_initializer()) 
    predictions = sess.run(endpoints.predicted_chars, feed_dict={images_placeholder:images_actual_data.reshape(1,imHeight,imWidth,imChannel)})
    print predictions

我收到以下错误:

INFO 2017-11-20 15:08:46.000545: fsns.py: 130 Using FSNS dataset split_name=train dataset_dir=/nsfs/tensor_models/models/research/attention_ocr/python/datasets/data/fsns
Traceback (most recent call last):
  File "run_tf.py", line 31, in <module>
    endpoints = model.Model.create_base(images_placeholder, labels_one_hot=None)
TypeError: unbound method create_base() must be called with Model instance as first argument (got Tensor instance instead)

我是python的新手,所以任何想法都有帮助吗? 感谢

1 个答案:

答案 0 :(得分:0)

正如评论中所建议的那样,我使用了

endpoints = modelout.create_base(images_placeholder, labels_one_hot=None)

并且有效