新闻  |   论坛  |   博客  |   在线研讨会
Python多阶段框架实现虚拟试衣间,超逼真!
AI科技大本营 | 2020-12-15 11:17:44    阅读:695   发布文章

任意姿态下的虚拟试衣因其巨大的应用潜力而引起了人们的广泛关注。然而,现有的方法在将新颖的服装和姿势贴合到一个人身上的同时,很难保留服装纹理和面部特征(面孔、毛发)中的细节。故在论文《Downto the Last Detail: Virtual Try-on with Detail Carving》中提出了一种新的多阶段合成框架,可以很好地保留图像显著区域的丰富细节。

具体地说,就是提出了一个多阶段的框架,将生成分解为空间对齐,然后由粗到细生成。为了更好地保留显著区域的细节,如服装和面部区域,我们提出了一个树块(树扩张融合块)来利用多尺度特征在发生器网络。通过多个阶段的端到端训练,可以联合优化整个框架,最终使得视觉逼真度得到了显著的提高、同时获得了细节更为丰富的结果。在标准数据集上进行的大量实验表明,他们提出的框架实现了最先进的性能,特别是在保存服装纹理和面部识别的视觉细节方面。

故今天我们将在他们代码的基础上,实现虚拟换衣系统。具体流程如下:

图片实验前的准备

首先我们使用的python版本是3.6.5所用到的模块如下:

·  opencv是将用来进行图像处理和图片保存读取等操作。

·  numpy模块用来处理矩阵数据的运算。

·  pytorch模块是常用的用来搭建模型和训练的深度学习框架,和tensorflow以及Keras等具有相当的地位。

·  json是为了读取json存储格式的数据。

·  PIL库可以完成对图像进行批处理、生成图像预览、图像格式转换和图像处理操作,包括图像基本处理、像素处理、颜色处理等。

·  argparse 是python自带的命令行参数解析包,可以用来方便地读取命令行参数。

图片网络模型的定义和训练

其中已经训练好的模型地址如下:https://drive.google.com/open?id=1vQo4xNGdYe2uAtur0mDlHY7W2ZR3shWT。其中需要将其中的模型放到"./pretrained_checkpoint"目录下。

对于数据集的存放,分为cloth_image(用来存储衣服图片),cloth_mask(用来分割衣服的mask,可以使用grabcut的方法进行分割保存),image(用来存储人物图片),parse_cihp(用来衣服语义分析的图片结果,可以使用[CIHP_PGN](https://github.com/Engineering-Course/CIHP_PGN)的方法获得)和pose_coco(用来存储提取到的人物姿态特征数据,可以使用openpose进行提取保存为josn数据即可)。

5.jpg

对于模型的训练,我们需要使用到VGG19模型,网络上可以很容易下载到,然后把它放到vgg_model文件夹下。

其中提出的一种基于目标姿态和店内服装图像由粗到细的多阶段图像生成框架,首先是设计了一个解析转换网络来预测目标语义图,该语义图在空间上对齐相应的身体部位,并提供更多关于躯干和四肢形状的结构信息。然后使用一种新的树扩张融合块(tree - block)算法,将空间对齐的布料与粗糙的渲染图像融合在一起,以获得更合理、更体面的结果。其中这个虚拟试穿网络不仅不借助3D信息,可以在任意姿态下将新衣服叠加到人的对应区域上,还保留和增强了显著区域的丰富细节,如布料纹理、面部特征等。同时还使用了空间对齐、多尺度上下文特征聚集和显著的区域增强,以由粗到细的方式各种难题。

(1)其中网络主要使用pix2pix模型,其中的部分代码如下:

class PixelDiscriminator(nn.Module):

    def __init__(self, input_nc,ndf=64, norm_layer=nn.InstanceNorm2d):

        super(PixelDiscriminator,self).__init__()

        if type(norm_layer) ==functools.partial:

            use_bias =norm_layer.func == nn.InstanceNorm2d

        else:

            use_bias = norm_layer ==nn.InstanceNorm2d

        self.net = nn.Sequential(

            nn.Conv2d(input_nc, ndf,kernel_size=1, stride=1, padding=0),

            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf, ndf * 2,kernel_size=1, stride=1, padding=0, bias=use_bias),

            norm_layer(ndf * 2),

            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf * 2, 1,kernel_size=1, stride=1, padding=0, bias=use_bias),

            nn.Sigmoid()

        )

    def forward(self, input):

        return self.net(input)

