首页 / 专利库 / 人工智能 / 人工智能 / 机器学习 / 半监督学习 / 一种自适应通信系统神经网络均衡方法

一种自适应通信系统神经网络均衡方法

阅读:497发布:2020-05-15

专利汇可以提供一种自适应通信系统神经网络均衡方法专利检索,专利查询,专利分析的服务。并且本 发明 提出一种自适应通信系统神经网络均衡方法,考虑到在实际通信场景下器件和信道特性随时间变化这一自适应调整的需求,借助 机器学习 当中的半 监督学习 算法 ,通过设计损失函数来加快模型微调的收敛速度,从而使得基于神经网络的均衡算法灵活性得到极大提升,在信道条件发生变化的情况下通过调整模型来维持较低的误码率。本发明涉及的算法完全不需要发端提供训练序列,而是类似于判决反馈模式,将现有模型的判决作为对应符号的标签,在此 基础 上设计合适的半监督损失函数并进行学习。,下面是一种自适应通信系统神经网络均衡方法专利的具体信息内容。

1.一种自适应通信系统神经网络均衡方法,包括以下步骤:
利用初始的训练集对一个基于神经网络的均衡器进行线下训练,训练完成后得到神经网络模型;
通信系统在线接收经过信道传输的物理信号,并转化为电信号,对电信号进行重新采样和零均值标准化,得到接收信号序列;
针对接收信号序列,将当中的任一符号及其前后L个符号所对应的接收信号相拼接,作为该符号的输入特征向量
每接收Nb的输入特征向量,利用滑动窗口组合为一个batch的数据;
遍历每个batch的数据中的所有输入特征向量,计算损失函数对神经网络模型的模型参数的梯度,基于梯度下降调整模型参数;
利用调整好模型参数的神经网络模型继续处理后续的batch的数据,直到将接收到的所有数据处理完毕。
2.如权利要求1所述的方法,其特征在于,初始的训练集的构建方法为:对于一个长度为Γ·Ntr的序列的原始信号,其中Γ为上采样倍数,Ntr为训练集所包含的符号总数目;其中第i个符号的输入特征向量包括以第i个符号为中心的前后共2L+1个符号,长度为Γ·(2L+1);用该Ntr个输入特征向量作为网络的输入,对应的Ntr个符号所属类别作为网络的输出,构建初始的数据集。
3.如权利要求1所述的方法,其特征在于,对基于神经网络的均衡器进行训练的步骤包括:
神经网络的输入向量为v,输出向量为o=NNθ(v),其中o是一个M维的向量,给出了v对应符号属于M个类别的概率;
训练时采用损失函数作为交叉熵,公式为
对模型参数进行调整,减小神经网络在训练集上的损失函数,完成训练,得到神经网络模型。
4.如权利要求1所述的方法,其特征在于,对电信号进行重新采样的方法为:利用数字示波器对电信号进行采用,并在数字域进行重采样。
5.如权利要求1所述的方法,其特征在于,零均值标准化的方法为:将重采样后的信号序列记为 该训练的均值记为μs,均方差记为σs,则通过公式
得到标准化的接收信号序列。
6.如权利要求1所述的方法,其特征在于,每接收Nb的输入特征向量,利用滑动窗口组合为一个batch的数据,步骤如下:
对于接收信号序列,每个符合对应Γ个采样点,建立一个长度为Γ(2L+1)的滑动窗口,开始时以序列的第一个符号为中心,按照时间顺序滑动;每滑动一次,窗口内的向量即为相应符号的输入特征向量,表达式为v(i)=[s′i-L,…,s′i,…,s′i+L];
窗口滑动Nb次之后共采集了相邻Nb个符号的输入特征向量,将这些输入特征向量共同组成一个batch的数据。
7.如权利要求1所述的方法,其特征在于,计算损失函数对神经网络模型参数的梯度的方法为:
开始处理单个batch的数据时,初始化损失函数L=0,初始化全零向量gt用于记录损失函数对模型参数的梯度;
遍历batch中的所有输入特征向量v(i),为样本v(i)赋予一个伪标签l(i),表明当前神经网络模型将v(i)判决为属于第l(i)类,其中l(i)∈{1,…,M};
针对样本v(i),更新损失函数:
其中l(i)代表算法赋予样本v(i)的伪标签;
借助反向传播算法求解损失函数变化量ΔL对模型参数的梯度,并累加到gt上。
8.如权利要求7所述的方法,其特征在于,赋予伪标签的方法包括如下两种:
一种是仅做数据增强:在训练阶段,给输入特征向量加入一给定的噪声,得到增强的神经网络输出的概率向量,进而得到伪标签;
另一种是将数据增强与虚拟对抗训练相结合:
首先对于给定的单个样本v,生成一个向量d,d的维度与输入特征向量维度相同,并且每一维相互独立,服从均方差为σ的零均值高斯分布;
将d叠加在输入特征向量上,用交叉熵来衡量NN(v+d)与NN(v)这两个概率向量的差异,并求这一交叉熵对向量d的梯度,记为g;
将向量g长度归一化,再乘以一个给定值,即可最终得到虚拟对抗扰动向量radv;
将对抗扰动向量加入到特征输入向量中,得到增强的神经网络输出的概率向量,进而得到伪标签。
9.如权利要求1所述的方法,其特征在于,基于梯度下降调整模型参数的方法为:
利用优化算法对模型参数进行调整,参数调整公式如下:
其中,θt表示模型参数,φ(g1,…,gt)表示基于当前梯度和历史梯度信息对实际梯度的估计, 为对应于优化算法的自适应学习率。
10.如权利要求9所述的方法,其特征在于,优化算法至少包括如下三种:
1)如果取αt=α,ψ(g1,…,gt)=1以及φ(g1,…,gt)=gt,所述参数调整公式对应朴素的随机梯度下降:θt+1=θt-αgt;
2)如果取αt=α,ψ(g1,…,gt)=1以及 所述参数调整公式对
应带动量的随机梯度下降;
3)如果取αt=α, 以及φ(g1,…,gt)=gt,所述参数调整
公式对应AdaGrad优化算法,其中,ε0为一个给定小量。

