在TF估算器中使用Keras模型

时间:2018-01-17 07:32:12

标签: python tensorflow keras

我想使用tf.keras.application中包含的一个预先构建的keras模型(vgg,inception,resnet等)进行特征提取,以节省一些时间训练。

在估算器模型函数中执行此操作的正确方法是什么?

这就是我现在所拥有的。

import tensorflow as tf

def model_fn(features, labels, mode):

    # Import the pretrained model
    base_model = tf.keras.applications.InceptionV3(
            weights='imagenet', 
            include_top=False,
            input_shape=(200,200,3)
    )

    # get the output features from InceptionV3
    resnet_features = base_model.predict(features['x'])

    # flatten and feed into dense layers
    pool2_flat = tf.layers.flatten(resnet_features)

    dense1 = tf.layers.dense(inputs=pool2_flat, units=5120, activation=tf.nn.relu)

    # ... Add in N number of dense layers depending on my application

    logits = tf.layers.dense(inputs=denseN, units=5)

    # Calculate Loss
    onehot_labels = tf.one_hot(indices=tf.cast(labels, tf.int32), depth=5)

    loss = tf.losses.softmax_cross_entropy(
    onehot_labels=onehot_labels, logits=logits)

    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    train_op = optimizer.minimize(
        loss=loss,
        global_step=tf.train.get_global_step()
    )

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

if __name__ == "__main__":

    # import Xtrain and Ytrain

    classifier = tf.estimator.Estimator(
        model_fn=model_fn, model_dir="/tmp/conv_model")

    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x': Xtrain},
        y=Ytrain,
        batch_size=100,
        num_epochs=None,
        shuffle=True)

    classifier.train(
        input_fn=train_input_fn,
        steps=100)

但是,此代码会抛出错误:

TypeError: unsupported operand type(s) for /: 'Dimension' and 'float'

在第resnet_features = base_model.predict(features['x'])

我认为这是因为keras模型期待一个numpy数组,但估算器传入tf.Tensor。

那么,在估算器中使用keras模型的正确方法是什么。并且,如果您不想这样做,那么在TF中利用预训练模型进行转移学习的最简单方法是什么?

2 个答案:

答案 0 :(得分:5)

我不知道有任何可用的方法允许您从预训练的keras模型创建自定义 model_fn。更简单的方法是使用tf.keras.estimator.model_to_estimator()

model = tf.keras.applications.ResNet50(
    input_shape=(224, 224, 3),
    include_top=False,
    pooling='avg',
    weights='imagenet')
logits =  tf.keras.layers.Dense(10, 'softmax')(model.layers[-1].output)
model = tf.keras.models.Model(model.inputs, logits)
model.compile('adam', 'categorical_crossentropy', ['accuracy'])

# Convert Keras Model to tf.Estimator
estimator = tf.keras.estimator.model_to_estimator(keras_model=model)
estimator.train(input_fn=....)

但是,如果您想创建自定义model_fn以添加更多操作(例如摘要操作),您可以编写如下:

import tensorflow as tf

_INIT_WEIGHT = True

def model_fn(features, labels, mode, params):
  global _INIT_WEIGHT

  # This is important, it allows keras model to update weights
  tf.keras.backend.set_learning_phase(mode == tf.estimator.ModeKeys.TRAIN)

  model = tf.keras.applications.MobileNet(
      input_tensor=features,
      include_top=False,
      pooling='avg',
      weights='imagenet' if _INIT_WEIGHT else None)

  # Only init weights on first run
  if _INIT_WEIGHT:
    _INIT_WEIGHT = False

  feature_map = model(features)
  logits = tf.keras.layers.Dense(units=params['num_classes'])(feature_map)

  # loss
  loss = tf.losses.softmax_cross_entropy(labels=labels, logits=logits)
  ...

答案 1 :(得分:0)

model_fn只能拥有张量。也许你可以尝试这样的事情。这可以被视为黑客。更好的部分是这个代码除了提供model_fn之外,它还将加载模型的权重存储为检查点。这有助于您在检查点呼叫estimator.train(...)estimator.evaluate(...)时获得权重。

def model_fn(features, labels, mode):  

    # Import the pretrained model
    base_model = tf.keras.applications.InceptionV3(
        weights='imagenet', 
        include_top=False,
        input_shape=(200,200,3)
    )    

    # some check
    if not hasattr(m, 'optimizer'):
        raise ValueError(
            'Given keras model has not been compiled yet. '
            'Please compile first '
            'before creating the estimator.')

    # get estimator object from model
    keras_estimator_obj = tf.keras.estimator.model_to_estimator(
        keras_model=base_model,
        model_dir=<model_dir>,
        config=<run_config>,
    ) 

    # pull model_fn that we need (hack)
    return keras_estimator_obj._model_fn