首页 / 专利库 / 人工智能 / 机器学习 / 监督学习 / 一种基于注意力机制的稀疏编码方法

一种基于注意机制的稀疏编码方法

阅读:701发布:2020-05-13

专利汇可以提供一种基于注意机制的稀疏编码方法专利检索,专利查询,专利分析的服务。并且本 发明 涉及一种基于注意 力 机制的稀疏编码方法,该方法在原有模型LISTA的 基础 上加入注意力机制网络,从而充分利用 迭代 历史信息, 加速 模型收敛,并在此基础上进行 监督学习 ,提高图像分类的准确率。与 现有技术 相比,本发明具有 算法 复杂度更低,实验证明效果更佳等优点。,下面是一种基于注意机制的稀疏编码方法专利的具体信息内容。

1.一种基于注意机制的稀疏编码方法,其特征在于,该方法在原有模型LISTA的基础上加入注意力机制网络,从而充分利用迭代历史信息,加速模型收敛,并在此基础上进行监督学习,提高图像分类的准确率。
2.根据权利要求1所述的一种基于注意力机制的稀疏编码方法,其特征在于,该方法具体包括以下步骤:
将注意力机制加到原有的LISTA模型上,其中历史信息定义如下:
A[t]=WeX+WsS[t-1]
S[t]=hθ(A[t])
其中A[t]表示第t次迭代的历史信息,X表示原始数据,S[t]表示第t次迭代的稀疏编码输出,We和Ws是可学习参数;
1)基于注意力机制的稀疏编码前向传播过程;
2)在后向传播时所有可学习的参数被更新,直至模型收敛为止;
3)在步骤2)基础上加入有监督学习项,并应用于图像分类任务。
3.根据权利要求2所述的一种基于注意力机制的稀疏编码方法,其特征在于,所述的基于注意力机制的稀疏编码前向传播过程具体包括以下步骤:
输入: λ>0,n>=]
其中X表示原始数据,是m个p维度的数组;D是字典矩阵,λ表示稀疏系数,n表示迭代次数;
初始化: θ=λ/τ,S[0]=0,
其中θ是软阈值函数h的输入参数;
第一步:A[t]=WeX+WsS[t-1]
这一步求得当前迭代的历史信息A[t],X表示原始数据,S[t]表示第t次迭代的稀疏编码输出,We和Ws是可学习参数;
第二步:S[t]=hθ(A[t)
S[t]表示第t次迭代的稀疏编码输出,hθ表示软阈值函数;
第三步:计算各个历史信息的重要性权重,
其中 表示对t次迭代、第t-l+i个历史信息对当前输出的重要程度;
第四步:计算上下文向量和最终的稀疏编码。
4.根据权利要求3所述的一种基于注意力机制的稀疏编码方法,其特征在于,所述的计算各个历史信息的重要性权重具体为:
对于第t次迭代更新,通过注意力机制决定当前输出Z[t]与哪些历史信息有关,并且构建一个多层感知器,求出某一层历史信息对当前输出的重要程度,其中多层感知器建模如下:
其中 表示对t次迭代、第t-l+i
个历史信息对当前输出的重要程度,A表示历史信息,Z[t-1]表示上一层迭代输出,即该历史信息的重要程度由自身和前一层迭代输出共同决定,P表示一个多层感知器,Wa表示多层感知器的参数。
5.根据权利要求3所述的一种基于注意力机制的稀疏编码方法,其特征在于,所述的计算上下文向量和最终的稀疏编码具体为:
得到各个历史信息的权重后,将其加权求和得到上下文向量,并得到最终的稀疏编码Z,
Z[t]=hθ(C[t]).
其中C表示上下文向量,是对所有历史信息进行整合之后的结果,l表示关注历史信息的长度,hθ表示软阈值函数:
hθ(x)=sign(x)max(|x|-θ,0)
Sign(x)表示符号函数,当x大于0时值为1,小于0时值为-1。
6.根据权利要求2所述的一种基于注意力机制的稀疏编码方法,其特征在于,所述的监督学习顶采用softmax函数,其loss函数定义如下:
相比原来的ALISTA,在重构误差和稀疏误差的基础上加入了分类损失误差,其中β≥0用于权衡分类误差,Ec是用于softmax分类对交叉熵损失函数,yc是类别标签,Wc是可学习的参数,Lc表示分类loss、X表示原始数据、D表示字典矩阵、Z表示第n次迭代的稀疏编码输出、λ表示稀疏系数。

说明书全文

一种基于注意机制的稀疏编码方法

技术领域

[0001] 本发明涉及一种稀疏编码方法,尤其是涉及一种基于注意力机制的稀疏编码 方法。

背景技术

[0002] 稀疏编码SC(Sparse Coding)方法在盲源信号分离、语音信号处理、自然图 像特征提取、自然图像去噪以及模式识别等方面已经取得许多研究成果,具有重要 的实用价值,是当前学术界的一个研究热点。进一步研究稀疏编码技术,不仅会积极 地促进图像信号处理、神经网络等技术的研究,而且也将会对相关领域新技术的发 展起到一定的促进作用。
[0003] 近几年,以RNN(Recurrent Neural Network,时序循环神经网络)为基础的 SC推断方法变得流行,如LISTA(Learned iterative shrinkage-thresholding algorithm, 可学习的迭代阈值算法),LFISTA(Learned Fast iterative shrinkage-thresholding algorithm,可学习的快速迭代软阈值算法)和SLSTM(Sparse Long Short-Term Memory,基于长短期记忆单元的稀疏编码方法)。与传统的交替优化算法,如ISTA  (iterative shrinkage-thresholding algorithm,迭代软阈值算法)相比,这些基于RNN 的方法主要有两个优势:1)基于RNN的算法可以同时学习字典和进行稀疏编码 推断;2)基于RNN算法的稀疏编码推断效率更高,算法复杂度更低(传统的稀 疏编码方法需要解决凸优化问题。
[0004] 作为LISTA算法的改进,LFISTA和SLSTM并没有合理地利用迭代过程中的 历史信息,这些信息已被证明是有利于加速迭代的收敛的。
[0005] 现有技术一介绍了一种可学习的迭代软阈值算法。给定m个p维的信号X= 稀疏编码的任务是利用一组超完备的字典 k≥p找到m个k维的稀疏系数 满足:
[0006]
[0007] 其中λ>0用于平衡稀疏项和数据重构项。
[0008] 为了解决公式(1),传统方法是交替优化D和Z,固定D时,按如下公式优化Z:
[0009]
[0010] ISTA算法的解决思路如下:
[0011] Z[t]=hθ(WeX+WsZ[t-1]),      (3)
[0012]
[0013] 其中hθ(x)=sign(x)max(|x|θ,0)是软阈值函数,θ=λ/τ,τ是DTD的最大特征值, tT表示第t次迭代,D 表示D的转置矩阵,I表示单位矩阵。在LISTA[3]算法中, 参数{We,Ws}是可学习的,LISTA通过建立RNN网络学习权重参数,该网络每个 RNN层的输入是前一层的稀疏编码输出。算法框架如图2所示。
[0014] 现有技术一的LISTA算法的结构较为简单,没有充分利用历史信息加速模型 收敛。
[0015] 现有技术二为快速迭代软阈值算法FISTA,FISTA本质上是在原有的ISTA的基 础上引入动量项加速收敛:
[0016]
[0017] 其中,
[0018]
[0019] 可以发现,在LFISTA[4]中每个RNN层输入的是前两层的稀疏编码输出,如 图3所示。
[0020] 现有技术二的LFISTA算法的结构只关注了一层历史信息,同样没有充分利用 历史信息加速模型收敛。
[0021] 现有技术三提出了一种基于长短期记忆单元的稀疏编码模型SLSTM。SLSTM算 法借鉴了传统LSTM算法的思想,通过引入两个:更新门u[t]和遗忘门v[t],其 迭代公式如下:
[0022]
[0023]
[0024] 其中⊙向量之间的逐元素乘积,σ表示sigmoid激活函数,Wus,Wue,Wfs,Wfe是可学 习的LSTM参数。通过分析,注意到隐藏单元c[t]的迭代输出可推导如下:
[0025]
[0026] 其中 和 表示第i层迭代中与Z,c,x有关的函数。可以发现,SLSTM实际上 整合了所有历史信息。
[0027] 现有技术三虽然SLSTM充分利用了历史信息,但是由于把所有历史信息聚合在 一起,由于较早的迭代层数中模型尚未收敛,该层历史信息不一定有用,所以盲目 聚合会导致模型性能下降。同时网络复杂度也比较高,模型比较占内存。

发明内容

[0028] 本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种基于注意力 机制的稀疏编码方法。
[0029] 本发明的目的可以通过以下技术方案来实现:
[0030] 一种基于注意力机制的稀疏编码方法,其特征在于,该方法在原有模型LISTA 的基础上加入注意力机制网络,从而充分利用迭代历史信息,加速模型收敛,并在 此基础上进行监督学习,提高图像分类的准确率。
[0031] 优选地,该方法具体包括以下步骤:
[0032] 将注意力机制加到原有的LISTA模型上,其中历史信息定义如下:
[0033] A[t]=We X+Ws S[t-1]
[0034] S[t]=hθ(A[t])
[0035] 其中A[t]表示第t次迭代的历史信息,X表示原始数据,S[t]表示第t次迭代的 稀疏编码输出,We和Ws是可学习参数;
[0036] 1)基于注意力机制的稀疏编码前向传播过程;
[0037] 2)在后向传播时所有可学习的参数被更新,直至模型收敛为止;
[0038] 3)在步骤2)基础上加入有监督学习项,并应用于图像分类任务。
[0039] 优选地,所述的基于注意力机制的稀疏编码前向传播过程具体包括以下步骤:
[0040] 输入: λ>0,n>=1
[0041] 其中X表示原始数据,是m个p维度的数组;D是字典矩阵,λ表示稀疏系数, n表示迭代次数;
[0042] 初始化:
[0043] 其中θ是软阈值函数h的输入参数;
[0044] 第一步:A[t]=We X+Ws S[t-1]
[0045] 这一步求得当前迭代的历史信息A[t],X表示原始数据,S[t]表示第t次迭代的 稀疏编码输出,We和Ws是可学习参数;
[0046] 第二步:S[t]=hθ(A[t])
[0047] S[t]表示第t次迭代的稀疏编码输出,hθ表示软阈值函数;
[0048] 第三步:计算各个历史信息的重要性权重, α,α,…,α,其中α表示对t次迭代、第 t-l+i个历史信息对当前输出的重要程度;
[0049] 第四步:计算上下文向量和最终的稀疏编码。
[0050] 优选地,所述的计算各个历史信息的重要性权重具体为:
[0051] 对于第t次迭代更新,通过注意力机制决定当前输出Z[t]与哪些历史信息有关, 并且构建一个多层感知器,求出某一层历史信息对当前输出的重要程度,其中多层 感知器建模如下:
[0052]
[0053] 其中 α表示对t次迭代、第t-l+i个历史信息对当前输出的重要程度,A表示历史信息,Z[t-1]表示上一层迭代 输出,即该历史信息的重要程度由自身和前一层迭代输出共同决定,P表示一个多 层感知器,Wa表示多层感知器的参数。
[0054] 优选地,所述的计算上下文向量和最终的稀疏编码具体为:
[0055] 得到各个历史信息的权重后,将其加权求和得到上下文向量,并得到最终的稀 疏编码Z,
[0056]
[0057] Z[t]=hθ(C[t]).
[0058] 其中C表示上下文向量,是对所有历史信息进行整合之后的结果,l表示关注 历史信息的长度,hθ表示软阈值函数:
[0059] hθ(x)=sign(x)max(|x|-θ,0)
[0060] Sign(x)表示符号函数,当x大于0时值为1,小于0时值为-1。
[0061] 优选地,所述的监督学习项采用softmax函数,其loss函数定义如下:
[0062]
[0063] 相比原来的ALISTA,在重构误差和稀疏误差的基础上加入了分类损失误差, 其中β≥0用于权衡分类误差,Ec是用于softmax分类对交叉熵损失函数,yc是类别 标签,Wc是可学习的参数,Lc表示分类loss、X表示原始数据、D表示字典矩阵、 Z表示第n次迭代的稀疏编码输出、λ表示稀疏系数。
[0064] 与现有技术相比,本发明具有以下优点:
[0065] 1)本发明提出的ALISTA框架利用注意力机制网络可以灵活地结合迭代过程 的历史信息并为其分配合适的权重,权重参数均可学习
[0066] 2)本发明提出的ALISTA框架是简单的LISTA和attention网络的结合,没有 过多地改变LISTA的整体架构
[0067] 3)本发明提出的ALISTA几乎没有增加LISTA的参数量,大大少于LFISTA 和SLSTM的参数量,算法复杂度更低,实验证明效果更佳。附图说明
[0068] 图1为ALISTA的整体框架图;
[0069] 图2为LISTARNN单元示意图;
[0070] 图3为注意力机制网络示意图;
[0071] 图4为ISTA的结构示意图;
[0072] 图5为LISTA的结构示意图;
[0073] 图6为FISTA的结构示意图;
[0074] 图7为LFISTA的结构示意图;图8为多层感知器的结构示意图。

