首页 / 专利库 / 人工智能 / 机器学习 / 半监督学习 / 一种用于GAN模型训练的半监督学习的方法

一种用于GAN模型训练的半监督学习的方法

阅读:144发布:2020-05-17

专利汇可以提供一种用于GAN模型训练的半监督学习的方法专利检索,专利查询,专利分析的服务。并且本 发明 公开了一种用于GAN模型训练的通过自适应监督比率控制的半 监督学习 的方法。本方法有效的解决了GAN模型训练过程中对抗器学习得到的辩真规则不足以指导生成器生成高 精度 案例的问题;通过自适应监督比率,模型自动控制融入标签信息的量,有效利用了数据信息,使得生成的案例同时满足多样性与保真度。,下面是一种用于GAN模型训练的半监督学习的方法专利的具体信息内容。

1.一种用于GAN模型训练的半监督学习的方法,其特征在于,包括如下步骤:
第1步,获取图像数据,所述的图像数据中至少包括第一图像区域以及第二图像区域,第二图像区域落在第一图像区域的范围内部;并且,图像数据中还包括第一图像区域与第二图像区域的相对位置关系;
第2步,分别对每个图像数据中的第一图像区域和第二图像区域进行特征提取,分别得到经过了图形特征提取后的第一图像特征和第二图像特征;
第3步,将每个图像数据中的第二图像特征进行拼接之后,再与第一图像特征进行拼接,得到总图像特征;
第4步,随机生成种子矩阵,并与总图像特征进行拼接,作为生成器的输入值,得到生成器的输出值Goutput;生成器的监督损失函数LG_s中纳入了第一图像区域与第二图像区域的相对位置关系;
第5步,根据Goutput生成图像数据,再传入对抗器中,输出评分pG;
第6步,计算生成器的非监督损失函数值LG_u,LG_u=∑(1-pG);
第7步,生成器的损失函数LG通过如下计算: r是自适应比
率参数;
第8步,根据第一图像区域与第二图像区域的相对位置关系标签,构建真实家具布局图,并计算真实家具布局图的评分preal;
第9步,遍历数据集,重复第2到第8步,直到生成器生成较为满意的输出为止。
2.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,所述的图像数据是带有家具的户型图;所述的第一图像区域是户型结构,所述的第二图像区域是家具。
3.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,设定第二图像区域的数量为k个,k可以为大于1的整数;每个图像数据中若实际第二图像区域少于k个时,图像数据中用0补齐至k个。
4.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,第2步中,特征提取过程可以是resNet特征提取器、Inception特征特征提取器,也可以是VGG特征提取器,甚至可以是一般的卷积神经网络
5.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,生成器主体结构仍然采用resNet网络。
6.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,第一图像区域与第二图像区域的相对位置关系是指家具在户型图中的位置坐标关系。
7.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,位置坐标关系由标签l表示,标签l是一个零一矩阵,每个第二图像区域在第一图像区域中的坐标位置标记为1,矩阵大小为(i,j)。
根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,生成器的监督损失函数LG_s过程式如下:
其中, 是Goutput中的第i行第j列第k通道的元素,i和j是标签l中第k个通道中值为1的元素的坐标,函数对一批次中的所有样本以及各案例中的所有第二图像区域的元素进行累加。
8.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,所述的第8步中,对抗器的损失函数如下:LD=∑(D(Ireal)-pG);其中,D(.)为对抗器,Ireal是真实家具布局图。
9.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,利用Adam等优化器更新对抗器的参数,利用Adam等优化器更新生成器的参数;
在一个实施方式中,可以人为指定半监督学习参数r,也可以让模型自适应地学习r;在自适应地学习r的实施方式中,r的更新方式如下:
10.根据权利要求1所述的用于GAN模型训练的半监督学习的方法,其特征在于,在一个实施方式中,所述的图像数据是户型布局,其中,第一图像区域是指户型图,第二图像区域是家具。

说明书全文

一种用于GAN模型训练的半监督学习的方法

技术领域

