具有自定义model_fn的TensorFlow估算器在训练阶段永远运行

时间:2017-10-03 05:50:48

标签: python tensorflow

以下代码的相关部分。对Scattering对象的调用基于固定的滤波器映射返回系数的3D张量。该程序仅从Scattering调用进入和返回一次,表示代码在第一个训练步骤中的某个地方永久挂起,但不在Scattering调用中。这可能发生在哪里?

def my_model_fn(features, labels, mode, params):
    M, N = features.get_shape().as_list()[-2:]
    scattering_coefficients = Scattering(M=M, N=N, J=1, L=2)(features)
    batch_size = scattering_coefficients.get_shape().as_list()[0]
    # throw all coefficients into single vector for each image
    scattering_coefficients = tf.reshape(scattering_coefficients, [batch_size, -1])
    # returns tensor of correct shape
    print(scattering_coefficients)
    n_classes = 10
    n_coefficients = scattering_coefficients.get_shape().as_list()[1]

    # use linear classifier
    W = tf.Variable(tf.zeros([n_coefficients, n_classes]))
    b = tf.Variable(tf.zeros([n_classes]))
    y_predict = tf.nn.softmax(tf.matmul(scattering_coefficients, W) + b)

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions={"predictions": y_predict})

    # loss function and training step
    cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_predict) )
    train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(cross_entropy)

    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=cross_entropy,
        train_op=train_op)


def sample_batch(X, y, batch_size):
    idx = np.random.choice(X.shape[0], batch_size, replace=False)
    return tf.convert_to_tensor(X[idx]), tf.convert_to_tensor(y[idx])

LEARNING_RATE = 0.01
BATCH_SIZE = 2
n_training_steps = 2
image_dimension = 28
model_params = {"learning_rate": LEARNING_RATE}

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

X_train = mnist.train.images.astype(np.float32)
X_train = normalize(X_train)
# number of channels is 1, -1 infers number of samples
X_train = X_train.reshape(-1, 1, image_dimension, image_dimension)
y_train = mnist.train.labels.astype(np.int64)

X_validation = mnist.validation.images.astype(np.float32)
X_validation = normalize(X_validation)
X_validation = X_validation.reshape(-1, 1, image_dimension, image_dimension)
y_validation = mnist.validation.labels.astype(np.int64)

train_input_fn = lambda: sample_batch(X_train, y_train, BATCH_SIZE)
validation_input_fn = lambda: sample_batch(X_validation, y_validation, BATCH_SIZE)

# Train
scattering_classifier = tf.estimator.Estimator(model_fn=my_model_fn, params=model_params)
# Hangs forever...
scattering_classifier.train(input_fn=train_input_fn, max_steps=n_training_steps)
# If I comment out training step, this finishes immediately.
print("start scoring accuracy")
predictions = scattering_classifier.predict(input_fn=validation_input_fn)

1 个答案:

答案 0 :(得分:0)

更改

train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(cross_entropy)

train_op = tf.train.GradientDescentOptimizer(params["learning_rate"]).minimize(
    cross_entropy, global_step=tf.train.get_global_step())

解决了这个问题。非常欢迎解释。

相关问题