联邦学习(一)
今天我们来看一下联邦学习的开山论文Communication-Efficient Learning of Deep Networks from Decentralized Data
,简单介绍一下这个东西出现的背景。起因是谷歌的输入法想要做一个训练模型,基于谷歌超大体量的全球用户,他们的工程师提出了以下的解决问题的流程:
- Problems: Google wants to train a model using users' data
- Solutions:
- Centralized learning
- Collect users' data
- Train a model on the cluster
- challenge: Users may refuse to upload their data, especially sensitive data to Google's Server
问题显而易见,那么大量的隐私数据,用户是不可能允许你全部上传到服务器端进行训练的, 即使你谷歌无心,也难以保证其他的恶意用户不会窃取你的隐私。所以我们才会发现问题,也就有了这篇论文的诞生,下面我们开始详细地跟着论文过一遍。
Introduction
Increasingly, phones and tablets are the primary computing devices for many people. The powerful sensors on these devices (including cameras, microphones, and GPS), combined with the fact they are frequently carried, means they have access to an unprecedented amount of data, much of it private in nature.
首先这一段指出了联邦学习的主体就将集中在手机、平板等智能移动终端上,并且指出了数据隐私问题,我们不能直接将用户的数据赤裸裸地传到Server上进行训练。这样会导致隐私泄露。
- Instead, each client computes an update to the current global model maintained by the server, and only this update is communicated.......
- Since these updates are specific to improving the current model, there is no reason to store them once they have been applied.
然后文章指出了一种方法就是:每个客户端自身利用自己的数据更新当前服务端的全局模型,然后只将这一小部分上传 ,这样就使得至少从表面上看起来,客户的隐私始终没有离开client端。
the identification of the problem of training on decentralized data from mobile devices as an important research direction;
the selection of a straightforward and practical algorithm that can be applied to this setting;
an extensive empirical evaluation of the proposed approach. More concretely, we introduce the FederatedAveraging algorithm, which combines local stochastic gradient descent (SGD) on each client with a server that performs model averaging.
在这里作者渐渐给我们展开一个模型了,名字叫FederatedAveraging
的算法,大致是想要让devices
在各自利用随机梯度下降(SGD)训练数据,然后传给Server
。然后作者号称这个模型是可以做到在非独立同分布(NON-IID)数据下的鲁棒性,并且可以减少分布式学习中的通信轮数。
Federated Learning
- Training on real-world data from mobile devices provides a distinct advantage over training on proxy data that is generally available in the data center.
- This data is privacy sensitive or large in size (compared to the size of the model), so it is preferable not to log it to the data center purely for the purpose of model training (in service of the focused collection principle).
- For supervised tasks, labels on the data can be inferred naturally from user interaction.
然后作者给出了适用于联邦学习的应用问题,总结来说就是有三点特性:一是适用于需要更加真实的来自于现实生活的数据,这一点是数据中心的样板数据无法比拟的优势;二是这些数据是相当隐私或者相当庞大的,不适用于完全传输到服务中心进行训练;三是对于监督任务,数据标签很容易就能从用户交互中获得。 紧接着,作者举出了很明显的两个例子:image classification和language models,仔细分析会发现这些都满足上面的三点特性,并且分别可以将卷积神经网络和递归神经网络应用模型。
Privacy
这里作者简单地argue了一下联邦学习对于隐私泄露的风险仅仅在于minimal update necessary to improve a particular model
,并且提到会在文章最后简单介绍一下和差分隐私算法的结合。
Federated Optimization
- Non-IID: The training data on a given client is typically based on the usage of the mobile device by a particular user, and hence any particular user’s local dataset will not be representative of the population distribution.
- Unbalanced: The training data on a given client is typically based on the usage of the mobile device by a particular user, and hence any particular user’s local dataset will not be representative of the population distribution.
- Massively distributed: We expect the number of clients participating in an optimization to be much larger than the average number of examples per client.
- Limited communication: Mobile devices are frequently offline or on slow or expensive connections.
总的来说,联邦学习比一般的分布式机器学习要考虑的优化问题更多一点。作者给出了其中的四点(当然其实还有更多的问题等待着暴露),第一是各个节点没办法实现数据的独立同分布 ,每个手机的数据特点显然也无法概括整体的数据集;第二是不平衡性 ,显然有些用户的使用频率、数据质量可能高于另一些用户,那么如何分配权重也是一个很大的问题;第三就是超大体量的分布式计算 ,显然用户的数量已经远远超过了在实验室中的样例;第四是交流的限制性 ,我们都知道手机等移动端设备肯定会经常出现不稳定的现状,掉线、低带宽的通信都是问题。作者重点argue的是 Non-IID和不平衡性。
There is a fixed set of K clients, each with a fixed local dataset. At the beginning of each round, a random fraction C of clients is selected, and the server sends the current global algorithm state to each of these clients (e.g., the current model parameters).
作者首先做了一个实验,固定K个Client,每个Client都持有一定的本地数据。从中随机挑选一部分客户端
,服务器将model parameters
发送给这些客户端,客户端利用本地的数据集进行运算,并且将更新后的参数发送给服务端,服务端整合数据以后重复上述过程,继续发送给客户端。(至于为什么只选择一部分Clients,作者给出的解释是We only select a fraction of clients for effificiency, as our experiments show diminishing returns for adding more clients beyond a certain point.
)换言之,是为了实验的效率考虑。
对于大部分非凸神经网络,我们考虑以下的公式是普适的: \[ \min _{ w \in \mathbb{ R }^{ d } } f(w) \quad \text { where } \quad f(w) \stackrel{ \operatorname{ def } } { = } \frac { 1 } { n } \sum_{ i=1 }^{ n } f_{ i }(w) \] 对于分布式机器学习的问题,我们定义如下的函数\(f_{i}(w)=l(x_{i},y_{i};w)\),来表示由\(model \ parameter \ w\)造成的\(loss\)损失\((x_{i},y_{i})\)。我们假设有K个Clients,并且client k上分配到的数据集索引为\(P_{k}\),并且\(n_{k}=|P_{k}|\),我们可以修改上述公式为 \[ f(w)=\sum_{k=1}^{K} \frac{n_{k}}{n}F_{k}(w) \quad where \quad F_{k}(w)=\frac{1}{n_{k}}\sum_{i\in{P_{k}}}f_{i}(w). \] 从这个式子我们可以看出,如果每个节点Client的数据是完全随机分布的话,也就是会做到$Independent & Identically Distributed (IID) \(,那么我们就可以得到\)F_{k}(w)\(的期望事实上是和\)f(w)$大抵相近的,也就是每个独立节点得到的模型已经几乎可以代替全局模型。
但是联邦学习中明显数据的分布是不满足\(IID\)的,每个用户节点都带有很强的个性,所以无法使用这种假设。作者为了破坏这种随机性,在实验中使用的是(that is, F could be an arbitrarily bad approximationto f)
,也就是将\(F_{k}\)拟合为一个很糟糕的函数以至于无法很好地描述\(f\)。
上面是针对模型非独立同分布作出的假设。下面作者考虑如何对联邦学习的通信进行优化。显然的是,在计算中心,用有足够的带宽和稳定的通信,那么针对计算量的优化是很有必要的;但是联邦学习的环境下,用户节点的带宽甚至无法超过1MB/s,而且像手机等移动设备,通常用户只会在充电、使用未计量的WIFI时才会主动参与计算,这限制了通信;另一方面,考虑到现代设备比如手机都有着相对较快的处理器以及GPU,每台设备上自己的数据集也是非常小的尺度,在多数模型下,计算所带来的消耗和通信相比,几乎可以忽略不计。所以,作者给出的优化方法就是,在客户端本地使用额外的计算,来减少需要通信的轮数,具体来说是以下两个方面:
- increased parallelism, where we use more clients working independently between each communication round.
- increased computation on each client, where rather than performing a simple computation like a gradient calculation, each client performs a more complex calculation between each communication round.
相较于增加数据的并发性,作者在实验中的方法是增加在每台客户端上的计算量,不仅仅再是简单的梯度计算。
Related work
总结前人在分布式计算中作出的贡献,\(McDonald\)等人研究了通过逐轮迭代平均本地训练模型的方法进行分布式训练;\(Zhang\)等人也研究了通过一种"软"平均的方式来实现异步分布训练。这些研究成果最大的问题就是要么只考虑到了数据中心少量的\(worker \ node\)s集合,要么就是忽略了数据的\(Unbalance\)和\(Non-IID\)特性,而这些恰恰都是联邦学习的特征。
同时作者也提到了诸如\(Neverova\)和\(Shori \ and \ Shmatikov\)等研究团队都在考虑用户数据的隐私敏感问题,也采取了一定的措施来防止,但是依然没有考虑到上述的两点问题。等等多种算法都不能完美地解决联邦学习中存在的问题。
作者考虑的(参数化)算法是简单的做一次平均,也就是说每一个客户端利用本地的数据训练一个最小化\(Loss\)函数的模型,然后这些模型将被平均成为全局模型。作者也指出,在\(IID\)情况下,这种方法训练出的本地模型甚至比全局模型更优,当然,是在某些特殊的状况。
The FederatedAveraging Algorithm
在这一部分作者开始介绍著名的FedAvg
算法。首先,作者指出最近的深度学习的研究几乎都是依赖于随机梯度下降的方法\((SGD)\);而事实上很多的研究进展也都是可以理解为在调整模型的结构(或者是损失函数)来使得模型可以实现通过简单的基于梯度的优化。
根据已有的前人的实验结果可以发现,\(SGD\)是可以应用在联邦优化上的。只不过之前的在数据中心做的实验,为了提高效率,都是采取小的batch,多的rounds。但是考虑到在联邦学习中,参与计算的客户端可以很多,而且不必考虑wall-clock time
的损失问题,所以作者的实验都是基于较大的批尺寸的同步随机梯度下降法。每次从所有客户端中随机选择一部分,比例大小为C,这里当\(C=1\)时,也就是全批次的非随机梯度下降算法。我们称这种算法为\(FederatedSGD\)或\(FedSGD\)。
一种典型的算法是,当\(C=1\)时,在一个固定的学习率\(\eta\)和客户端数量\(k\)中,每一个节点应用以下的梯度计算: \[ g_{k}=\nabla{F_{k}(w_t)} \] 然后中央服务器汇合这些梯度并且做一次全局的更新: \[ w_{t+1} \gets w_{t}-\eta\sum_{k=1}^{K} \frac {n_{k}}{n} g_{k}=w_{t}-\eta\nabla f(w_{t}) \] 一种等价的变换就是对于每个节点来说都利用本地的数据集做一次参数的更新,然后将这些更新后的结果交给中央服务器,中央服务器对这些本地已经局部梯度更新后的参数再做加权平均。 这也就是整个联邦学习\(FederatedAveraging\)的核心: \[ \forall{k}, \quad w_{t+1}^{k} \gets w_{t}- \eta g_{k} \\ then, \quad w_{t+1} \gets \sum_{k=1}^K \frac {n_{k}}{n} w_{t+1}^{k} \] 因此,到了这一步我们就已经搞清楚了整个\(FedAvg\)算法的要点,那就是把梯度计算和模型参数更新全部下放到客户端节点中去,然后各个client相当于直接将模型参数传回云端,云端的服务器最终只需要简单地对这些模型作加权平均处理,然后再下放到各个客户端作为新一轮训练的模型参数。这样的话也带来了很大的一个便捷那就是,本地客户端可以多次重复进行梯度下降的过程,以期在更少的全局轮数内获得更好的拟合效果。针对每个客户端设备的计算量,有如下的衡量参数:
C:the fraction of clients that perform computation on each round;
E:the number of training passes each client makes over its local dataset on each round;
B:the local minibatch size used for client updates.
简单解释一下,就是:C代表的是每轮计算中参与的设备比例;E代表的是每个节点每轮中本地的训练次数;B代表的是本地训练中每次拿出的batch size。那么这样一来一个拥有\(n_{k}\)个样例,