model.compile中的内存不足错误

时间:2017-05-16 13:15:18

标签: tensorflow keras

我有一个相对较大的多层回归模型,我想要端到端训练。我的训练是一个两步的过程,我首先将欧几里德的损失降到最低,然后我将损失降到最低。实际上,这意味着以下伪代码:

model.compile(optimizer='Adam', loss='mse')
model.fit()
model.compile(optimizer='Adam', loss=my_metric)
model.fit()

我可以毫无问题地运行前两个语句。但是当我的代码到达第二个model.compile语句时,我得到一个内存不足错误。我应该采取哪些不同的措施来避免这个问题?

编辑包含my_metric。将y_true和y_pred视为3-dim向量。首先,我将它们之间的欧氏距离最小化以初始化权重,然后我最小化它们之间的测地线损失。

# compute geodesic viewpoint loss
def my_metric(y_true, y_pred):
    # compute angles
    angle_true = K.sqrt(K.sum(K.square(y_true), axis=1))
    angle_pred = K.sqrt(K.sum(K.square(y_pred), axis=1))
    # compute axes
    axis_true = K.l2_normalize(y_true, axis=1)
    axis_pred = K.l2_normalize(y_pred, axis=1)
    # convert axes to corresponding skew-symmetric matrices
    proj = tf.constant(np.asarray([[0, -1, 0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, -1, 0, 0], [0, 0, 0, 0, 0, -1, 0, 1, 0]]), dtype=tf.float32)
    skew_true = K.dot(axis_true, proj)
    skew_pred = K.dot(axis_pred, proj)
    skew_true = K.map_fn(lambda x: K.reshape(x, [3, 3]), skew_true)
    skew_pred = K.map_fn(lambda x: K.reshape(x, [3, 3]), skew_pred)
    # compute rotation matrices and do a dot product
    R = tf.map_fn(my_R, (skew_true, skew_pred, angle_true, angle_pred), dtype=tf.float32)
    # compute the angle error
    theta = K.map_fn(get_theta, R)
    return K.mean(theta)

# function to compute R1^T R2 given the axis angle representations (\theta_1, v_1) and (\theta_2, v_2)
# x is a list that contains x[0] = v_1, x[1] = v_2, x[2] = \theta_1, x[3] = \theta_2
# note that the v_1 and v_2 are skew-symmetric matrices corresponding to the 3-dim vectors in this function
def my_R(x):
    R1 = K.eye(3) + K.sin(x[2]) * x[0] + (1.0 - K.cos(x[2])) * K.dot(x[0], x[0])
    R2 = K.eye(3) + K.sin(x[3]) * x[1] + (1.0 - K.cos(x[3])) * K.dot(x[1], x[1])
    return K.dot(K.transpose(R1), R2)


# Rodrigues' formula
def get_theta(x):
    return K.abs(tf.acos(K.clip(0.5*(tf.reduce_sum(tf.diag_part(x))-1.0), -1.0+1e-7, 1.0-1e-7)))

0 个答案:

没有答案