将预训练的已保存模型从NCHW转换为NHWC,使其与Tensorflow Lite兼容

时间:2018-05-01 15:59:28

标签: tensorflow keras

我已将模型从PyTorch转换为Keras,并使用后端提取张量流图。由于PyTorch的数据格式是NCHW,因此提取和保存的模型也是如此。在将模型转换为TFLite时,由于格式为NCHW,因此无法转换。有没有办法将整个图表转换为NHCW?

2 个答案:

答案 0 :(得分:1)

最好让图表的数据格式与TFLite匹配,以加快推理速度。一种方法是手动将转置操作插入图形,如以下示例所示: How to convert the CIFAR10 tutorial to NCHW

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())

答案 1 :(得分:0)

不幸的是,当前无法将NCHW图转换为NHWC。您必须从NHWC图开始进行训练,如果以后要使用TF lite进行训练。