본문 바로가기
Data Analysis

다양한 옵션을 갖는 Normalization Layer 설정 옵션 만들어보기

by but_poor 2022. 9. 14.

커스터마이징 백본을 만들다 보면 계층별로 서로 다른 normalization layer를 설정해줘야 할 때가 있는데

 

일일이 입력해줘도 전혀 문제 없지만

 

조금더 파이써닉한 레이어 설계 툴을 만들어두면 편합니당ㅎㅎ

 

Declare

class Identity(nn.Module):
    def forward(self, x):
        return x

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

 

이외에도 group normalizatoin이나 channel normalization도 있지만 제 기준에서는 아직 흔히 쓰이지 않아서 구현은 안했습니다만 만들기가 전혀 어렵지는 않습니당ㅎㅎ

 

위에 4줄만 추가해주면 되요!!ㅎㅎ

 

Identity class는 제 경험상 꼭 한번씩 쓰게 되었던 것 같아요

그냥 자기 자신 뱉어내게 하면 되는데 특히 bottle neck을 가지는 backbone 만들 땐 고려하게 되는 것 같아요

 

 

Usage

norm_layer = get_norm_layer(norm_type='instance')  # instance | batch | none
NetG = G(..., norm_layer=norm_layer, ...)
NetD = D(...)

...

class G(nn.Module):
    def __init__(self, ..., norm_layer=norm_layer, ...):
        super(G, self).__init__()
        
        if type(norm_layer) == functools.partial:  # 요기
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
            
        down = [nn.ReflectionPad2d(3),
                 nn.Conv2d(dim, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),  # 요기
                 nn.ReLU(True)]
        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            down += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),  # 요기
                      nn.ReLU(True)
                     ]

...

요렇게 쓰시면 됩니당!!ㅎㅎ

댓글