My first ever blog post was about implementing neural style transfer in TensorFlow: https://cfml.se/blog/neural_style_transfer/, which is an algorithm introduced in (Gatys et al., 2015, https://arxiv.org/abs/1508.06576). In this notebook I reimplement the algorithm using PyTorch, with some changes compared to the first post. The most notable changes are:
The settings, such as the CNN model (vgg16 instead of the vgg19 used in the original paper), what feature layers to extract from the CNN model for content and style loss, and the addition of total variation loss are inspired by the baseline model in (Johnsson, Alahi, and Fei-Fei, 2016, https://arxiv.org/abs/1603.08155), which is a reimplementation of the algorithm by Gatys et al. with some tweaks.
The source code can be found at https://github.com/CarlFredriksson/neural_style_transfer_2.
import numpy as np
import PIL
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
print('np.__version__:', np.__version__)
print('PIL.__version__:', PIL.__version__)
print('matplotlib.__version__:', matplotlib.__version__)
print('torch.__version__:', torch.__version__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CONTENT_IMG_PATH = 'input/cat2.jpg'
STYLE_IMG_PATH = 'input/mosaic.jpg'
CNN_MODEL = models.vgg16(pretrained=True)
NORMALIZATION_MEAN = torch.tensor([0.485, 0.456, 0.406]).to(DEVICE).view(-1, 1, 1)
NORMALIZATION_STD = torch.tensor([0.229, 0.224, 0.225]).to(DEVICE).view(-1, 1, 1)
CONTENT_LAYERS = ['15']
STYLE_LAYERS = ['3', '8', '15', '22']
CONTENT_WEIGHT = 1
STYLE_WEIGHT = 1e6
TV_WEIGHT = 1e-7
LEARNING_RATE = 0.1
NUM_ITERATIONS = 500
c_img = Image.open(CONTENT_IMG_PATH)
plt.imshow(np.asarray(c_img))
s_img = Image.open(STYLE_IMG_PATH)
plt.imshow(np.asarray(s_img))
def preprocess_img(img, device):
transform = transforms.Compose([
transforms.ToTensor()
])
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0)
img_tensor = img_tensor.to(device)
img_tensor.requires_grad = True
return img_tensor
def postprocess_img(img_tensor):
img = img_tensor.detach().cpu().numpy()
img = np.squeeze(img, axis=0)
img = img.transpose((1, 2, 0))
img = img.clip(0, 1)
return img
c_img = preprocess_img(c_img, DEVICE)
c_img.shape
s_img = preprocess_img(s_img, DEVICE)
s_img.shape
class LossNetwork(torch.nn.Module):
def __init__(self, cnn_layers, content_layer_keys, style_layer_keys):
super(LossNetwork, self).__init__()
self.layers = cnn_layers
self.content_layer_keys = content_layer_keys
self.style_layer_keys = style_layer_keys
def forward(self, x):
content_features = []
style_features = []
for key, layer in self.layers._modules.items():
x = layer(x)
if key in self.content_layer_keys:
content_features.append(x)
if key in self.style_layer_keys:
style_features.append(x)
return content_features, style_features
loss_network = LossNetwork(CNN_MODEL.features, CONTENT_LAYERS, STYLE_LAYERS)
loss_network.to(DEVICE).eval()
def normalize(img, mean, std):
return (img - mean) / std
with torch.no_grad():
c_img_content_features, _ = loss_network(normalize(c_img, NORMALIZATION_MEAN, NORMALIZATION_STD))
_, s_img_style_features = loss_network(normalize(s_img, NORMALIZATION_MEAN, NORMALIZATION_STD))
[feature.shape for feature in c_img_content_features]
[feature.shape for feature in s_img_style_features]
def compute_gram_matrix(features):
_, num_channels, height, width = features.shape
features_unrolled = features.view(num_channels, -1)
return torch.matmul(features_unrolled, torch.transpose(features_unrolled, 0, 1)) / (num_channels * height * width)
def compute_gram_matrices(features_per_layer):
return [compute_gram_matrix(features) for features in features_per_layer]
s_img_gram_matrices = compute_gram_matrices(s_img_style_features)
[mat.shape for mat in s_img_gram_matrices]
def loss_func(g_img, g_img_content_features, c_img_content_features, g_img_gram_matrices, s_img_gram_matrices,
content_weight, style_weight, tv_weight):
# Content loss
content_loss = 0
for i in range(len(g_img_content_features)):
content_loss = content_weight * F.mse_loss(g_img_content_features[i], c_img_content_features[i])
# Style loss
style_loss = 0
for i in range(len(g_img_gram_matrices)):
style_loss += F.mse_loss(g_img_gram_matrices[i], s_img_gram_matrices[i])
style_loss *= style_weight
# Total variation regularization loss
tv_loss = tv_weight * (
torch.sum(torch.abs(g_img[:, :, :, :-1] - g_img[:, :, :, 1:])) +
torch.sum(torch.abs(g_img[:, :, :-1, :] - g_img[:, :, 1:, :]))
)
# Total loss
loss = content_loss + style_loss + tv_loss
return loss, content_loss, style_loss, tv_loss
g_img = torch.tensor(c_img.cpu().detach().numpy(), requires_grad=True, device=DEVICE)
plt.imshow(postprocess_img(g_img))
optimizer = optim.LBFGS([g_img], lr=LEARNING_RATE)
i = [0] # Use list for iteration counting because of the closure scope
while i[0] < NUM_ITERATIONS:
def closure():
g_img.data.clamp_(0, 1)
optimizer.zero_grad()
g_img_content_features, g_img_style_features = loss_network(normalize(g_img, NORMALIZATION_MEAN, NORMALIZATION_STD))
g_img_gram_matrices = compute_gram_matrices(g_img_style_features)
loss, content_loss, style_loss, tv_loss = loss_func(
g_img,
g_img_content_features,
c_img_content_features,
g_img_gram_matrices,
s_img_gram_matrices,
CONTENT_WEIGHT,
STYLE_WEIGHT,
TV_WEIGHT
)
loss.backward()
i[0] += 1
if i[0] == 1 or i[0] % 50 == 0:
print('i: {}, content_loss: {:4f}, style_loss: {:4f}, tv_loss: {:4f}'.format(
i, content_loss, style_loss, tv_loss))
return loss
optimizer.step(closure)
plt.imshow(postprocess_img(g_img))