在自定义损失函数中重塑张量

时间:2017-07-18 15:31:58

标签: python tensorflow keras

我遇到与this question类似的问题。我试图在keras中设计一个损失函数:

    allfiles = glob.glob("*.csv", )
    dataframes = []
    for file in allfiles :
        dataframes.append(pd.read_csv(file, sep=";", decimal=","))
    df = pd.concat(dataframes)

基于this question中给出的答案。但是,我收到了一个错误:

def depth_loss_func(lr):
    def loss(actual_depth,pred_depth):
        actual_shape = actual_depth.get_shape().as_list()
        dim = np.prod(actual_shape[1:])
        actual_vec = K.reshape(actual_depth,[-1,dim])
        pred_vec = K.reshape(pred_depth,[-1,dim])
        di = K.log(pred_vec)-K.log(actual_vec)
        di_mean = K.mean(di)
        sq_mean = K.mean(K.square(di))

        return (sq_mean - (lr*di_mean*di_mean))
    return loss

具体而言,此语句提供以下输出

 TypeError: unsupported operand type(s) for *: 'NoneType' and 'NoneType'

后端是TensorFlow。谢谢你的帮助。

1 个答案:

答案 0 :(得分:1)

我设法使用形状张量<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script> <form action="/" method="post"> <select id="SelectedMonth" name="SelectedMonth"> <option>7/1/2017</option> <option>6/1/2017</option> </select> <button type="submit">Submit</button> </form>重现您的异常,当这样调用(None, None, None, 9)时:

np.prod()

这是因为您尝试将from keras import backend as K #create tensor placeholder z = K.placeholder(shape=(None, None, None, 9)) #obtain its static shape with int_shape from Keras actual_shape = K.int_shape(z) #obtain product, error fires here... TypeError between None and None dim = np.prod(actual_shape[1:]) 类型的两个元素相乘,即使您对None进行了切片(因为actual_shape中的元素超过1个)。在某些情况下,如果在切片后只剩下一个非类型元素,您甚至可以在NoneTypeError之间获得None

看一下你提到的answer,他们会指出在这些情况下该怎么做,引用它:

  

对于未定义多个维度的情况,我们可以将 tf.shape() tf.reduce_prod()一起使用。

基于此,我们可以分别使用intdocs)和K.shape()docs)将这些操作转换为Keras API:

K.prod()

此外,对于只有一个维度未定义的情况,请记住使用z = K.placeholder(shape=(None, None, None, 9)) #obtain Real shape and calculate dim with prod, no TypeError this time dim = K.prod(K.shape(z)[1:]) #reshape z2 = K.reshape(z, [-1,dim]) 或其包装K.int_shape(z)而不只是K.get_variable_shape(z),同样在后端定义({{3} })。希望这能解决你的问题。