说明书全文

一种自适应通信系统神经网络均衡方法

技术领域

[0001] 本发明属于通信技术领域,涉及新型通信系统中针对神经网络均衡算法进行自适应参数调整的关键技术,具体涉及在借助判决反馈模式进行参数调整的过程中引入数据增强和虚拟对抗训练的技术,以提高神经网络均衡算法的泛化性能,降低通信系统误码率。

背景技术

[0002] 近年来,随着新兴科技的不断发展,互联网已与人类生活紧密结合在一起。人们对互联网的诉求也远远超越了原先简单的邮件收发、文本图片传送等,取而代之的是随时随地高质量的视频通信、视频下载等高端需求。作为人类通向信息化时代的主要手段,通信技术在人类文明进一步发展的过程中发挥着至关重要的作用。举例来说,在多种主流的通信方式当中,光纤通信承担了主干道与高速公路的色。经过数十年的发展,光纤通信系统的通信容量已经达到每秒上百Tbit,引领当今社会进入信息时代。作为互联网、计算和人工智能技术基础设施的数据中心网络,则更离不开光纤通信技术。
[0003] 为了不断提高通信系统的容量,研究者们做了大量的工作,其中新型的数字信号处理(DSP)技术(主要是新型均衡算法)是改善误码率(BER)性能、提高通信系统传输速率的关键。目前普遍使用的传统均衡算法包括前馈均衡(FFE)、判决反馈均衡(DFE)、最大似然序列估计(MLSE),这些方法借助接收信号和给定的模型,判断发端发送的符号。上述DSP算法都是基于丰富的专家知识来设计的,在某些特定的信道模型下可以证明是最优的算法。然而,实际系统中存在多种非线性效应(如调制非线性和平方律检测),很难借助传统的DSP技术进行均衡。针对这一问题,许多研究人员提出了基于神经网络的均衡算法。由于采用了更加复杂、表达能更为强大的神经网络模型,新型均衡算法相比于传统方法,达到更好的误码率性能。
[0004] 然而,当前已经提出的基于神经网络的均衡算法,依赖于充足的线下训练,通信系统上线后,信道发生变化时如何自适应地对模型进行调整仍然是一个难题。

