在DNNClassifier TensorFlow中创建适当的input_fn

时间:2018-03-23 15:14:27

标签: python tensorflow tensorflow-datasets tensorflow-estimator

我正在使用DNNClassifier构建一个神经网络,我已经阅读了网站上的示例,并由其他人完成了关于此估算工具的示例,但我仍然对input_fn的构造感到困惑。我在下面发布我的代码

import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

df = pd.read_csv('chunk.csv')
y = df.MoreClass
x_train, x_test, y_train, y_test = train_test_split(df, y, test_size=0.2)

CATEGORICAL_COLUMNS = [#list of categorical columns]
CONTINUOUS_COLUMNS = [#list of continuous columns]
COLUMNS = [#all the columns]
FEATURES = [#columns containing features]
LABEL = "label" #categorical column with 10 classes (D, BBB, BB and so on)

#embedding of categorical columns
col1 = tf.feature_column.categorical_column_with_hash_bucket(
  "df.col1", hash_bucket_size=1000)
col1_emb = tf.feature_column.embedding_column(col1, 30)
#list of others embedded columns

#transformation of numerical columns in indicator ones
col2 = tf.feature_column.numeric_column("df.col2")
col2_ind = tf.feature_column.indicator_column(col2)

#all the transformed columns
dense_cols = [col1, col2 #etcetc]

#DNNClassifier
classifier = tf.estimator.DNNClassifier(
 feature_columns=dense_cols,
 hidden_units=[10, 10],
 n_classes=10,
 dropout=0.1
)

def create_train_input_fn(): 
    return tf.estimator.inputs.pandas_input_fn(
        x=x_train,
        y=y_train, 
        batch_size=32,
        num_epochs=None, 
        shuffle=True)

def create_test_input_fn():
    return tf.estimator.inputs.pandas_input_fn(
        x=x_test,
        y=y_test, 
        num_epochs=1, 
        shuffle=False) 


train_input_fn = create_train_input_fn()
classifier.train(train_input_fn, steps=1000)

我省略了部分代码,因为在定义嵌入列和指示符列时,以及定义COLUMNS,FEATURES和LABELS时,大致相同。

运行脚本后我遇到错误:'_NumericColumn'对象没有属性'_get_sparse_tensors',我不知道如何克服它或我做错了。

创建input_fn时出现问题吗?还是在那之前呢?如果它在input_fn之前,我该如何编写正确的input_fn?

非常感谢任何帮助,提前谢谢。

1 个答案:

答案 0 :(得分:0)

indicator_column函数需要_CategoricalColumn,但您使用_NumericColumn返回的numeric_column来调用它。我想您可以通过致电_CategoricalColumn获取bucketized_column

相关问题