fed_avg

2025/04/03

Fed_Avg

算法原理

设一共有$K$个客户机。中心服务器初始化模型参数,执行若干轮,每轮选取至少1个至多$K$个客户机参与训练,每个被选中的客户机同根据服务器下发的本轮($t$轮)模型参数$w_t$,用自己的数据训练得到模型参数$w_{t + 1}^k$,上传回服务器。服务器将收集来的各客户机的模型根据各样本数量用加权平均的方式进行聚合,得到下一轮的模型参数$w_{t + 1}$:

$$ w_{t+1} \leftarrow \sum_{k = 1}^{K} \frac{n_k}{n} w_{t + 1}^k $$

// $n_k$为客户机$k$上的样本数量,$n$为所有被选中客户机的总样本数量

FedAvg步骤:

$K$个客户端编号为$1...k$;$B$,$E$,$\eta$ 分别代表本地的minibatch size,epochs,学习率learning rate

服务器执行:

初始化$w_0$

for 每轮$t = 1,2,\dots$, do

 $m \leftarrow \max(C \cdot K, 1)$ // $C$为比例系数

 $S_t \leftarrow$ (随机选取$m$个客户端)

 for 每个客户端$k \in S_t$同时 do

 ​ $w_{t + 1}^k \leftarrow$ 客户端更新$(k, w_t)$

 ​ $w_{t+1}\leftarrow\sum_{k = 1}^{K}\frac{n_k}{n}w_{t + 1}^k$ // $n_k$为客户机$k$上的样本数量,$n$为所有被选中客户机的总样本数量

客户端更新$(k, w_t)$:   // 在客户端 $k$ 上运行

$\beta \leftarrow$ (将$P_k$分成若干大小为$B$的batch) // $P_k$为客户机$k$上数据点的索引集,$P_k$大小为$n_k$

for 每个本地的epoch $i (1 \sim E)$ do

 for batch $b \in \beta$ do

 ​ $w \leftarrow w - \eta \nabla l(w; b)$ // $\nabla$为计算梯度,$l(w; b)$为损失函数

返回$w$给服务器

为了增加客户机计算量,可以在中心服务器做聚合(加权平均)操作前在每个客户机上多迭代更新几次。计算量由三个参数决定:

当$E = 1\ B = \infty$时,对应的就是FedSGD,即每一轮客户机一次性将所有本地数据投入训练,更新模型参数。

对于一个有着$n_k$个本地样本的客户机$k$来说,每轮的本地更新次数为$u_k = E \cdot \frac{n_k}{B}$。

Fed_Avg

算法原理

设一共有$K$个客户机。中心服务器初始化模型参数,执行若干轮,每轮选取至少1个至多$K$个客户机参与训练,每个被选中的客户机同根据服务器下发的本轮($t$轮)模型参数$w_t$,用自己的数据训练得到模型参数$w_{t + 1}^k$,上传回服务器。服务器将收集来的各客户机的模型根据各样本数量用加权平均的方式进行聚合,得到下一轮的模型参数$w_{t + 1}$:

$$ w_{t+1} \leftarrow \sum_{k = 1}^{K} \frac{n_k}{n} w_{t + 1}^k $$

// $n_k$为客户机$k$上的样本数量,$n$为所有被选中客户机的总样本数量

FedAvg步骤:

$K$个客户端编号为$1...k$;$B$,$E$,$\eta$ 分别代表本地的minibatch size,epochs,学习率learning rate

服务器执行:

初始化$w_0$

for 每轮$t = 1,2,\dots$, do

 $m \leftarrow \max(C \cdot K, 1)$ // $C$为比例系数

 $S_t \leftarrow$ (随机选取$m$个客户端)

 for 每个客户端$k \in S_t$同时 do

 ​ $w_{t + 1}^k \leftarrow$ 客户端更新$(k, w_t)$

 ​ $w_{t+1}\leftarrow\sum_{k = 1}^{K}\frac{n_k}{n}w_{t + 1}^k$ // $n_k$为客户机$k$上的样本数量,$n$为所有被选中客户机的总样本数量

客户端更新$(k, w_t)$:   // 在客户端 $k$ 上运行

$\beta \leftarrow$ (将$P_k$分成若干大小为$B$的batch) // $P_k$为客户机$k$上数据点的索引集,$P_k$大小为$n_k$

for 每个本地的epoch $i (1 \sim E)$ do

 for batch $b \in \beta$ do

 ​ $w \leftarrow w - \eta \nabla l(w; b)$ // $\nabla$为计算梯度,$l(w; b)$为损失函数

返回$w$给服务器

为了增加客户机计算量,可以在中心服务器做聚合(加权平均)操作前在每个客户机上多迭代更新几次。计算量由三个参数决定:

当$E = 1\ B = \infty$时,对应的就是FedSGD,即每一轮客户机一次性将所有本地数据投入训练,更新模型参数。

对于一个有着$n_k$个本地样本的客户机$k$来说,每轮的本地更新次数为$u_k = E \cdot \frac{n_k}{B}$。