Tensorflow数据集用法

时间:2019-04-14 03:14:01

标签: python tensorflow tensorflow-datasets

我正在尝试创建一个简单的字符串标签对数据集,并且无法获得tensorflow来正确连接这些对

我正在尝试使用Dataset.from_tensor_slices初始化程序和数据集。make_one_shot_iterator迭代器:

import tensorflow as tf

strings = [
    'aaaa',
    'asdf'
]
labels = [1,0]

sess = tf.Session()
tf.global_variables_initializer()

dataset = tf.data.Dataset.from_tensor_slices((strings, labels))
dataset = dataset.repeat()
dataset = dataset.shuffle(512)
iterator = dataset.make_one_shot_iterator()


x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)

最后,我希望输出为'aaaa'为'1',为'asdf'为'0',但是反复得到一些随机值:

aaaa 0
asdf 0
aaaa 1
asdf 1
aaaa 1
aaaa 0
asdf 1
aaaa 1

请提出我的代码中可能有什么问题

顺便说一句,如果删除混洗,将无法获取另一个字符串,迭代器将仅输出:

aaaa 0
aaaa 0 
aaaa 0
...

标签错误...有人知道开始的原因吗?

1 个答案:

答案 0 :(得分:0)

这就是我的用法。

next = iterator.get_next()
# print(next)

with  tf.Session() as sess:
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))
    print(sess.run(next))


(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'aaaa', 1)