您好,欢迎来到好走旅游网。
搜索
您的当前位置:首页风格迁移实战

风格迁移实战

来源:好走旅游网

1.导包

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from PIL import Image

2.将图像转为tensor

# 将图像转为一定的大小的tensor
def img_to_tensor(img, img_shape):
    rgb_mean = torch.tensor([0.485, 0.456, 0.406])
    rgb_std = torch.tensor([0.229, 0.224, 0.225])
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(img_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)
    ])
    return transforms(img)

3.将tensor转为图片

# 将tensor还原为image
def tensor_to_img(img):
    rgb_mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3)
    rgb_std = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3)
    img = img.permute(1, 2, 0) * rgb_std + rgb_mean
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

4.使用预训练网络提取特征以及风格

# 使用vgg19预训练模型
pretrained_net = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1)
# 在0,5,10,19,28层中提取风格特征
style_layers = [0, 5, 10, 19, 28]
# 在5, 10, 19, 25层中提取内容特征
content_layers = [5, 10, 19, 25]
max_num_layers = max(style_layers + content_layers) + 1
# 删除多余的层数
mynet = nn.Sequential(*pretrained_net.features[:max_num_layers]).cuda()

def get_features(X):
    style = []
    content = []
    for i in range(len(mynet)):
        X = mynet[i](X)
        if i in style_layers:
            style.append(X)
        if i in content_layers:
            content.append(X)
    return style, content

def get_style(img):
    style, _ = get_features(img)
    return style

def get_content(img):
    _, content = get_features(img)
    return content

5.计算内容损失

def compute_content_loss(content, new_img_content):
    content_loss = [torch.square(a.detach() - b).mean() for a, b in zip(content, new_img_content)]
    return sum(content_loss)

6.计算风格损失

# 计算通道之间的协方差矩阵
def gram(X):
    num_channels, n = X.shape[0], X.numel() // X.shape[0]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)
# 计算风格损失
def compute_style_loss(style, new_img_style):
    style_loss = [torch.square(gram(a).detach() - gram(b)).mean() for a, b in zip(style, new_img_style)]
    return sum(style_loss)

7.对生成的图像做全变分去噪

def img_loss(X):
    return 0.5 * (torch.abs(X[:, 1:, :] - X[:, :-1, :]).mean() +
                  torch.abs(X[:, :, 1:] - X[:, :, :-1]).mean())

8.计算总的损失

def get_total_loss(new_img, content, new_img_content, style, new_img_style):
    content_loss = compute_content_loss(content, new_img_content)
    style_loss = compute_style_loss(style, new_img_style)
    tv_loss = img_loss(new_img)
    return content_loss + 1e4 * style_loss + 10 * tv_loss

9.生成图像模型

class SynthesizedImage(nn.Module):
    def __init__(self, content_img_tensor_shape):
        super(SynthesizedImage, self).__init__()
        self.weight = nn.Parameter(torch.rand(*content_img_tensor_shape, requires_grad=True))

    def forward(self):
        return self.weight

10.对生成图像进行初始化

def get_init(content_img_tensor, lr):
    new_img = SynthesizedImage(content_img_tensor.shape)
    new_img.weight.data.copy_(content_img_tensor)
    optimizer = torch.optim.Adam(new_img.parameters(), lr=lr)
    return new_img(), optimizer

11. 训练函数

def train(content_img, style_img, lr=0.1, epoch_num=10, img_shape=(500, 800)):
    # 将内容图片、风格图片转为tensor,
    content_img_tensor = img_to_tensor(content_img, img_shape=img_shape).to('cuda')
    style_img_tensor = img_to_tensor(style_img, img_shape=img_shape).to('cuda')
    # 获取内容、以及风格
    content = get_content(content_img_tensor)
    style = get_style(style_img_tensor)
    # 根据内容image初始化合成图片
    new_img, optimizer = get_init(content_img_tensor, lr)
    for epoch in range(epoch_num):
        optimizer.zero_grad()
        new_img_content = get_content(new_img.to('cuda'))
        new_img_style = get_style(new_img.to('cuda'))
        total_loss = get_total_loss(new_img, content, new_img_content, style, new_img_style)
        total_loss.backward()
        optimizer.step()
        # print(total_loss)
    return tensor_to_img(new_img)

11.最后,进行训练

# 加载内容图片以及风格图片
content_img = Image.open('./img/content.jpg')
style_img = Image.open('./img/style.png')
new_img = train(content_img, style_img, lr=0.1, epoch_num=500, img_shape=(300, 400))
plt.imshow(new_img)

12.结果展示

原始内容图片为:

因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- haog.cn 版权所有

违法及侵权请联系:TEL:199 1889 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务