计算tf.estimator.DNNLinearCombinedClassifier的精度/召回率

时间:2018-02-23 16:47:46

标签: tensorflow tensorflow-estimator

我刚刚更新了DNNLinearCombinedClassifier以使用tf.estimator,只需要SessionRunHook而不是ValidationMonitor,我一直在使用here描述的代码计算精度/召回,因为评估者不会打印这两个指标。但是,我无法使用SessionRunHook所需的tf.estimator.Estimator找到实现相同功能的方法。

我找到了一个相关的帖子here,但似乎在代码中,如果指定了指标(来自tensorflow / contrib / learn / python / learn / monitors),则特别禁用ValidationMonitor。 PY):

if isinstance(self._estimator, core_estimator.Estimator):
  if any((x is not None for x in
          [self.x, self.y, self.batch_size, self.metrics])):
    raise ValueError(
        "tf.estimator.Estimator does not support following "
        "arguments: x, y, batch_size, metrics. Should set as `None` "
        "in ValidationMonitor")

我正在使用tensorflow 1.5.0。

有关如何实施此建议的任何建议吗?

0 个答案:

没有答案