[0001] 本发明涉及深度学习领域的GAN模型训练方法,具体涉及一种半监督GAN的通过自适应监督比率控制监督程度的学习方法。

背景技术

[0002] 生成式对抗网络(GAN,Generative Adversarial Networks)是一种深度学习模型,GAN模型在图像生成、语音生成、文字生成等领域都有广泛应用。对GAN模型的训练,一般都采用无监督式训练方法,这样可以利用大量未标记过的自然数据;当存在部分标记数据时,可以通过半监督学习框架,从而得到生成器的输入编码与标签的对应关系,完成C-GAN的训练。但这些一般方法存在以下问题:1.无法充分利用数据信息,特别是标签信息;2.对抗器学习得到的辩真规则无法有效的指导生成器更新参数。特别是当所有数据都已经是带标签的数据时,如何有效利用好GAN模型训练框架,让最后训练好的生成器输出结果既满足多样性又满足保真度,就成为了一个问题。
[0003] 一般的GAN模型训练方法,都是采用的非监督式训练,其优点是不需要标签数据,通过观察大量的真实数据,一方面让生成器生成的结果令对抗器输出较大的评分,另一方面让对抗器分辨真实数据和生成器生成结果的能不断提高。此种博弈思想的利用,使得模型可以不依赖标签进行训练。
[0004] GAN模型的另一优点是,相较于预测式模型,GAN作为生成式模型可以针对同一输入输出不同的结果,理论上可以做到不同的结果都是逼近真实的。例如,给定一个户型和多个待布局的家具,GAN模型可以输出多个满足要求的方案。
[0005] 但GAN模型的缺点也是很明显的。第一个缺点就是GAN模型难以训练,训练难以收敛,有时训练时候会发生梯度消失或梯度爆炸;第二个缺点是,如果收集的是大量带标签的数据,此时为了达到生成结果的多样性,仍然采用GAN模型的话就会丢失标签数据的重要信息。

发明内容

