保存并运行wide_deep.py模型

时间:2017-11-26 09:20:48

标签: tensorflow

我使用人口普查数据集一直在玩Tensorflow Wide and Deep tutorial

线性/广泛教程陈述:

We will train a logistic regression model, and given an individual's information our model will output a number between 0 and 1

目前,我无法弄清楚如何预测单个输入的输出(从单元测试中复制):

TEST_INPUT_VALUES = {
    'age': 18,
    'education_num': 12,
    'capital_gain': 34,
    'capital_loss': 56,
    'hours_per_week': 78,
    'education': 'Bachelors',
    'marital_status': 'Married-civ-spouse',
    'relationship': 'Husband',
    'workclass': 'Self-emp-not-inc',
    'occupation': 'abc',
}

我们如何预测并输出此人是否可能获得<50k(0)或> = 50k(1)?

2 个答案:

答案 0 :(得分:2)

函数是predict,但我没有弄清楚如何直接输入一个示例数据(我试过numpy_input_fn和张量的dict)。

相反,使用wide_deep.py中的输入函数将数据写入临时csv文件然后读取它,可以使用predict函数:

TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
              'Husband,zyx,wvu,34,56,78,tsr,<=50K')
# Create temporary CSV file
input_csv = '/tmp/census_model/test.csv'
with tf.gfile.Open(input_csv, 'w') as temp_csv:
    temp_csv.write(TEST_INPUT)

# restore model trained by wide_deep.py with same model_dir and model_type 
model = wide_deep.build_estimator(FLAGS.model_dir, FLAGS.model_type)
pred_iter = model.predict(input_fn=lambda: wide_deep.input_fn(input_csv, 1, False, 1))
for pred in pred_iter:
    # print(pred)
    print(pred['classes'])

probability中还有logitspred等其他属性。

答案 1 :(得分:1)

Hookay,我现在可以回答这个问题。所以如果你想评估测试集的准确性,你可以按照接受的答案,但如果你想做出自己的预测,这里是步骤。

首先,构建一个新的input_fn,注意您需要更改列和默认列值,因为标签列不在那里。

def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_PREDICT_COLUMNS_DEFAULTS)
    features = dict(zip(_PREDICT_COLUMNS, columns))

    return features


def predict_input_fn(data_file):
    assert tf.gfile.Exists(data_file), ('%s not found. Please make sure the path is correct.' % data_file)

    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(parse_csv, num_parallel_calls=5)
    dataset = dataset.batch(1) # => This is very important to get the rank correct
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()
    return features

然后你可以通过

简单地调用它
results = model.predict(
        input_fn=lambda: predict_input_fn(data_file='test.csv')
    )
相关问题