使用tensorflowjs进行联合学习

时间:2019-04-22 03:07:59

标签: tensorflow tensorflow.js

我正在使用tensorflowjs实现联合学习。但是我有点陷入联合平均过程中。这个想法很简单:从多个客户端获取更新的权重并将其平均在服务器中。

我已经在浏览器上训练了一个模型,通过model.getWeights()方法获取了更新的权重,并将权重发送到服务器进行平均。


//get weights from multiple clients(happens i client-side)
w1 = model.getWeights(); //weights from client 1
w2 = model.getWeights(); //weights from client 2


//calculate average of the weights(server-side)
var mean_weights= [];
let length = w1.length; // length of all weights_array is same
for(var i=0; i<length; i++){
    let sum = w1[i].add(w2[i]);
    let mean = sum.divide(2); //got confused here, how to calculate mean of tensors ??
    mean_weights.push(mean);
}

//apply updates to the model(both client-side and server-side)
model.setWeights(mean_weights);

所以我的问题是: 如何计算张量数组的均值? 而且,这是通过tensorflowjs执行联邦平均的正确方法吗?

2 个答案:

答案 0 :(得分:1)

是的,但要小心。你可以用 tf.mean 平均两个张量,就像 https://stackoverflow.com/users/5069957/edkeveked 说的那样。但是,请记住 axis=0 在 JavaScript 中应缩写为 0

只是用第二种方式重写他的代码:

const x = tf.tensor([1, 2, 3, 2, 3, 4], [2, 3]);
x.mean(0).print()

但是,你问你做得是否正确,这取决于你是否在进行平均。滚动平均值存在问题。

示例:

如果你平均 (10, 20) 然后 30,你得到的数字是 (22.5) 不同于平均 (20, 30) 然后 10 (17.5),这当然不同于同时平均所有三个,这会给你 20。

平均值一经计算就不再遵循与订单无关的原则。它是删除关联属性的除法部分。所以你需要:

A:存储所有模型权重并根据所有以前的模型每次计算新的平均值

B:为联合平均值添加一个加权系统,这样更新的模型就不会对系统产生重大影响。

哪个有意义?

我推荐 B 在您的情况下:

  1. 不想或无法存储提交的每个模型和重量。
  2. 您知道某些模型已经看到了更有效的数据,与盲模型相比,应该适当加权。

您可以计算加权平均值,调整现有模型与传入模型的分母。

在 JavaScript 中,你可以做一些简单的事情来计算两个值之间的加权平均值:

const modelVal1 = 0
const modelVal2 = 1

const weight1 = 0.5
const weight2 = 1 - weight1


const average = (modelVal1 * weight1) + (modelVal2 * weight2)

上面的代码是您常用的均匀加权平均值,但是当您调整权重 1 时,您正在重新平衡量表以显着调整结果以支持 modelVal1modelVal2

显然,您需要将我展示的 JavaScript 转换为张量数学函数,但这很简单。

在联邦学习中经常使用权重衰减的迭代平均(或加权平均)。请参阅 Iterate averaging as regularization for stochastic gradient descentServer Averaging for Federated Learning

答案 1 :(得分:0)

要计算2个张量的平均值,可以使用tf.mean

const x = tf.tensor1d([1, 2, 3]);
const y = tf.tensor1d([2, 3, 4]);
tf.stack([x, y]).print()
const mean = tf.stack([x, y]).mean(axis=0)

mean.print();