import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from PIL import Image
# 将图像转为一定的大小的tensor
def img_to_tensor(img, img_shape):
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize(img_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)
])
return transforms(img)
# 将tensor还原为image
def tensor_to_img(img):
rgb_mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3)
rgb_std = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3)
img = img.permute(1, 2, 0) * rgb_std + rgb_mean
return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))
# 使用vgg19预训练模型
pretrained_net = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.IMAGENET1K_V1)
# 在0,5,10,19,28层中提取风格特征
style_layers = [0, 5, 10, 19, 28]
# 在5, 10, 19, 25层中提取内容特征
content_layers = [5, 10, 19, 25]
max_num_layers = max(style_layers + content_layers) + 1
# 删除多余的层数
mynet = nn.Sequential(*pretrained_net.features[:max_num_layers]).cuda()
def get_features(X):
style = []
content = []
for i in range(len(mynet)):
X = mynet[i](X)
if i in style_layers:
style.append(X)
if i in content_layers:
content.append(X)
return style, content
def get_style(img):
style, _ = get_features(img)
return style
def get_content(img):
_, content = get_features(img)
return content
def compute_content_loss(content, new_img_content):
content_loss = [torch.square(a.detach() - b).mean() for a, b in zip(content, new_img_content)]
return sum(content_loss)
# 计算通道之间的协方差矩阵
def gram(X):
num_channels, n = X.shape[0], X.numel() // X.shape[0]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
# 计算风格损失
def compute_style_loss(style, new_img_style):
style_loss = [torch.square(gram(a).detach() - gram(b)).mean() for a, b in zip(style, new_img_style)]
return sum(style_loss)
def img_loss(X):
return 0.5 * (torch.abs(X[:, 1:, :] - X[:, :-1, :]).mean() +
torch.abs(X[:, :, 1:] - X[:, :, :-1]).mean())
def get_total_loss(new_img, content, new_img_content, style, new_img_style):
content_loss = compute_content_loss(content, new_img_content)
style_loss = compute_style_loss(style, new_img_style)
tv_loss = img_loss(new_img)
return content_loss + 1e4 * style_loss + 10 * tv_loss
class SynthesizedImage(nn.Module):
def __init__(self, content_img_tensor_shape):
super(SynthesizedImage, self).__init__()
self.weight = nn.Parameter(torch.rand(*content_img_tensor_shape, requires_grad=True))
def forward(self):
return self.weight
def get_init(content_img_tensor, lr):
new_img = SynthesizedImage(content_img_tensor.shape)
new_img.weight.data.copy_(content_img_tensor)
optimizer = torch.optim.Adam(new_img.parameters(), lr=lr)
return new_img(), optimizer
def train(content_img, style_img, lr=0.1, epoch_num=10, img_shape=(500, 800)):
# 将内容图片、风格图片转为tensor,
content_img_tensor = img_to_tensor(content_img, img_shape=img_shape).to('cuda')
style_img_tensor = img_to_tensor(style_img, img_shape=img_shape).to('cuda')
# 获取内容、以及风格
content = get_content(content_img_tensor)
style = get_style(style_img_tensor)
# 根据内容image初始化合成图片
new_img, optimizer = get_init(content_img_tensor, lr)
for epoch in range(epoch_num):
optimizer.zero_grad()
new_img_content = get_content(new_img.to('cuda'))
new_img_style = get_style(new_img.to('cuda'))
total_loss = get_total_loss(new_img, content, new_img_content, style, new_img_style)
total_loss.backward()
optimizer.step()
# print(total_loss)
return tensor_to_img(new_img)
# 加载内容图片以及风格图片
content_img = Image.open('./img/content.jpg')
style_img = Image.open('./img/style.png')
new_img = train(content_img, style_img, lr=0.1, epoch_num=500, img_shape=(300, 400))
plt.imshow(new_img)
原始内容图片为:
因篇幅问题不能全部显示,请点此查看更多更全内容