具体实施方式

[0075] 下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、 完整地描述,显然,所描述的实施例是本发明的一部分实施例,而不是全部实施例。 基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获 得的所有其他实施例,都应属于本发明保护的范围。
[0076] 为了更合理地发挥这些信息的作用,本文提出了一种新的基于注意力机制的可 学习迭代软阈值算法框架ALISTA(Attention-Based Learned iterative shrinkage-thresholding algorithm)。ALISTA由一个注意力机制网络和一个时序RNN 组成,如图1所示。其中时序RNN本质上是LISTA,其作用是产生一系列时序的 迭代信息;注意力机制网络的作用是决定这些迭代历史信息的重要程度并将其整合, 作为本次迭代更新的输入。前人的算法如LISTA只考虑了一层历史信息,而SLSTM 考虑了所有的历史信息,本专利提出的ALISTA可以更灵活地决定输入多少迭代历 史信息和决定各个历史信息的重要程度。ALISTA算法已在MNIST数据集和 CIFAR-10数据集作了图像重构和图像分类实验,实验结果表明,ALISTA不论从 编码质量还是编码效率上已经超过了原有的LFISTA算法和SLSTM算法。
[0077] 其中图1是ALISTA的整体框架图,图2是LISTA RNN单元,用于产生时序 的历史信息,图3是注意力机制网络,输入由RNN网络生成的各个历史信息,输 出各个历史信息的加权和,权重参数可学习。
[0078] 本发明主要解决的技术问题如下:
[0079] (1)提出一种新的基于注意力机制的稀疏编码模型,可以在不增加模型复杂 度的情况下充分考虑迭代历史信息。
[0080] (2)将提出的基于注意力机制的稀疏编码模型应用于有监督学习,从而提高 分类准确率。
[0081] 本发明提出了一种新的基于注意力机制的稀疏编码方法,在原有模型LISTA 的基础上加入注意力机制网络,从而充分利用迭代历史信息,加速模型收敛。并且, 在此基础上可以进行有监督学习,提高图像分类的准确率。
[0082] 本发明利用注意力机制进行稀疏编码,模型的理论基础、模型的建立、模型的 有监督形式及模型的应用如下:
[0083] 1、模型的理论基础---注意力机制
[0084] 从上述分析可以发现,前人的算法如LISTA、LFISTA和SLSTM都没有合理 地利用历史信息,导致模型收敛变慢或者算法复杂度过高等问题。本发明提出在原 有算法LISTA的基础上加入注意力机制。注以能更好地关注历史信息,分配合理 的参数权重,是因为它在模型训练过程中所有参数都是可学习的。具体来说,对于 第t次迭代更新,通过注意力机制决定当前输出Z[t]与哪些历史信息有关,并且构 建一个多层感知器,求出某一层历史信息的对当前输出的重要程度,多层感知器的 结构如图8所示。多层感知器建模如下:
[0085]
[0086] 其中 α表示对t次迭代、第 t-l+i个历史信息对当前输出的重要程度,A表示历史信息,Z[t-1]表示上一层迭代 输出,即该历史信息的重要程度由自身和前一层迭代输出共同决定。
[0087] 得到各个历史信息的权重后,将其加权求和得到上下文向量,并得到最终的稀 疏编码Z,
[0088]
[0089] Z[t]=hθ(C[t]).
[0090] 其中C表示上下文向量,是对所有历史信息进行整合之后的结果,l表示关注 历史信息的长度,是可人为控制的参数。
[0091] 2、模型的构建---基于注意力机制的稀疏编码模型
[0092] 将注意力机制加到原有的LISTA模型上,上节提到的历史信息定义如下:
[0093] A[t]=We X+Ws S[t-1]
[0094] S[t]=hθ(A[t])
[0095] 基于注意力机制的稀疏编码模型流程如下:
[0096] 基于注意力机制的稀疏编码前向传播
[0097] 输入: λ>0,n>=1
[0098] 初始化:
[0099] 第一步:A[t]=We X+Ws S[t-1]
[0100] 第二步:S[t]=hθ(A[t])
[0101] 第三步:计算各个历史信息的重要性权重, α,α,…,α,其中α表示对t次迭代、第 t-1+i个历史信息对当前输出的重要程度;
[0102] 第四步:计算上下文向量和最终的稀疏编码。
[0103] 重复第一步至第四步n遍。
[0104] 上述过程形成一个前向传播,在后向传播时所有可学习的参数被更新,直至模 型收敛为止。对各个模型的训练参数量进行统计,结果如表1所示。从表看出, ALISTA的模型的参数量明显少于LFISTA和SLSTM,相比LISTA只多了一点参 数。这说明本发明提出的方法仅比简单的LISTA增加了一点复杂度。
[0105] 表1
[0106]
[0107] 3、模型的有监督形式
[0108] 本发明提出的方法可以方便地推广成有监督形式,并用于图像分类。有监督形 式的ALISTA框架如图1所示。为了进行分类,本发明采用softmax函数,其loss 函数定义如下:
[0109]
[0110] 相比原来的ALISTA,在重构误差和稀疏误差的基础上加入了分类损失误差, 其中β≥0用于权衡分类误差,Ec是用于softmax分类对交叉熵损失函数,yc是类别 标签,Wc是可学习的参数。
[0111] 传统的基于RNN的稀疏编码推断算法,如LFISTA和SLSTM主要面临两个 主要问题:a)不能很好地结合迭代过程保留的历史信息来加速模型收敛,历史信 息已被证明可以加速稀疏编码推断的收敛;b)为使性能提升而改变原LISTA的简 单结构,虽然性能能够提升,但同时算法复杂度变高。本发明能很好地解决上述问 题。
[0112] 以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此, 任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到各种等效 的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明 的保护范围应以权利要求的保护范围为准。
高效检索全球专利

专利汇是专利免费检索,专利查询,专利分析-国家发明专利查询检索分析平台,是提供专利分析,专利查询,专利检索等数据服务功能的知识产权数据服务商。

我们的产品包含105个国家的1.26亿组数据,免费查、免费专利分析。

申请试用

分析报告

专利汇分析报告产品可以对行业情报数据进行梳理分析,涉及维度包括行业专利基本状况分析、地域分析、技术分析、发明人分析、申请人分析、专利权人分析、失效分析、核心专利分析、法律分析、研发重点分析、企业专利处境分析、技术处境分析、专利寿命分析、企业定位分析、引证分析等超过60个分析角度,系统通过AI智能系统对图表进行解读,只需1分钟,一键生成行业专利分析报告。

申请试用

QQ群二维码
意见反馈