본문 바로가기
Data Analysis

쉽게 쓰는 Torch FID Score 모듈

by but_poor 2022. 8. 20.

Reference Address : https://github.com/vict0rsch/pytorch-fid-wrapper

 

GitHub - vict0rsch/pytorch-fid-wrapper: A simple wrapper around @mseitzer's great pytorch-fid work to compute Fréchet Inception

A simple wrapper around @mseitzer's great pytorch-fid work to compute Fréchet Inception Distance in-memory from batches of images, using PyTorch - GitHub - vict0rsch/pytorch-fid-wrapper: A simp...

github.com


FID는 Frechat Inception Score의 약자로 대충 SOTA 논문 및 모델에서 많이 인용하고 있는 GAN의 정량적인 지표에요
구체적인 설명은 구글링 부탁드립니당ㅎㅎ
대충 이미지 vs 이미지의 유클리디언 디스턴스랑 피쳐맵 디스턴스 구해서 더해주는 스코어에요!!ㅎㅎ

오늘 포스팅에선 그냥 빠르게 읽고 빠르게 갖다쓰는 데 집중해보겠습니당!!ㅎㅎ

Library install and import

pip install pytorch-fid-wrapper  # 주피터 사용자를 위한 래퍼
import pytorch_fid_wrapper as pfw  # fid 모듈 임포트

 

FID Configuration Setting

pfw.set_config(batch_size=1, device="cuda:0")
pfw.fid()
fid_score = pfw.fid(FIDPreporcess(fake), FIDPreporcess(real))


Sample Test

batch_size = 4
img_chn = 3  #maybe RGB
img_size = 128
x = torch.normal(0, 1, (batch_size, img_chn, img_size, img_size))
y = torch.normal(0.2, 0.8, (batch_size, img_chn, img_size, img_size))
pfw.set_config(batch_size=batch_size, device="cpu")
fid_score = pfw.fid(x, y)
fid_score

 

Usage Example(Paired Images)

Generator = model(...)  # Declare Model
for a, (src, dst) in enumerate(dataloader):
    src, dst = src.to(device), dst.to(device)  # Send by GPU
    fake = Generator(src)  # Generate Fake Image by Generator, Using GPU
    # Use CPU for Memory Management
    if a == 0 :
        fake_images = fake.detach().cpu()
        real_images = dst.detach().cpu()
    else :
        fake_images = torch.cat([fake_images, fake.detach().cpu()], 0)
        real_images = torch.cat([real_images, dst.detach().cpu()], 0)
        
print('FID Score : ', pfw.fid(fake_images, real_images))

댓글