TFlearn:ValueError:无法为张量u'TargetsData / Y:0'输入形状为((,,)'

时间:2018-09-25 14:50:40

标签: python tensorflow machine-learning tflearn

我已经检查了所有其他答案,以查找相同的错误,但是我不知道这些数据形状来自何处。我有一个包含600个样本的训练集,每个样本具有6个功能,试图将它们映射到600x2的标签矩阵,标签为0到18之间的整数。

import numpy as np
import tflearn
from sklearn.model_selection import train_test_split
import tensorflow as tf

typnp = np.array(types) #800x2 mtx
stanp = np.array(stats) #800x6 mtx
stat_tr, stat_tst, typ_tr, typ_tst, nam_tr, nam_tst = train_test_split(stanp,typnp,names) 
#training sets are cut to 600 samples

learning_rate = 0.05
epochs = 10

net = tflearn.input_data(shape=[None,6])
net = tflearn.fully_connected(net, 15, activation='sigmoid')
net = tflearn.dropout(net, keep_prob=0.5)
net = tflearn.fully_connected(net,19,activation='sigmoid')
net = tflearn.regression(net, optimizer='SGD',learning_rate=learning_rate, to_one_hot=True,
    n_classes=19, loss='categorical_crossentropy')

model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(stat_tr,typ_tr,n_epoch=epochs,validation_set=(stat_tst,typ_tst),show_metric=True)

typ_pred = model.predict(stanp)
print(typ_pred)

我不确定64x2形状的东西在哪里或如何提出,或者为什么要尝试将它与缺少尺寸的东西匹配? (?,)我尝试对标签进行热编码,因为标签没有排序,但是我什至不知道这是否是引发标签错误的因素。我确定我需要发挥学习率和激活功能,但是现在,我只需要了解我对该模型做错的事情以及如何将整数标签作为预测。任何帮助表示赞赏。

编辑:根据要求,这也是完整的错误跟踪。

Traceback (most recent call last):
  File "type3.py", line 54, in <module>
model.fit(stat_tr,typ_tr,n_epoch=epochs,validation_set=(stat_tst,typ_tst),show_metric=True)
  File "/home/peetzaman521069/.local/lib/python2.7/site-packages/tflearn/models/dnn.py", line 216, in fit
callbacks=callbacks)
  File "/home/peetzaman521069/.local/lib/python2.7/site-packages/tflearn/helpers/trainer.py", line 339, in fit
show_metric)
  File "/home/peetzaman521069/.local/lib/python2.7/site-packages/tflearn/helpers/trainer.py", line 818, in _train
feed_batch)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 877, in run
run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1076, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (64, 2) for Tensor u'TargetsData/Y:0', which has shape '(?,)'

0 个答案:

没有答案
相关问题