Keras:批次中有正负样本的三联体丢失

时间:2017-11-29 11:12:35

标签: keras triplet

我尝试重构我的Keras代码,对{3}}中提出的三元组使用'Batch Hard'采样。

  

“核心思想是通过随机抽样P类来形成批次   (人物身份),然后随机抽样每个班级的K图像   (人),从而产生一批PK图像。现在,每个人   在批次中抽样,我们可以选择最难的正面和   形成三胞胎时,批次中最难的阴性样本   用于计算损失,我们称之为批量硬“

所以目前我有一个Python生成器(用于Keras中的model.fit_generator),它在CPU上生成批处理。然后,可以在GPU上完成模型的实际前向和后向传递。

但是,如何使用“批量硬”方法?发生器采样64个图像,应该形成64个三元组。首先需要正向传递以获得当前模型的64个嵌入。

    embedding_model = Model(inputs = input_image, outputs = embedding)

但是,必须从64个嵌入中选择最难的正面和最难的负面形成三元组。然后可以计算损失

    anchor = Input(input_shape, name='anchor')
    positive = Input(input_shape, name='positive')
    negative = Input(input_shape, name='negative')

    f_anchor = embedding_model(anchor)
    f_pos = embedding_model(pos)
    f_neg = embedding_model(neg)

    triplet_model = Model(inputs = [anchor, positive, negative], outputs=[f_anchor, f_pos, f_neg])

这个triplet_model可以通过定义三元组丢失函数来训练。但是,Keras是否可以使用fit_generator和'Batch Hard'方法?或者如何从批处理中的其他样本访问嵌入?

编辑:使用keras.layers.Lambda我可以定义一个自己的图层来创建带有输入(batch_size,height,width,3)和输出(batch_size,3,height,width,3)的三元组,但我还需要访问id在某个地方。这可能在图层内吗?

0 个答案:

没有答案