无法将保存的模型转换为tflite

时间:2020-07-05 17:22:59

标签: tensorflow

我很有趣的tensorflow 2.2,并且正在重新训练mobilenets体系结构。总的来说,我的训练很好,但是在尝试将模型转换为tensorflow lite时抛出了错误 在此处输入代码 将tensorflow作为tf导入 进口PIL 从tensorflow.lite.python导入lite 从tensorflow.python.keras.layers.core导入密集 从tensorflow.keras.optimizers导入Adam 从tensorflow.python.keras.metrics导入categorical_crossentropy 从tensorflow.keras.preprocessing.image导入ImageDataGenerator 从tensorflow.python.keras.preprocessing导入图像 从tensorflow.keras.models导入模型 从tensorflow.python.keras.applications导入imagenet_utils 从sklearn.metrics导入confusion_matrix 导入路径库 从mlxtend.plotting导入plot_confusion_matrix 将seaborn导入为sns 将numpy导入为np

train_path = "C:/Users/rosha/PycharmProjects/FYP/Cat-Man-Car/train"
validation_path = "C:/Users/rosha/PycharmProjects/FYP/Cat-Man-Car/validation"
test_path = "C:/Users/rosha/PycharmProjects/FYP/Cat-Man-Car/test"


train_batch = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
    train_path, target_size=(224,224), batch_size=64)
valid_batch = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
    validation_path, target_size=(224,224), batch_size=64)
test_batch = ImageDataGenerator(preprocessing_function=tf.keras.applications.mobilenet.preprocess_input).flow_from_directory(
    test_path, target_size=(224,224), batch_size=64, shuffle=False)


model= tf.keras.applications.mobilenet.MobileNet()  #call mobilenet API AND STORE INTO MOBILE variable
model.summary()

x = model.layers[-1].output
predictions = Dense(3, activation='softmax')(x)
model = Model(inputs=model.input, outputs=predictions)


def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / np.sum(cm).astype('float')
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()


model.compile(Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
trained_model = model.fit_generator(train_batch, validation_data=test_batch, epochs=1)

export_dir = 'saved/model/1'
tf.saved_model.save(trained_model,export_dir)

converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
tflite_model = converter.convert()

tflite_model_file = pathlib.Path('/content/model.tflite')
tflite_model_file.write_bytes(tflite_model)

test_labels = test_batch.classes  # this line shows classes and store into test_labels
print(test_labels)      # shows classes
print(test_batch.class_indices) # show class indices like 0 for car 1 for cat 2 for man
predictions = model.predict_generator(test_batch)
cm = confusion_matrix(test_labels, predictions.argmax(axis=1))
print(f'model input : {model.input}, model input_names : {model.input_names} ')
print(f'model output : {model.output}, model output_names : {model.output_names} ')


cm_plot_labels = ['car', 'cat', 'man']
plot_confusion_matrix(cm, cm_plot_labels, normalize=False)
import matplotlib.pyplot as plt
# plot training & validation accuracy values
plt.plot(trained_model.history['accuracy'])
plt.plot(trained_model.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train','Test'], loc='upper left')
plt.show()

# plot training & validation loss
plt.plot(trained_model.history['loss'])
plt.plot(trained_model.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train','Test'], loc='upper left')
plt.show()

enter image description here

0 个答案:

没有答案
相关问题