class PatchDiscriminator(nn.Module):

    def __init__(self, input_nc,ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):

        super(PatchDiscriminator,self).__init__()

        if type(norm_layer) ==functools.partial:  # no need to use biasas BatchNorm2d has affine parameters

            use_bias =norm_layer.func == nn.InstanceNorm2d

        else:

            use_bias = norm_layer ==nn.InstanceNorm2d

        kw = 4

        padw = 1

        sequence =[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),nn.LeakyReLU(0.2, True)]

        nf_mult = 1

        nf_mult_prev = 1

        # channel up

        for n in range(1,n_layers):  # gradually increase thenumber of filters

            nf_mult_prev = nf_mult #1,2,4,8

            nf_mult = min(2 ** n, 8)

            sequence += [

                nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw,bias=use_bias),

                norm_layer(ndf *nf_mult),

                nn.LeakyReLU(0.2,True)

            ]

        # channel down

        nf_mult_prev = nf_mult

        nf_mult = min(2 ** n_layers,8)

        sequence += [

            nn.Conv2d(ndf *nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw,bias=use_bias),

            norm_layer(ndf *nf_mult),

            nn.LeakyReLU(0.2, True)

        ]

        # channel = 1 (bct, 1, x, x)

        sequence += [nn.Conv2d(ndf *nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map

        sequence += [nn.Sigmoid()]

        self.model =nn.Sequential(*sequence) 

4.jpg

(2)生成器部分代码:

class GenerationModel(BaseModel):

    def name(self):

        return 'Generation model:pix2pix | pix2pixHD'

    def __init__(self, opt):

        self.t0 = time()

        BaseModel.__init__(self,opt)

        self.train_mode =opt.train_mode

        # resume of networks

        resume_gmm = opt.resume_gmm

        resume_G_parse =opt.resume_G_parse

        resume_D_parse =opt.resume_D_parse

        resume_G_appearance =opt.resume_G_app

        resume_D_appearance =opt.resume_D_app

        resume_G_face = opt.resume_G_face

        resume_D_face =opt.resume_D_face

        # define network

        self.gmm_model =torch.nn.DataParallel(GMM(opt)).cuda()

        self.generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf, opt.netG_parsing,opt.norm,

                                       not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

        self.discriminator_parsing =Define_D(opt.input_nc_D_parsing, opt.ndf, opt.netD_parsing, opt.n_layers_D,

                                       opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)

        self.generator_appearance =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,

                                       not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids,with_tanh=False)

       self.discriminator_appearance = Define_D(opt.input_nc_D_app, opt.ndf,opt.netD_app, opt.n_layers_D,

                                       opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)

        self.generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,

                                        notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

        self.discriminator_face =Define_D(opt.input_nc_D_face, opt.ndf, opt.netD_face, opt.n_layers_D,

                                       opt.norm, opt.init_type, opt.init_gain, opt.gpu_ids)

        if opt.train_mode == 'gmm':

            setattr(self,'generator', self.gmm_model)

        else:

            setattr(self,'generator', getattr(self, 'generator_' + self.train_mode))

            setattr(self, 'discriminator',getattr(self, 'discriminator_' + self.train_mode))

        # load networks

        self.networks_name = ['gmm','parsing', 'parsing', 'appearance', 'appearance', 'face', 'face']

        self.networks_model =[self.gmm_model, self.generator_parsing, self.discriminator_parsing,self.generator_appearance, self.discriminator_appearance,

                       self.generator_face, self.discriminator_face]

        self.networks =dict(zip(self.networks_name, self.networks_model))

        self.resume_path =[resume_gmm, resume_G_parse, resume_D_parse, resume_G_appearance,resume_D_appearance, resume_G_face, resume_D_face]

        for network, resume inzip(self.networks_model, self.resume_path):

            if network != [] andresume != '':

               assert(osp.exists(resume), 'the resume not exits')

                print('loading...')

               self.load_network(network, resume, ifprint=False)

        # define optimizer

        self.optimizer_gmm =torch.optim.Adam(self.gmm_model.parameters(), lr=opt.lr, betas=(0.5, 0.999))

        self.optimizer_parsing_G =torch.optim.Adam(self.generator_parsing.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])

        self.optimizer_parsing_D =torch.optim.Adam(self.discriminator_parsing.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])

        self.optimizer_appearance_G= torch.optim.Adam(self.generator_appearance.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])

        self.optimizer_appearance_D= torch.optim.Adam(self.discriminator_appearance.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])

        self.optimizer_face_G =torch.optim.Adam(self.generator_face.parameters(), lr=opt.lr, betas=[opt.beta1,0.999])

        self.optimizer_face_D =torch.optim.Adam(self.discriminator_face.parameters(), lr=opt.lr,betas=[opt.beta1, 0.999])

        if opt.train_mode == 'gmm':

            self.optimizer_G =self.optimizer_gmm

        elif opt.joint_all:

            self.optimizer_G =[self.optimizer_parsing_G, self.optimizer_appearance_G, self.optimizer_face_G]

            setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))

        else:

            setattr(self,'optimizer_G', getattr(self, 'optimizer_' + self.train_mode + '_G'))

            setattr(self,'optimizer_D', getattr(self, 'optimizer_' + self.train_mode + '_D'))

        self.t1 = time()

