A Federated Deep Learning Method for Adaptive System Scale
In the context of complex task-oriented collaboration, the limitations of communication bandwidth and computing resources, as well as the need for privacy protection, together constitute major challenges in this research area. To solve these problems, researchers propose a framework of Federated Learning (FL) as a solution. FL allows multiple devices to perform collaborative model training without directly exchanging raw data, thereby reducing communication requirements and protecting data privacy. However, some FL methods adopt a full-client-participation strategy, where all clients update their local model in each round. This method not only increases the number of communication, but also leads to performance degradation and response delay with the increase of client size. Therefore, this paper introduces a new FL protocol (Federated Deep Learning Alternating Direction Method of Multipliers, FDLADMM) based on primitive-dual optimization. FDLADMM algorithm uses bivariate to guide the client to carry out local training, reduces the communication times between devices, optimizes the training speed of the model, and with the increase of the system size, it can adapt effectively without super parameter adjustment. Through experiments, the advantages of the proposed method in communication efficiency and training speed are demonstrated, and it can be adapted effectively without hyperparameter adjustment when the system size is adjusted continuously. This innovative approach provides a viable and efficient solution to the challenges of collaboration in complex missions and is expected to be widely applied in future research and practice.
Federated Learning
随着数据在生活各个领域的大量积累,机器学习(Machine Learning, ML)技术得到了迅速发展和广泛应用。在众多应用场景中,数据以分布式的方式被大量生成,例如通过众包、移动电话、自动驾驶汽车、医疗中心和智能电网中的分布式传感器等手段。大量的数据,加之对个人隐私保护的迫切需求,使得传统的分布式机器学习实践难以直接应用。数据中心分布式学习框架
在这样的背景下,联邦学习作为一个新范式被提出,用以解决与数据中心环境相比具有显著不同特点的大规模分布式机器学习问题。FL是一种机器学习的分布式学习方法,旨在保护数据隐私的同时进行模型训练。在传统的集中式机器学习中,所有数据被收集到一个中心服务器进行训练,这可能会涉及大量用户数据的集中存储,存在隐私泄露的风险。而在联邦学习中,模型的训练过程分布在多个设备或智能体中进行,每个设备或智能体只需在本地存储和处理自己的数据,不需要将数据传输到中心服务器。联邦学习框架通过服务器与客户端协作计算全局ML模型,以此来克服传统方法中存在的问题,展示在
尽管FL为数据隐私提供了更好的保护,但也面临一些挑战,如通信效率、训练速度等问题。因此,本文提出一种新FL协议,FDLADMM算法是一种基于原始–对偶优化的算法,能够自动适应系统规模,并相比现有技术显著提高收敛速度。FDLADMM每轮涉及少数客户端,并为每位用户提供丰富的计算选择,减少了设备之间的通信次数,从而优化了模型训练速度。同时使用双变量引导本地训练过程,避免了客户端漂移问题,即防止模型过度拟合于特定客户端的数据。
在联邦学习的研究领域,尽管FedAvg因其简洁性而广泛应用于联邦学习,但它未能解决数据偏移问题,即数据量较小的设备对全局模型的贡献较小,从而影响模型的泛化性能,此外,FedAvg可能会使跨客户端的数据分布发生分歧。作者在
与之相比,本文提出的FDLADMM算法基于原始–对偶优化算法,能够有效适应系统规模,同时优化通信效率和模型训练速度。在FDLADMM中,服务器每轮仅与部分客户端通信,客户端只在被选中时执行本地训练,此外,FDLADMM使用双变量来指导模型进行本地训练,这有效地解释了在没有调优的情况下自动适应数据分布,这解决了在FL设置中遇到的一个关键挑战,即避免客户端漂移,以防止模型过度拟合特定选定客户的数据。本文的方法与随机ADMM
(1)
FL的目标是将其构建为一个损失最小化问题:最小化 ,其中m是客户端的数量, 表示加权的本地训练损失。常见的权重选择包括 ,其中 (即根据它们数据量比例给客户端赋予权重),或者 (即所有客户端平等权重,这有助于避免对拥有更多数据的客户端过度拟合,同时这也是本文实验中采用的选择)。FL的目标是针对分布在不同客户端的数据,对模型参数θ解决公式(1)所示的问题。与数据中心分布式机器学习相比,通信是FL应用的瓶颈。联邦系统中大量的客户端数量使得大规模消息传输的同步通信变得不可行,这是由于带宽限制和延迟节点的存在。此外,大数据量的参与以及原始数据中个人信息的可能泄漏,阻止了服务器收集本地数据进行集中处理。因此,本文提出的FDLADMM算法,在优化通信效率和训练速度的前提下,有效的保护了数据隐私。
如公式(2)所示,其所描述的最优化问题在数学上与公式(1)是等效的,因为它们的最优解是相同的。这种等效性使得公式(2)更适合在FL环境中进行解释和应用。
(2)
在公式(2)中, 可以被解释为客户端i持有的本地模型,这代表了每个参与方(如移动设备、智能体、个人用户等)在本地训练出的模型参数。而θ则可以被解释为服务器持有的全局模型,这是各个参与方本地模型的整合和更新结果。在t轮的初期,系统将选取一组客户端(标记为 ),并且这些客户端将会加载来自服务器的模型 。需要指出的是,FDLADMM的一个积极因素在于其客户端的选取机制:该机制可以是基于设备运行状况(如能量和带宽)的动态规则,其必须确保所有客户端都有一定的概率参与进来。然后进行本地训练,这一过程涉及优化函数 ,以求达到最小化,其中 定义为每个客户端i持有的本地对偶变量,且 属于 , 定义为最小化 的解, 是二次项的系数。
(3)
对偶变量和二次项的结合有助于在更新模型参数(使用本地数据)和与服务器模型保持一致之间取得平衡(服务器模型用于整合所有参与者的信息)。本文注意到,FedProx以
的方式类似地解决了(3)式,虽然这在一定程度上为客户端偏移提供了保障。然而,FedProx的竞争性能依赖于对ρ的精确调整
(4)
服务器在从选定的客户端收集更新消息后更新全局模型:
(5)
其中 是服务器聚合步长。不同的 值适用于系统规模和统计变化各异的场景。本文通过经验观察到,将 设置为1可以实现快速训练速度,而设置 为 有助于在检测到显著的数据差异性时消除振荡行为。与FedAvg/Prox不同,后者仅使用客户端 的当前模型信息来更新全局模型,FDLADMM有效地融入了历史信息,这是通过公式(5)中的更新规则实现的。
本文阐述了FDLADMM如何对FedAvg和FedProx进行了扩展。如公式(3)所定义,当设置对偶变量 时,本文就回到了FedProx的局部训练问题。如果进一步将二次项系数ρ设置为0,那么就转变为了FedAvg的局部训练问题。这两个参数的引入主要是为了应对联邦学习中的一个基本挑战,即如何平衡广泛的本地训练(以加快整个FL过程的收敛速度,从而减少通信轮次总数),以及防止客户端模型过度拟合其本地数据。因此本文认为FDLADMM是一种有效的策略,能够利用本地设备的能力来减少通信成本。
本文分别在两种图像分类的数据集上进行实验,数据集的具体介绍如下。
Fashion-MNIST是一个图像分类数据集,用于机器学习算法的基准测试。Fashion-MNIST (FMNIST)由Zalando Research提供,包含70,000个服饰图像,分为10个类别,其中训练集包含60,000张图像,测试集包含10,000张图像,数据集的图像大小为28 × 28像素
CIFAR-10是一个包含日常物品的彩色图像数据集,由Hinton的学生Alex Krizhevsky和Ilya Sutskever整理,旨在识别广泛的物体类别
本文实验具体的硬件配置如
参数 |
参数取值 |
操作系统 |
Ubuntu 20.04 |
开发语言 |
Python 3.8.0 |
CUDA |
11.3 |
CPU |
Xeon(R) Platinum 8358P |
GPU |
RTX 3090(24G) |
本文使用两种流行的具有两个卷积层的卷积神经网络(Convolutional Neural Network, CNN)模型
模型 |
参数数量 |
数据集 |
目标准确率 |
CNN 1 |
1,663,370 |
FMNIST |
80% |
CNN 2 |
1,105,098 |
CIFAR-10 |
45% |
首先,本文使用了FMNIST和CIFAR-10数据集,通过设置相同的客户端数量,比较了相同的通信次数时模型的测试精度,测试精度的具体计算如式(6)所示。
(6)
其中,TP表示将正样本预测为正样本总数,FN表示将正样本预测为负样本总数。实验结果如
国家自然科学基金资助项目(61602305, 61802257);上海市自然科学基金资助项目(18ZR1426000, 19ZR1477600。
*通讯作者。