如何在Session中运行多个图形 - Tensorflow API

时间:2017-10-07 07:09:48

标签: python session tensorflow models object-detection

Tensorflow API提供了一些预先训练过的模型,并允许我们使用任何数据集对其进行训练。

我想知道如何在一个张量流会话中初始化和使用多个图。我想在两个图中导入两个经过训练的模型并将它们用于对象检测,但我在一次会话中尝试运行多个图时迷失了。

在一个会话中使用多个图表是否有任何特定方法?

另一个问题是,即使我为2个不同的图创建两个不同的会话并尝试使用它们,我最终会在第一个实例化会话中获得类似的结果。

3 个答案:

答案 0 :(得分:11)

每个Session只能有一个Graph。话虽这么说,根据你特别想做的事情,你有几个选择。

第一个选项是创建两个单独的会话,并将一个图表加载到每个会话中,如the documentation here中所述。您提到您使用该方法从每个会话中得到意外类似的结果,但没有更多详细信息,很难确定具体问题在您的案例中。我怀疑在每个会话中加载了相同的图表,或者当您尝试单独运行每个会话时,同一个会话正在运行两次,但没有更多细节,很难说清楚。

第二个选项是将两个图形作为主会话图的子图加载。您可以在图形中创建两个范围,并为要在该范围内加载的每个图形构建图形。然后你可以将它们视为独立的图形,因为它们之间没有连接。在正常运行图形全局函数时,您需要指定这些函数应用于哪个范围。例如,当使用其优化器对其中一个子图执行更新时,您需要使用this answer中显示的内容来获取该子图范围的可训练变量。

除非您明确要求两个图形能够在TensorFlow图形中以某种方式进行交互,否则我建议使用第一种方法,这样您就不需要跳过子图将需要的额外环形(例如需要)过滤你在任何特定时刻使用的范围,以及两者之间共享图表全局事物的可能性。)

答案 1 :(得分:3)

我面临着同样的挑战,经过几个月的研究,我终于能够解决问题。我做了tf.graph_util.import_graph_def。根据{{​​3}}:

name :(可选。)前缀,该前缀将位于 graph_def。请注意,这不适用于导入的函数名称。 默认为“导入”。

因此,通过添加此前缀,可以区分不同的会话。

例如:

first_graph_def = tf.compat.v1.GraphDef()
second_graph_def = tf.compat.v1.GraphDef()

# Import the TF graph : first
first_file = tf.io.gfile.GFile(first_MODEL_FILENAME, 'rb')
first_graph_def.ParseFromString(first_file.read())
first_graph = tf.import_graph_def(first_graph_def, name='first')

# Import the TF graph : second
second_file = tf.io.gfile.GFile(second_MODEL_FILENAME, 'rb')
second_graph_def.ParseFromString(second_file.read())
second_graph = tf.import_graph_def(second_graph_def, name='second')

# These names are part of the model and cannot be changed.
first_output_layer = 'first/loss:0'
first_input_node = 'first/Placeholder:0'

second_output_layer = 'second/loss:0'
second_input_node = 'second/Placeholder:0'

# initialize probability tensor
first_sess = tf.compat.v1.Session(graph=first_graph)
first_prob_tensor = first_sess.graph.get_tensor_by_name(first_output_layer)

second_sess = tf.compat.v1.Session(graph=second_graph)
second_prob_tensor = second_sess.graph.get_tensor_by_name(second_output_layer)

first_predictions, = first_sess.run(
        first_prob_tensor, {first_input_node: [adapted_image]})
    first_highest_probability_index = np.argmax(first_predictions)

second_predictions, = second_sess.run(
        second_prob_tensor, {second_input_node: [adapted_image]})
    second_highest_probability_index = np.argmax(second_predictions)

如您所见,您现在可以在一个tensorflow会话中初始化和使用多个图。

希望这会有所帮助

答案 2 :(得分:0)

在一个会话中的图arg应该为None(默认图或)

这是source code

class BaseSession(SessionInterface):
  """A class for interacting with a TensorFlow computation.
  The BaseSession enables incremental graph building with inline
  execution of Operations and evaluation of Tensors.
  """

  def __init__(self, target='', graph=None, config=None):
    """Constructs a new TensorFlow session.
    Args:
      target: (Optional) The TensorFlow execution engine to connect to.
      graph: (Optional) The graph to be used. If this argument is None,
        the default graph will be used.
      config: (Optional) ConfigProto proto used to configure the session.
    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        creating the TensorFlow session.
      TypeError: If one of the arguments has the wrong type.
    """
    if graph is None:
      self._graph = ops.get_default_graph()
    else:
      if not isinstance(graph, ops.Graph):
        raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))

从波纹管片段中您可以看到它不能是列表。

if graph is None:
  self._graph = ops.get_default_graph()
else:
  if not isinstance(graph, ops.Graph):
    raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))

ops.Graph(由help(ops.Graph)查找)对象开始,它不能是多个图形。

对于more,有关该点和图表的信息:

If no `graph` argument is specified when constructing the session,
the default graph will be launched in the session. If you are
using more than one graph (created with `tf.Graph()` in the same
process, you will have to use different sessions for each graph,
but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.