3.jpg

图片模型的使用

在模型训练完成之后,通过命令“python demo.py --batch_size_v 80--num_workers 4 --forward_save_path 'demo/forward'”生成图片。

2.jpg

(1)分别定义读取模型函数和模型调用处理图片函数

def load_model(model, path):

    checkpoint = torch.load(path)

    try:

       model.load_state_dict(checkpoint)

    except:

       model.load_state_dict(checkpoint.state_dict())

    model = model.cuda()

    model.eval()

    print(20*'=')

    for param in model.parameters():

        param.requires_grad = False

def forward(opt, paths, gpu_ids, refine_path):

    cudnn.enabled = True

    cudnn.benchmark = True

    opt.output_nc = 3

    gmm = GMM(opt)

    gmm =torch.nn.DataParallel(gmm).cuda()

    # 'batch'

    generator_parsing =Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf,opt.netG_parsing, opt.norm,

                            notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    generator_app_cpvton =Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app,opt.norm,

                            notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids, with_tanh=False)

    generator_face =Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face,opt.norm,

                            notopt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    models = [gmm,generator_parsing, generator_app_cpvton, generator_face]

    for model, path in zip(models,paths):

        load_model(model, path)   

    print('==>loaded model')

    augment = {}

    if '0.4' in torch.__version__:

        augment['3'] =transforms.Compose([

                                   # transforms.Resize(256),

                                   transforms.ToTensor(),

                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))

            ]) # change to [C, H, W]

        augment['1'] = augment['3']

    else:

        augment['3'] =transforms.Compose([

                                #transforms.Resize(256),

                               transforms.ToTensor(),

                               transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))

        ]) # change to [C, H, W]

        augment['1'] =transforms.Compose([

                                # transforms.Resize(256),

                               transforms.ToTensor(),

                               transforms.Normalize((0.5,), (0.5,))

        ]) # change to [C, H, W]

    val_dataset = DemoDataset(opt,augment=augment)

    val_dataloader = DataLoader(

                    val_dataset,

                    shuffle=False,

                    drop_last=False,

                   num_workers=opt.num_workers,

                    batch_size = opt.batch_size_v,

                    pin_memory=True)

    with torch.no_grad():

        for i, result inenumerate(val_dataloader):

            'warped cloth'

            warped_cloth =warped_image(gmm, result)

            if opt.warp_cloth:

                warped_cloth_name =result['warped_cloth_name']

                warped_cloth_path =os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])

                if notos.path.exists(os.path.split(warped_cloth_path)[0]):

                    os.makedirs(os.path.split(warped_cloth_path)[0])

               utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)

               print('processing_%d'%i)

                continue

            source_parse =result['source_parse'].float().cuda()

            target_pose_embedding =result['target_pose_embedding'].float().cuda()

            source_image =result['source_image'].float().cuda()

            cloth_parse =result['cloth_parse'].cuda()

            cloth_image =result['cloth_image'].cuda()

            target_pose_img =result['target_pose_img'].float().cuda()

            cloth_parse =result['cloth_parse'].float().cuda()

            source_parse_vis =result['source_parse_vis'].float().cuda()

            "filter add clothinfomation"

            real_s =source_parse  

            index = [x for x inlist(range(20)) if x != 5 and x != 6 and x != 7]

            real_s_ =torch.index_select(real_s, 1, torch.tensor(index).cuda())

            input_parse =torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()

            'P'

            generate_parse =generator_parsing(input_parse) # tanh

            generate_parse =F.softmax(generate_parse, dim=1)

            generate_parse_argmax =torch.argmax(generate_parse, dim=1, keepdim=True).float()

            res = []

            for index in range(20):

                res.append(generate_parse_argmax == index)

            generate_parse_argmax =torch.cat(res, dim=1).float()

            "A"

            image_without_cloth =create_part(source_image, source_parse, 'image_without_cloth', False)

            input_app = torch.cat((image_without_cloth,warped_cloth, generate_parse), 1).cuda()

1.jpg

源码地址:

链接:https://pan.baidu.com/s/1POrmcvv_LWg-SuY3ilif5g

提取码:qcj6

*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客