功能与evaluate()中使用tf.contrib.learn.SVM的给定信息不兼容

时间:2017-07-05 22:30:46

标签: python tensorflow svm

代码如下:

def svm_tf(file):
    X,Y,training_size, index = process_data(file)
    def input_fn():
        return {
              'example_id': tf.constant(index[:training_size]),
              'multi_dim_feature': tf.constant(X[:training_size].values.tolist()),
        }, tf.constant(Y[:training_size])

    feature_columns = [tf.contrib.layers.real_valued_column("multi_dim_feature",dimension=4)]
    svm = learn.SVM(feature_columns=feature_columns,
                    l1_regularization=0.0,
                    l2_regularization=1.0,
                    example_id_column='example_id') 
    svm.fit(input_fn=input_fn,steps=50)

    def test_input():
        return{
        'example_id': tf.constant(index[training_size:]),
        'features': tf.constant(X[training_size:].values.tolist())
        }, tf.constant(Y[training_size:])


    accuracy = svm.evaluate(input_fn=test_input,steps=1)['accuracy']
    print('\nAccuracy :{0:f}\n'.format(accuracy))

但是,当我运行程序时,它会遇到如下错误:

Traceback (most recent call last):
  File "subscriber.py", line 84, in <module>
    svm_tf(file)
  File "subscriber.py", line 75, in svm_tf
    accuracy = svm.evaluate(input_fn=test_input,steps=1)['accuracy']
  File "/home/annie/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 289, in new_func
    return func(*args, **kwargs)
  File "/home/annie/.local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 543, in evaluate
    log_progress=log_progress)
  File "/home/annie/.local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 827, in _evaluate_model
    self._check_inputs(features, labels)
  File "/home/annie/.local/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 757, in _check_inputs
    (str(features), str(self._features_info)))
ValueError: Features are incompatible with given information. Given features: {'example_id': <tf.Tensor 'Const:0' shape=(1000,) dtype=string>, 'features': <tf.Tensor 'Const_1:0' shape=(1000, 4) dtype=float32>}, required signatures: {'example_id': TensorSignature(dtype=tf.string, shape=TensorShape([Dimension(4000)]), is_sparse=False), 'multi_dim_feature': TensorSignature(dtype=tf.float32, shape=TensorShape([Dimension(4000), Dimension(4)]), is_sparse=False)}.

我在网上找不到任何相关问题,因此非常丢失 请帮忙!提前致谢

1 个答案:

答案 0 :(得分:0)

test_input功能中,将features替换为multi_dim_features