如何创建输入函数,input_fn()

时间:2017-01-09 15:59:41

标签: python tensorflow

我正在关注_tensorflow.org上的this教程。 我正在尝试正确处理 input_fn _ ,以在 .fit()中用作参数。 我创建了分类器:

classifier = tf.contrib.learn.SKCompat(tf.contrib.learn.DNNClassifier(
feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir=("C:\\........\tmp"),
n_classes=2,
activation_fn=tf.sigmoid,
optimizer=tf.train.ProximalAdagradOptimizer(
    learning_rate=0.1,
    l1_regularization_strength=0.001
    )))

然后输入功能:

def input_fn(data_set):
  feature_cols = {k: tf.constant(data_set[k].values)
                  for k in FEATURES}
  labels = tf.constant(data_set[LABEL].values)
  return feature_cols, labels

最后我将 input_fn()放在 fit()中:

classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)

当我运行代码时,我收到此错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-6-938bcd2f929f> in <module>()
----> 1 classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)

TypeError: fit() got an unexpected keyword argument 'input_fn'

我不知道它是关于 input_fn 定义还是 fit 参数

1 个答案:

答案 0 :(得分:0)

如果您想使用input_fn,请不要使用SKCompat,将第一行替换为:

classifier = tf.contrib.learn.DNNClassifier(

根据需要调整括号。