본문 바로가기
Data Analysis

SRGAN의 VGGPerceptualLoss 써보기

by but_poor 2022. 9. 11.

Reference Address : https://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf

 

VGGPerceptualLoss는 Feature Map 간의 거리를 Pretrained된 VGG 네트워크를 이용해 계산해주기 때문에 생성해내는 이미지의 Detail이 잘 표현된다고 알려져 있어요

 

 

 

제 모델에 적용해 보려고 공부한 김에 포스팅하는데 오류가 있으면 지적 부탁드릴게요 :)

 

참조한 깃허브는 다음과 같아요

https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49

 

PyTorch implementation of VGG perceptual loss

PyTorch implementation of VGG perceptual loss. GitHub Gist: instantly share code, notes, and snippets.

gist.github.com

 

Declare Loss Function Class

 

import torchvision
import torch.nn as nn
import torch.nn.functional as F

class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl.parameters():
                p.requires_grad = False
        self.blocks = nn.ModuleList(blocks)
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))


    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
        # convert to RGB space if input is a grayscale image
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        loss = 0.0
        x = input
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += F.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += F.l1_loss(gram_x, gram_y)
        return loss

 

Sample Test

VGGLoss = VGGPerceptualLoss().to(device)  # Use GPU for Model's Loss Function

batch_size = 4
img_chn = 3
img_size = 128

lambda_VGG = 10  # Default 10 | Normaly, 10 to 750

x = torch.normal(0, 1, (batch_size, img_chn, img_size, img_size)
y = torch.normal(0.3, 0.7, (batch_size, img_chn, img_size, img_size)

loss_G_VGG = VGGLoss(x, y) * lambda_VGG

print(f'x mean : {x.mean()}, y mean : {y.mean()}, loss : {loss_G_VGG}')

 

오늘도 빠르게 읽고 빠르게 써보기 끝!

댓글