我在张量流中处理一些CNN模型,但是我遇到了与读取数据有关的问题。
我有一个文件,一个TFRecord文件,使用GZIP压缩,我想批量读取该文件中的数据,然后设置下一个代码:
def _input_fn( files ):
print( files )
thread_count = multiprocessing.cpu_count()
batch_size = 2 # for debug
num_epochs = 2
min_after_dequeue = 1000
queue_size_multiplier = thread_count + 3
filename_queue = tf.train.string_input_producer( files , num_epochs = num_epochs )
example_id , encoded_examples = tf.TFRecordReader(
options = tf.python_io.TFRecordOptions (
compression_type= TFRecordCompressionType.GZIP
)
).read_up_to( filename_queue , batch_size)
features, targets = example_parser(encoded_examples )
capacity = min_after_dequeue + queue_size_multiplier
images , labels = tf.train.shuffle_batch(
[features , targets ] , batch_size , num_threads = thread_count ,
capacity = capacity , min_after_dequeue = min_after_dequeue ,
enqueue_many = True
)
return images , labels
读取TFRecords的常用代码。然后我开始创建一个测试图,而不是实际的NN,用于测试
inputs, labels = _inpunt_fn(files )
# the shape of the tensors returned by _input_fn is correct. [batch_size ,150*150*150 ]
ss = inputs + 1 # some computation
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess = tf.Session()
sess.run( init_op )
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# i cut some things to make it shorter ...
while not coord.should_stop():
sess.run( [ ss ] )
print(" test ")
我运行此代码,我永远不会看到"测试"打印。该程序崩溃了我的电脑。我使用" top"命令观察内存使用情况,它会快速增长,直到程序和我电脑中的所有内容都崩溃。
数据集很大。每个样本是一个3d矩阵(150 x 150 x 150),有一千个样本。但我(理论上)没有把它装到记忆中,我小批量地读它,对吧? ,那么为什么会发生这种情况......我阅读文件的方式有什么问题,我该如何修复它。
事先谢谢。欢迎任何见解。