将.csv加载到tensorflow时出错

时间:2016-12-16 03:48:14

标签: python numpy tensorflow

我已经采用了在Iris csv上训练的预制代码并试图使用我自己的csv。

此处发生错误

train_data = "train_data.csv"
test_data = "test_data.csv"

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=train_data,
target_dtype=np.int,
features_dtype=np.float32)

错误

ValueError: invalid literal for int() with base 10: 'feature1'

csv看起来像这样

feature1,feature2,feature3,label
1028.0,1012.0,1014.0,1
1029.0,1011.0,1017.0,-1
1027.0,1013.0,1015.0,1
...(and so on)

我得知错误是试图说feature1不是整数。但是,当我对Iris数据集使用相同的代码时,有些字符串标头不用作张量。 Iris数据csv看起来像这样。

30,4,setosa,versicolor,virginica
5.9,3.0,4.2,1.5,1
6.9,3.1,5.4,2.1,2
5.1,3.3,1.7,0.5,0

另外,不确定我是否应该将此问题作为一个不同的问题,但我将功能标题更改为

1,2,3,4
1028.0,1012.0,1014.0,1
1029.0,1011.0,1017.0,-1
1027.0,1013.0,1015.0,1
...(and so on)

我现在收到此错误

ValueError: could not broadcast input array from shape (3) into shape (2)

非常感谢任何想法或帮助!感谢!!!

1 个答案:

答案 0 :(得分:2)

如果要使用此功能,则必须以预期格式编写数据集。第一行应该是:

n_samples, n_features, [feature names]

例如,您显示的虹膜数据集的格式正确:

30,4,setosa,versicolor,virginica

即。 30个样本4个特征

如果您创建的数据集中有50个样本,则应该是:

50,4,labelname
1028.0,1012.0,1014.0,1
1029.0,1011.0,1017.0,-1
1027.0,1013.0,1015.0,1
...(and so on)
相关问题