发明内容

[0005] 本发明针对上述问题,借助机器学习当中的半监督学习算法,针对神经网络均衡器提出了一种普适的自适应参数调整方法,考虑到在实际通信场景下器件和信道特性随时间变化这一自适应调整的需求,本方法针对神经网络均衡算法提出了一种基于半监督学习的在线更新神经网络均衡器参数方案,通过设计损失函数来加快模型微调的收敛速度,从而使得基于神经网络的均衡算法灵活性得到极大提升,在信道条件发生变化的情况下通过调整模型来维持较低的误码率。本发明涉及的算法完全不需要发端提供训练序列,而是类似于判决反馈模式,将现有模型的判决作为对应符号的标签,在此基础上设计合适的半监督损失函数并进行学习。
[0006] 本发明提出的一种自适应通信系统神经网络均衡方法,包括以下步骤:
[0007] 1)对于任何一个基于神经网络的均衡器(不论其具体结构如何,是否涉及卷积层或循环层),首先需要在初始的训练集上进行线下训练。线下训练完成后,得到神经网络模型。通信系统上线,此后的均衡和参数调整均在线进行,初始的训练集被抛弃。
[0008] 2)通信系统上线后,发送端对所要发送的数据段进行编码,同时持续地发送信号;信号经过信道传输,波形发生畸变;接收端用相应器件(如光电二极管)接收,将物理信号转化为电信号,再经过模数转换,变为数字信号,从而可以方便地在计算机内部进行数字信号处理的相关操作。
[0009] 3)在接收端对电信号进行重采样、零均值标准化,得到接收信号序列;对于某一个符号,将这一符号及其前后L个符号(共2L+1个符号)所对应的接收信号相拼接,作为这一个符号的输入特征向量
[0010] 4)每接收一定数量Nb的输入特征向量,就组合成一个batch(批次)的数据,就根据这一个batch的数据对神经网络模型的参数进行在线更新。
[0011] 5)若发送端完成所有信息的发送,接收端数据处理完毕,则系统可停止运行。
[0012] 进一步地,步骤3)得到输入特征向量之后,便可以借助神经网络模型对接收信号进行判决,并输出序列中符号属于各个类别的概率,并根据这一概率向量进行分类,从而在接收端得到每个符号分别属于哪一类的判决结果。
[0013] 整个通信系统的示意图如图1所示;完整的工作流程如图2所示。
[0014] 进一步地,接收Nb个输入特征向量作为一个batch可以以下面的方式实现:
[0015] 1)接收到的原始数据经过重采样、零均值标准化之后,可以排成一个序列,每个符号对应Γ个采样点。
[0016] 2)构建一个长度为Γ(2L+1)的滑动窗口,每次采集2L+1个符号所对应的数据,作为一个输入特征向量。
[0017] 3)对连续的Nb个特征向量完成采集之后,这Nb个特征向量被组织为一个单独的batch。
[0018] 借助滑动窗口采集输入特征向量的过程如图3所示。
[0019] 进一步地,根据第t个batch中的数据对模型参数进行在线更新可以以下面的方式实现:
[0020] 1)对于第t个batch中的第i个样本v(i),可以计算出损失函数(预先定义好)对模型参数的梯度。
[0021] 2)遍历该batch中的所有样本,保存所有Nb个样本对应的梯度。
[0022] 3)将第t个batch中,根据每个样本计算出的梯度(共Nb个)进行加权平均,得到第t个batch对应的梯度gt。
[0023] 4)根据梯度gt,结合历史梯度信息gt-1、…、g1,计算出实际梯度的估计值,根据梯度下降算法对参数进行微调。梯度下降的步长(也称为学习率)可以是定值(对应朴素的随机梯度下降算法),也可以是自适应的(对应各种自适应学习率的优化算法)。
[0024] 一个普适的优化算法表达式为:
[0025]
[0026] 其中θt代表模型参数,φ(g1,…,gt)代表对实际梯度的估计,前面的因子对应于优化算法的自适应学习率。举例来说,考虑如下三种深度学习领域常见的优化算法:
[0027] (1)如果取αt=α,ψ(g1,…,gt)=1以及φ(g1,…,gt)=gt,上述公式对应朴素的随机梯度下降:θt+1=θt-αgt。
[0028] (2)如果取αt=α,ψ(g1,…,gt)=1以及 上述公式对应带动量的随机梯度下降(Momentum-SGD)。
[0029] (3)如果取αt=α, 以及φ(g1,…,gt)=gt,上述公式对应AdaGrad优化算法。其中,ε0为一个小量。
[0030] 对单个batch数据的处理过程如图4所示。
[0031] 进一步地,针对第i个样本v(i)计算损失函数对模型参数的梯度可以以下面的方式实现:
[0032] 1)首先对v(i)作数据增强,一般是在v(i)上加入一个噪声向量η。举例来说,可以使η中每个元素满足一个高斯分布: 其中σ为高斯分布的均方差。
[0033] 2)由于这一自适应调整算法不需要训练序列,因此发送端并没有对v(i)提供对应(i) (i) (i)的标签l 。为此,必须根据神经网络的判决结果,手动为v 赋予一个伪标签l 。参考半监督学习算法,这一伪标签可以通过不同方式得到。两个例子:一,直接用神经网络对数据增强后的v(i)进行判决得到l(i);二,将虚拟对抗训练(virtual adversarial training,简记为VAT)与数据增强结合,先计算出对抗扰动radv,将其与v(i)加和作为新的输入特征向量,再(i)
用神经网络对这一输入向量进行判决得到l 。
[0034] 3)由第i个样本计算出损失函数的变化。具体来说,损失函数可以采用分类问题中常用的交叉熵(cross-entropy)形式,只不过标签l(i)并非由发送端提供,而是由神经网络内部的判决得到的伪标签,在计算损失函数时直接当作真实标签来使用。损失函数的变化(即为第i个样本对应的损失函数值)具体表示为:
[0035]
[0036] 4)在计算出损失函数的基础上,更新第t个batch对应的总梯度:
[0037]
[0038] 对单个输入样本的处理流程如图5所示。
[0039] 值得注意的是,本专利中涉及的损失函数计算方法,实质上是把两种半监督学习领域常见的损失函数计算方法——Π-模型和虚拟对抗训练——相结合而得到的。具体来说,Π-模型的主要思想是借助随机扰动进行数据增强,在训练时优化损失函数的过程,即是希望同一个输入特征向量在不同随机扰动之下得到相同的分类结果;相比之下,虚拟对抗训练不使用随机扰动,而是使用人为计算出的虚拟对抗扰动,一般而言比Π-模型更为有效。在本专利中,既计算了虚拟对抗扰动,又在输入特征向量上加了随机扰动。当然,原则上讲,Π-模型或虚拟对抗训练可以单独使用,但专利中涉及的损失函数构建方式融合了两种方式的优点,能够稳定地降低误码率。另一方面,如果完全不对输入特征向量进行任何扰动,直接对交叉熵形式的损失函数进行优化,这种方式在实现上比较简单,原理与通信领域常用的基于判决反馈的自适应均衡十分类似,但会存在收敛速度极其缓慢的问题,因此并不实用。
[0040] 进一步地,对于v(i),借助虚拟对抗训练来计算对抗扰动radv可以以下面的方式实现,相比于只使用高斯噪声进行数据增强的方式,引入对抗扰动一般可以加快在线调整参数的收敛速度,性能上更为优越:
[0041] 1)首先生成向量d,d的每一维度相互独立,服从相同的高斯分布 其中ε为一较小的数值。
[0042] 2)将向量d与原始输入向量相加,采用这一新的输入特征向量计算损失函数(损失函数的形式仍为交叉熵),并使用反向传播算法计算出损失函数对向量d的梯度。这一梯度记为向量g。
[0043] 3)虚拟对抗扰动radv的方向与g相同,长度则一般为一较小的值。通过对向量g的长度进行归一化,再乘以给定的值,即可最终得到虚拟对抗扰动radv。
[0044] 计算虚拟对抗扰动radv的流程如图6所示。
[0045] 与现有技术相比,本发明的积极效果为:
[0046] 当前基于神经网络的均衡算法虽然能够在给定数据集上体现出明显优于传统均衡算法的性能,但许多文献表明,基于神经网络的均衡器在泛化性能上比较弱——当信道条件发生变化,误码率会急剧上升。为了解决这一问题,本发明针对基于神经网络的均衡器,实现了一种不需要训练序列、适用性较强的模型参数自适应调整方法;同时,本发明涉及的方法不同于单纯的判决反馈模式——由于引入了数据增强和虚拟对抗训练的技术,在线训练过程的收敛速度得到了明显提升,从而直接降低了系统的误码率。附图说明
[0047] 图1为完整通信系统的示意图;
[0048] 图2为系统工作流程(包含均衡和在线训练)示意图;
[0049] 图3为用滑动窗口采集单个batch数据的示意图;
[0050] 图4为单个batch数据的处理过程示意图;
[0051] 图5为单个输入样本的处理流程示意图;
[0052] 图6为虚拟对抗扰动向量的计算过程示意图;
[0053] 图7为具体实施方案中所涉及的光通信系统的示意图。