[0006] 基于以上两个GAN模型下的问题,本方法引入了监督学习,并通过一个监督比率将生成器的监督损失和非监督损失组合在一起,并通过了图像中的两类图像信息的位置关系作为数据标签。
[0007] 本发明的技术方案主要解决了以下问题:
[0008] 解决用全部带标签的数据训练GAN模型时如何充分利用标签信息的问题;
[0009] 解决GAN模型训练过程中对抗器学习得到的辩真规则无法正确引导生成器更新参数的问题;
[0010] 解决既保证生成器输出结果的多样性又保证其保真度的问题;
[0011] 解决两种损失函数相结合时无法及时而适当地调整不同损失函数占比的问题。
[0012] 技术方案是:
[0013] 一种用于GAN模型训练的半监督学习的方法,包括如下步骤:
[0014] 第1步,获取图像数据,所述的图像数据中至少包括第一图像区域以及第二图像区域,第二图像区域落在第一图像区域的范围内部;并且,图像数据中还包括第一图像区域与第二图像区域的相对位置关系;
[0015] 第2步,分别对每个图像数据中的第一图像区域和第二图像区域进行特征提取,分别得到经过了图形特征提取后的第一图像特征和第二图像特征;
[0016] 第3步,将每个图像数据中的第二图像特征进行拼接之后,再与第一图像特征进行拼接,得到总图像特征;
[0017] 第4步,随机生成种子矩阵,并与总图像特征进行拼接,作为生成器的输入值,得到生成器的输出值Goutput;生成器的监督损失函数LG_s中纳入了第一图像区域与第二图像区域的相对位置关系;
[0018] 第5步,根据Goutput生成图像数据,再传入对抗器中,输出评分pG;
[0019] 第6步,计算生成器的非监督损失函数值LG_u,LG_u=∑(1-pG);
[0020] 第7步,生成器的损失函数LG通过如下计算: r是自适应比率参数;
[0021] 第8步,根据第一图像区域与第二图像区域的相对位置关系标签,构建真实家具布局图,并计算真实家具布局图的评分preal,进而计算对抗器的损失函数LD,LD=∑(preal-pG);
[0022] 第9步,利用优化器,基于损失函数LG更新生成器参数,基于损失函数LD更新对抗器参数;
[0023] 第10步,遍历数据集,重复第2到第9步,直到生成器生成较为满意的输出为止。
[0024] 在一个实施方式中,所述的图像数据是带有家具的户型图;所述的第一图像区域是户型结构,所述的第二图像区域是家具。
[0025] 在一个实施方式中,设定第二图像区域的数量为k个,k可以为大于1的整数;每个图像数据中若实际第二图像区域少于k个时,图像数据中用0补齐至k个。
[0026] 在一个实施方式中,第2步中,特征提取过程可以是resNet特征提取器、Inception特征特征提取器,也可以是VGG特征提取器,甚至可以是一般的卷积神经网络
[0027] 在一个实施方式中,生成器主体结构仍然采用resNet网络。
[0028] 在一个实施方式中,第一图像区域与第二图像区域的相对位置关系是指家具在户型图中的位置坐标关系。
[0029] 在一个实施方式中,位置坐标关系由标签l表示,标签l是一个零一矩阵,每个第二图像区域在第一图像区域中的坐标位置标记为1,矩阵大小为(i,j)。
[0030] 在一个实施方式中,生成器的监督损失函数LG_s过程式如下:
[0031]
[0032] 其中, 是Goutput中的第i行第j列第k通道的元素,i和j是标签l中第k个通道中值为1的元素的坐标,函数对一批次中的所有样本以及各案例中的所有第二图像区域的元素进行累加。
[0033] 在一个实施方式中,所述的第8步中,对抗器的损失函数如下:LD=∑(D(Ireal)-pG);其中,D(.)为对抗器,Ireal是真实家具布局图。
[0034] 在一个实施方式中,利用Adam等优化器更新对抗器的参数,利用Adam等优化器更新生成器的参数。
[0035] 在一个实施方式中,可以人为指定半监督学习参数r,也可以让模型自适应地学习r。在自适应地学习r的实施方式中,r的更新方式如下:
[0036]
[0037] 在一个实施方式中,所述的图像数据是户型布局,其中,第一图像区域是指户型图,第二图像区域是家具。
[0038] 有益效果
[0039] 在GAN模型训练过程中,将监督式损失函数值通过参数r与无监督式损失函数值进行加权求和,然后作为新的生成器损失函数值,指导生成器进行参数更新。同时,通过参数r的调节作用,解决两种损失函数相结合时无法及时而适当地调整不同损失函数占比的问题。
[0040] 本发明中,通过将真实图像分成两图像区域分别进处理,并且由识别出两块图像区域的相对位置关系,将其作为数据标签引入模型之后,解决用全部带标签的数据训练GAN模型时如何充分利用标签信息的问题;
[0041] 本发明同时在生成器的非监督损失函数中引入了对抗器的计算评分,解决GAN模型训练过程中对抗器学习得到的辩真规则无法正确引导生成器更新参数的问题;
[0042] 通过以上的方法,可以有效地利用了原始图像文件中的数据标签信息,将图像数据进行了解读,将其引入模型计算过程中之后,可以提高模型的收敛速度以及提高模型的生成结果与真实样本的逼近程度,解决既保证生成器输出结果的多样性又保证其保真度的问题。附图说明
[0043] 图1是本发明的方法图。
[0044] 图2是本发明的生成的户型家具布局图。
[0045] 图3是对照GAN模型的生成的户型家具布局图。

具体实施方式

