如何在损失函数中访问张量(y_true)?

时间:2020-08-03 07:48:37

标签: tensorflow machine-learning keras deep-learning nlp

我有一个像这样的字典

dict = {"class_1" : array(12, 13, 14, 1686, 124 ,123,....), "class_2" : array(12312,312,3,34,3...), ...}

在这里, 它包含使用bert的类(字符串类型)的嵌入。 因此,在损失函数中,我想最小化实际类和预测类的嵌入之间的差异。

def loss_function(y_true, y_pred):
    # what can I do for finding the class here 
    # I need to find the classes y_pred is pointing to 
    # then need to find mse between embedding vector of y_pred class and y_true classes.

所以主要问题是如何在每次迭代中找到y_true指向的值。 我不能对此执行任何数组函数,因为它是张量。 我需要在损失函数中执行几个任务:

  • 找到y_true为1的类的名称
  • 找到y_pred为1的类的名称
  • 从dict中获取它们的嵌入并对其求平均值,因为任务是多标签分类
  • 计算它们之间的mse

步骤1、2有问题。请帮助我,谢谢。

0 个答案:

没有答案
相关问题