具体实施方式

[0054] 根据文献调研,当使用基于神经网络的均衡算法对接收信号进行均衡和判决时,一般需要首先在给定的信道条件下采集大量数据,用这些数据对神经网络进行线下训练。虽然在相关论文中,已经使用基于神经网络的均衡器取得了更好的误码率性能,但基于神经网络的均衡器能否在某些应用场景中取代传统的均衡器仍然是一个问题——神经网络面临的最严重问题是泛化性能弱。在实际通信系统中,系统所处的外部环境以及信道参数均可能发生缓慢变化,导致接收数据的概率分布不同于线下训练阶段所使用的训练集的概率分布。因此,虽然离线训练的神经网络均衡器能够在特定信道条件下获得的测试集上表现出更好的性能,但如果接收数据的分布偏离原始分布时,其性能就会严重下降。以数据中心中的短距离光互连系统为例,当环境温度波动,或光学器件仍处于预热阶段时,就会出现类似的问题。
[0055] 传统均衡算法通过自适应地微调模型参数来解决这一问题。具体来说,有两种不同的工作模式:训练序列模式和判决反馈模式。使用训练序列意味着发端需要额外发送一段训练序列,对于接收端而言这一段序列是已知的,因此可以借助相应算法对模型参数进行调整;如果工作在判决反馈模式,则接收端会默认现有模型对当前符号的判断是正确的,即认为现有模型的判决结果是正确的,在此基础上调整模型参数。
[0056] 原则上讲,基于神经网络的均衡器可以工作在训练序列模式。然而,使用标准中规定的较短训练序列对网络进行训练往往会导致其他问题。有文献提出神经网络对于较短序列和伪随机序列容易过拟合,模型在训练序列上准确率很高但用其他数据测试时效果明显变差。另一方面,由于神经网络参数较多,训练过程的收敛速度一般比较缓慢。还必须考虑到,在某些应用场景下,发端提供训练序列的代价较大(有时甚至是不现实的)。因此,针对神经网络均衡算法开发基于判决反馈的自适应参数调整算法是十分必要的。
[0057] 下面通过实施例,并配合附图说明,对本发明的技术方案作详细的说明。
[0058] 最近几年,产生大量互联网流量的数据中心引起了人们的极大关注。基于垂直腔面发射激光器和多模光纤(VCSEL-MMF)的短距离光互连链路由于其高容量、低功耗、低成本等优点,将继续成为数据中心网络中应用最广泛的光互连链路。为了提高基于VCSEL-MMF的光通信系统的容量,研究者们做了大量的工作。鉴于此,在本实施例中,主要针对一基于VCSEL-MMF的短距离光互连系统采用本发明提出的一种自适应通信系统神经网络均衡方法,如图7所示。所使用的数据均是在这一实验平台上采集得到的。当然,本方法在其他通信系统(长距光传输系统、无线通信系统)中的应用也是类似的。
[0059] 本方法采用基于半监督学习的神经网络均衡器参数自适应调整方案,具体包括以下步骤:
[0060] 一、首先对于给定的基于神经网络的均衡器,在初始的训练集上进行线下训练。
[0061] 1.原始信号是一个长度为Γ·Ntr的序列,其中Ntr为训练集所包含的符号总数目,Γ为上采样倍数。
[0062] 2.对于其中第i个符号,输入特征向量包括第i个符号为中心的前后共2L+1个符号,长度为Γ·(2L+1)。用这Ntr个输入特征向量作为网络的输入,对应的Ntr个符号所属类别(即真实的标签。共有M个可能的类别,单个符号属于其中的某一类)作为网络的输出,构建训练用数据集。
[0063] 3.用确定的训练集,离线训练一个神经网络,用作后续的均衡和判决功能。将输入向量记为v,那么网络输出记为:
[0064] o=NNθ(v),
[0065] 其中o是一个M维的向量,给出了v对应符号属于M个类别的概率。
[0066] 4.由于是分类问题,训练时所采用的损失函数为交叉熵,公式为:
[0067]
[0068] 5.针对神经网络中的权值、偏置(统称为模型参数),借助优化算法(如随机梯度下降算法)对参数进行调整,减小网络在训练集上的损失函数,得到神经网络模型NNθ(·)。
[0069] 二、通信系统上线后,发送端对所要发送的数据段进行编码,将编码后的信息借助激光器发送。
[0070] 1.在物理层,接收到的数据已经被转化为比特流,并根据当前系统所使用的调制格式将比特流转化为符号序列(长度记为Nte),生成对应的电信号。
[0071] 2.使用电信号驱动激光器,转换为可在光纤中传输的光信号
[0072] 3.光信号经过光纤传输,到达接收端时波形发生了畸变。因此在接收端需要借助特殊的均衡技术进行信号处理。
[0073] 三、接收端将经过传输后的光信号转化为电信号,对电信号进行重采样、零均值标准化,此时接收信号形成一个完整的序列。
[0074] 1.在接收端,一般用光电转换器件(如光电二极管)接收光信号,转化为电信号,经过数字示波器采样,并在数字域进行重采样,得到一个长度为Γ·Nte的序列。其中,Nte代表符号总数,Γ代表重采样倍数。
[0075] 2.标准化:将重采样后的信号序列记为 将这一序列的均值记为μs,均方差记为σs,那么标准化之后的信号序列表示为
[0076]
[0077] 四、在接收信号序列s’的基础上,借助滑动窗口采集输入特征向量,将Nb个输入特征向量组合为一个batch的数据。
[0078] 1.建立一个长度为Γ·(2L+1)的滑动窗口,一开始以序列的第一个符号为中心(前面用0补充即可),开始按照时间顺序滑动。
[0079] 2.每滑动一次,窗口内的向量即为相应符号的输入特征向量,表达如下:
[0080] v(i)=[s′i-L,…,s′i,…,s′i+L],
[0081] 3.窗口滑动Nb次之后,共采集了相邻Nb个符号的输入特征向量。这些特征向量共同组成一个batch的数据。
[0082] 4.每次收集到一个batch的数据,需要根据这些数据对神经网络模型的参数进行更新。实际上,Nb的数值是此更新算法的超参数,需要事先人为地根据信道变化的快慢来确定。考虑到这一自适应调整算法是应用于在线场景,Nb一般不会取得非常大,因为缓冲区中难以存储如此多的样本;另一方面,Nb也不应当取得过小,这主要有两方面原因:首先,如果Nb过小,单个batch中数据的分布与真实的数据分布就会有明显偏离,相应地计算出的梯度会与真实梯度有偏离,收敛速度会受到影响;其次,频繁地对参数进行更新会导致较高的计算负担,进而影响系统的吞吐率(由于同一个batch中不同的样本可以并行处理,采用稍大的Nb并不会导致运算时间的明显增加)。
[0083] 以上过程的示意图如图3所示。
[0084] 五、每接收到一个数据batch(记为第t个batch),遍历batch中的所有输入特征向量v(i),计算损失函数对模型参数的梯度gt,并结合历史梯度信息,进行梯度下降,微调参数。
[0085] 1.开始处理单个batch的数据时,初始化损失函数L=0,初始化全零向量gt用于记录损失函数对模型参数的梯度。
[0086] 2.遍历batch中的所有输入特征向量v(i),借助一定的方式为v(i)赋予一个伪标签l(i),表明当前神经网络将v(i)判决为属于第l(i)类(l(i)∈{1,…,M})。赋予伪标签的具体方法在下一部分会详细介绍。
[0087] 3.针对样本v(i),更新损失函数:
[0088]
[0089] 其中l(i)代表算法赋予样本v(i)的伪标签。一般而言,如果神经网络模型的准确率比较高(当前信道条件下,接收信号的概率分布与训练集比较接近),这一伪标签有很大概率是正确的;但伪标签有较小概率不是正确的,因此相比于使用真实的分类标签(即发送端提供符号所属的类别),在模型微调的收敛速度上会稍慢。
[0090] 4.借助反向传播算法,求解损失函数变化量ΔL对模型参数的梯度,并累加到gt上。对于复杂的模型,手动计算梯度是比较困难的,一般借助开源的深度学习框架(如tensorflow、pytorch等)来实现。成熟的框架中,一般已经实现了自动借助反向传播算法求梯度的功能。
[0091] 5.计算得到gt之后,便可以采用预先定义好的优化方法,基于梯度下降对网络模型参数进行微调,达到减小损失函数、提升分类准确率的目的。使用何种优化方法是自由的,参数微调方法可以普遍地表达为:
[0092]
[0093] 其中θt代表模型参数,φ(g1,…,gt)代表对实际梯度的估计,这一估计是基于当前梯度和历史梯度信息。一种最简单的估计方式是,直接对历史梯度求取加权平均值:
[0094]
[0095] 其中β是一个常数,控制求平均值时的权值大小。
[0096] 前面的因子 对应于优化算法的自适应学习率,一种最简单的方式是取ψ(g1,…,gt)=1,这样学习率就不会发生自适应变化。值得注意的是,采用不同的优化器对于模型微调过程的收敛速度会有一定影响,在实际使用中需要根据情况选择合适的优化器。
[0097] 6.根据这一batch的数据对模型进行微调之后,继续接收并处理后续的batch,直到接收的所有数据都被处理完毕。
[0098] 对于分类问题,损失函数一般是交叉熵形式:
[0099]
[0100] 然而由于发送端无法针对发送符号提供真实值(即公式中的 项),在计算梯度gt时必须首先给batch中的输入特征向量v(i)赋予一个伪标签l(i)。赋予伪标签的方式有两种,阐述如下:
[0101] 方式1:仅做数据增强。
[0102] 实践表明,在训练阶段对输入数据进行数据增强,可以提升神经网络的鲁棒性,并且一般能够加快训练过程的收敛速度。数据增强通常通过给输入特征向量加一个较小的噪声来实现。举例来说,将数据增强过程用函数g(·)表示。一种简单可靠的数据增强方式是在输入特征向量上加一高斯噪声:
[0103]
[0104] 神经网络输出的概率向量为:
[0105] o=NN(g(v)),
[0106] o的维度是M,代表输入特征向量v属于各个类别的概率(在网络内部已经进行了归一化)。根据这一输出的概率向量,可以得到网络对v的分类结果为:
[0107] l=arg maxkok,
[0108] 即向量v所对应的符号大概率属于第l类。考虑到这一判断结果有很大概率是正确的,那么类似于判决反馈模式,可以将这一伪标签l当作真实标签,代入到损失函数的表达式中进行计算。
[0109] 值得注意的是,如果不进行数据增强,该方法的收敛速度会明显变慢,并将直接导致误码率增大。
[0110] 方式2:将数据增强与虚拟对抗训练相结合。
[0111] 首先对于给定的单个样本v,需要计算对抗扰动向量radv,步骤如下:
[0112] 步骤1:生成一个向量d,d的维度与输入特征向量维度相同,并且每一维相互独立,服从均方差为σ的零均值高斯分布。
[0113] 步骤2:将d叠加在输入向量上,用交叉熵来衡量NN(v+d)与NN(v)这两个概率向量的差异,并求这一交叉熵对向量d的梯度,记为g。
[0114] 步骤3:如果在向量g对应的方向对输入进行扰动,那么与未扰动时相比,神经网络的输出将会明显改变。将向量g长度归一化,再乘以一个较小的给定值,即可最终得到虚拟对抗扰动radv。基本的训练思路是,希望在输入特征向量上加入扰动radv后,神经网络的判决不会受到明显影响。
[0115] 加入对抗扰动之后,神经网络的输出概率向量为:
[0116] oadv=NN(g(v+radv)),
[0117] 类似于方式1中的概率向量o,此处oadv的维度是M,代表输入特征向量v属于各个类别的概率。根据这一输出的概率向量,可以得到网络对v的分类结果(引入对抗扰动之后)为:
[0118] ladv=arg maxkoadv,k,
[0119] 类似于方式1,可以将这一伪标签ladv当作真实标签,代入到损失函数的表达式中进行计算。
[0120] 值得一提的是,实践表明,如果只使用虚拟对抗训练技术而不同时引入数据增强,训练过程的收敛速率以及相应的误码率与使用高斯噪声进行数据增强相比并没有明显的提升;如果在引入虚拟对抗扰动的基础上再使用噪声进行数据增强,则误码率会有稳定的提升。这里给出的方式2实际上是方式1的改进版本。一般而言采用方式2时,神经网络模型的训练收敛速度会稍快,性能稳定优于方式1。然而采用方式2的代价是,在计算对抗扰动向量时需要先额外使用一次反向传播,因此计算复杂度大约是方式1的两倍。
[0121] 最终,发送端完成所有信息的发送,接收端对所有的batch都处理完毕,可以进行后续的信息处理。
[0122] 以上实施例仅用以说明本发明的技术方案而非对其进行限制,本领域的普通技术人员可以对本发明的技术方案进行修改或者等同替换,而不脱离本发明的精神和范围,本发明的保护范围应以权利要求书所述为准。
高效检索全球专利

专利汇是专利免费检索,专利查询,专利分析-国家发明专利查询检索分析平台,是提供专利分析,专利查询,专利检索等数据服务功能的知识产权数据服务商。

我们的产品包含105个国家的1.26亿组数据,免费查、免费专利分析。

申请试用

分析报告

专利汇分析报告产品可以对行业情报数据进行梳理分析,涉及维度包括行业专利基本状况分析、地域分析、技术分析、发明人分析、申请人分析、专利权人分析、失效分析、核心专利分析、法律分析、研发重点分析、企业专利处境分析、技术处境分析、专利寿命分析、企业定位分析、引证分析等超过60个分析角度,系统通过AI智能系统对图表进行解读,只需1分钟,一键生成行业专利分析报告。

申请试用

QQ群二维码
意见反馈