[0046] 在GAN模型(生成对抗网络模型)中,存在两个相互博弈的模型,一个是生成器(generator),一个是对抗器(discriminator)。简记生成器为函数G(.),对抗器为函数D(.)。一般的生成器损失函数定义如下:
[0047] LG=∑(1-D(G(n)))
[0048] 其中,n为生成器的输入种子,记生成器为G(.),记对抗器为D(.)。
[0049] 对抗器的损失函数定义如下:
[0050] LD=∑(D(real)-D(G(n)))
[0051] 其中,real为真实的案例。
[0052] 在GAN模型的训练过程中,生成器损失函数值指导生成器的参数更新,指导方法可以是Adam等优化器。同样,对抗器的损失函数值指导对抗器的参数更新。
[0053] 本发明提出的用于GAN模型训练的半监督学习方法,可以将监督式损失函数值通过参数r与无监督式损失函数值进行加权求和,然后作为新的生成器损失函数值,指导生成器进行参数更新,使生成器生成结果同时满足多样性和保真度。
[0054] 在以下的实施例中,采用的是针对室内家具布局业务提出的GAN模型训练方法,用于生成适合于设计师希望的户型设计结果;该方法只对数据类型有限制,对业务并无限制,只要数据类型满足本方法的要求,就可以利用本方法训练GAN模型。
[0055] 以室内户型设计为例,采用的方法流程详述如下:
[0056] 第1步,准备带标签的数据,例如,家具布局图数据,具有万级案例数,在户型图中包含有房间的平面形状数据,也包含了在这个户型中的家具的形状数据,以及家具在户型中的摆放位置信息,因此,这里的标签可以是家具的摆放位置信息,根据此信息,以及数据中的户型图和家具平面图,可以通过简单图形构造出室内布局图;
[0057] 每个案例数据内容及规格如下:
[0058] 一张未放家具的平面户型图Iroom,尺寸(128,128,1);分别是指图像的长、高、数量;
[0059] 待放的家具的平面图ifurniture,尺寸(128,128,1),分别是指图像的长、高、数量;在本步骤中,设定ifurniture数量为k个,代表k个家具,若家具不足k个,则用零值补齐,家具个数多于k个的案例作为异常数据不予考虑;k可以为大于1的整数。
[0060] 待放的家具的位置标签l,规模为(128,128,k)的零一矩阵。在图像数据中,取家具的中心点代表该家具的位置,那么该在图像数据阵列中,以1代表存在家具,以0代表不放置家具,因此对于一个家具来说,它所的对应的标签零一矩阵为含有一个1的矩阵,1在矩阵中的坐标位置对应了这个家具在房间中的相对位置。
[0061] 第2步,对第1步准备的每一个案例中的平面户型图进行特征提取。例如,采用resNet残差网络进行特征提取,得到户型图图形特征Froom,过程式如下:
[0062] Froom=resNet(Iroom)
[0063] 其中,Iroom指的是案例中的平面户型图,resNet残差网络结构属于互联网开源代码,具体实现与本发明无关。处理后,Froom的尺寸为(128,128,sroom)。
[0064] 第3步,对第1步准备的每一个案例中的每一个家具对应的平面图进行特征提取。同样,可以采用resNet残差网络,得到每一个家具的图形特征ffurniture,过程式如下:
[0065] ffurniture=resNet(ifurniture)
[0066] 其中,ifurniture指的是每一个平面图。处理后,ffurniture的尺寸为(128,128,sfurniture)
[0067] 第4步,将第3步得到的多个家具的图形特征进行拼接,形成本案例的总的家具特征,过程式如下:
[0068]
[0069] 其中,Ffurniture为拼接得到的图形特征, 为第i个家具的图形特征,k为家具最大个数,此处取16。处理后,Ffurniture的尺寸为(128,128,sfurniture*k)。
[0070] 第5步,将第2步得到的户型图特征和第4步得到的家具图形特征进行拼接,形成户型图和家具图的总特征Ftotal,过程式如下:
[0071] Ftotal=concatenate([Froom,Ffurniture])
[0072] 处理后,Ftotal的尺寸为(128,128,sfurniture*k+sroom)。
[0073] 第6步,随机生成种子矩阵,并与第5步得到的总特征Ftotal作拼接,得到生成器的输入Ginput,过程式如下:
[0074] Ginput=concatenate([Ftotal,Frandom])
[0075] 处理后,Ginput的尺寸为(128,128,sfurniture*k+sroom+srandom),其中,srandom是随机种子矩阵的深度。
[0076] 第7步,将第6步得到的Ginput传入生成器中,生成家具布局参数Goutput,生成器主体结构仍然采用resNet网络,过程式如下:
[0077] Goutput=resNet(Ginput)
[0078] 处理后,Goutput的尺寸为(128,128,k)。为了行文方便,我们记生成器为G(.),将上式改写如下:
[0079] Goutput=G(Ginput)
[0080] 第8步,本发明适用的数据集中,生成器种子n除了一部分是随机白噪声之外,另一部分是和real配对的,并且带有对应的标签label。因此,可以将前文定义的生成器损失函数重新命名为非监督式的生成器损失函数(即图中的Unsupervised Generator Loss),并定义监督式的生成器损失函数(即图中的Supervised Generator Loss)为L(G(n),label)[0081] LG_s=f(G(n),label)
[0082] 其中f的具体函数形式与本发明无关,可以是交叉熵,也可以是其他。
[0083] 在本实施例中,计算生成器的监督损失函数LG_s,过程式如下:
[0084]
[0085] 其中, 是Goutput中的第i行第j列第k通道的元素,i和j是标签l中第k个通道中值为1的元素的坐标,函数对一批次中的所有案例(图像样本)以及各案例中的所有家具进行累加;本发明中的监督损失函数中引入了数据标签中的家具坐标,使得原始数据集中的图像位置关系可以纳入至计算模型中,有效地利用了原数据集中的信息,可以使网络收敛速度更快,有效地提高了生成数据的多样性和保真度。
[0086] 第9步,根据Goutput构建生成的家具布局图Igen,传入对抗器中,输出评分pG,过程式如下:
[0087] pG=resNet(Igen)
[0088] 处理后,pG为单一标量。同样,我们记对抗器为D(.),将上式改写如下:
[0089] pG=D(Igen)
[0090] 第10步,计算生成器非监督损失函数值LG_u,过程式如下:
[0091] LG_u=∑(1-pG)
[0092] 累加指的是对深度算法训练过程中同一批次的数据的累加。
[0093] 第11步,本发明的关键点是通过自适应比率r,对两部分生成器损失函数进行加权求和,从而形成新的生成器损失函数,通过参数r对LG_s和LG_u进行加权求和,得到生成器的损失函数LG:
[0094]
[0095] 在具体实施方式中,可以人为指定半监督学习参数r,也可以让模型自适应地学习r。在本实施例中采用了自适应地学习r的实施方式,r的更新方式如下:
[0096]
[0097] 第12步,基于LG利用Adam优化器更新生成器的参数。
[0098] 第13步,根据标签,构建真实家具布局图Ireal,并计算真实家具布局图的评分preal,过程式如下:
[0099] preal=D(Ireal)
[0100] 第14步,计算对抗器损失函数LD:
[0101] LD=∑(D(Ireal)-pG)
[0102] 第15步,基于LD利用Adam优化器更新对抗器的参数。
[0103] 第16步,遍历数据集,重复第2到第15步,直到生成器生成较为满意的输出为止。
[0104] 在本实施例中,采用了3万张的户型图数据作为样本集,进行了50万次的迭代训练后,得到生成模型。通过生成模型得到100张家具布局图,由专业的设计师进行评价是否符合要求,主要是从设计布局的合理性、美观性、实用性的度评价,当符合要求时,判定为合格设计图。
[0105] 同时,还进行了对比模型的构建,在生成器的监督损失函数LG_s中,采用了常规的非监督损失函数,作为模型的性能对比。
[0106] 采用本发明的模型得到的布局图以及对比模型的布局如分别如图2和图3所示,从图中可以看出布局图的设计效果对比;图中展示的是卧室的布局结果,图2中生成的卧室布局中的各个家具(床、衣柜、床头柜)的位置符合人们的生活习惯,而在图3中生成的图形中,衣柜的位置在多张图中都出现了偏差,与常理不符。
[0107] 在生成家具布局图的实验中表明:在训练相同轮次的前提下,生成的100张家具布局图,一般的GAN模型训练方法得到的生成器平均仅有10张布局图满足设计师要求,即,百张通过率10%;而利用本发明所述的训练方法得到的生成器平均达到55张布局图满足设计师要求,即,百张通过率55%。
[0108] 基于以上的示例过程,本发明的方案可以重新概括如下:
[0109] 本方法对数据的要求如下:
[0110] 本发明的方法中,需要进行处理的图片中的信息可以分为两部分:
[0111] 一部分是背景图(在本发明中也可以称之为第一图像区域),背景图是对前文所述的户型图的概念拓广,在非室内布局业务上,需要数据中含有与户型图功能相当的背景图,用于摆放前景物体例如家具等;
[0112] 另一部分就是前景图(在本发明中也可以称之为第二图像区域),前景图是对前文所述的家具平面图的概念拓广,在非室内布局业务上,需要数据中含有与家具平面图相当的前景图,基于数据标签(例如位置信息),对前景图进行仿射变换后覆盖在背景图上;
[0113] 本发明中,利用了这两块区域的相对位置关系,也就是主要利用的数据标签,在户型图的生成问题中,标签可以是家具的摆放位置信息,根据此信息,以及数据中的户型图和家具平面图,可以通过简单图形构造出室内布局图;
[0114] 具体的对标签信息的要求仅限于如下:背景图联合前景图与标签的关系必须是一对多的映射关系,亦即,背景图联合前景图可以有多个标签,但一个标签只能对应一个背景联合前景。除此之外,对标签没有具体要求,可以是二分类的也可以是多分类的还可以是回归的;可以是标量也可以是张量。
[0115] 第一图像区域可以用于规则出处理图片的基本框架,第二图像信息代表的是图像中的一些要素,并且基于两类区域的位置关系,可以将它们的相对位置关系构建出一个数据信息,也就是类似于上文所述的家具坐标信息;因此,就可以在生成器的损失函数中引入这样的第一图像区域和第二图像区域的相对位置关系,得到了一个监督损失函数,有效地利用了原始图像数据中的数据信息;接下来,在对抗器的输入数据的获取过程中,也同样利用了这样的一个位置关系信息,得到的由两块图像区域根据位置关系而生成的图像,使得图像中的信息得到了有效的利用。
[0116] 在本发明中,单一案例中可以有多个前景图,但只能有一个背景图。
[0117] 本方法采用的生成器和对抗器的主体部分都是用的resNet残差网络,后接全连接头,但具体实施时并无此限定,任意图形特征提取器都可以,可以是Inception特征特征提取器,也可以是VGG特征提取器,甚至可以是一般的卷积神经网络。
[0118] 本方法需要计算监督损失,因此,对生成器的生成结果有一定限制,即,此结果的维度大小必须和标签相同。
[0119] 监督损失的具体函数形式亦不作限制,上文方法说明是采用的是交叉熵形式,除此之外还有均方差、Huber损失、Log-Cosh损失、Hinge损失等多种可选形式。
[0120] 监督比率r可以作为超参人为设定,也可以在训练过程中自适应调整,调整方式如下:
[0121]
[0122] 以上实施方式仅仅用于说明本发明的技术构思,不代表对本发明的方案的具体限定。
高效检索全球专利

专利汇是专利免费检索,专利查询,专利分析-国家发明专利查询检索分析平台,是提供专利分析,专利查询,专利检索等数据服务功能的知识产权数据服务商。

我们的产品包含105个国家的1.26亿组数据,免费查、免费专利分析。

申请试用

分析报告

专利汇分析报告产品可以对行业情报数据进行梳理分析,涉及维度包括行业专利基本状况分析、地域分析、技术分析、发明人分析、申请人分析、专利权人分析、失效分析、核心专利分析、法律分析、研发重点分析、企业专利处境分析、技术处境分析、专利寿命分析、企业定位分析、引证分析等超过60个分析角度,系统通过AI智能系统对图表进行解读,只需1分钟,一键生成行业专利分析报告。

申请试用

QQ群二维码
意见反馈