[논문구현]Pytorch로 ConvNeXt 구현

2023. 8. 10. 20:26

ConvNeXt 는 Attention 을 제외한 Convolution 연산만을 사용하는 모델로서, 여전히 Convolution 연산이 CV 분야에 유용하다는 것을 알렸습니다.

 

이번 포스팅에서는 ConvNeXt 를 구현해보겠습니다. ConvNeXt 리뷰는 아래 링크에서 확인할 수 있습니다.

https://dreamrunning.tistory.com/15

 

 

1. 구현방법

기본적으로, ConvNeXt 는 ResNet 구조를 개량한 구조이기 때문에 ResNet 과 비슷하게 구현할 수 있습니다. 이 포스팅에서는 ResNet 과 유사한 방법으로 구현하겠습니다.

 

공식 GitHub 이 공개되어있으므로, 공식 구현 방법이 궁금하시면 찾아보면 되겠습니다.

 

2. LayerNorm

ConvNeXt 에서는 Normalizaiton Layer 로 LayerNorm 을 사용합니다. Pytorch 에서는 nn.LayerNorm 으로 LayerNorm 을 사용할 수 있습니다.

 

nn.LayerNorm 은 마지막 차원을 Normalization 하므로, LayerNorm layer 를 통과하기 전 Tensor 의 차원을 바꿔줘야 할 필요성이 있습니다. 따라서 따로 클래스로 분리해서 구현하는 것이 편하므로 LayerNorm Class 를 정의하고 사용합니다.

class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.layernorm = nn.LayerNorm(dim, eps=eps)
    
    def forward(self, x):
        x = x.permute(0, 2, 3, 1) # (b c h w) -> (b h w c) for Layernorm
        x = self.layernorm(x)
        x = x.permute(0, 3, 1, 2) # (b h w c) -> (b c h w)
        return x

 

3. Block

ConvNeXt 에서는 Reverse Bottleneck 구조를 가진 하나의 Block 으로 모든 stage 를 구성합니다. 다음과 같이 구현할 수 있습니다.

class Block(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, in_ch, kernel_size=7, padding=3, groups=in_ch)
        self.conv2 = nn.Conv2d(in_ch, in_ch*4, kernel_size=1,)
        self.conv3 = nn.Conv2d(in_ch*4, in_ch, kernel_size=1,)
        self.norm = LayerNorm(in_ch)
        self.gelu = nn.GELU()
        self.residual = nn.Identity()
    
    def forward(self, x):
        res = self.residual(x)
        x = self.conv1(x)
        x = self.norm(x)
        x = self.conv2(x)
        x = self.gelu(x)
        x = self.conv3(x)
        x = x + res
        return x

 

4. ConvNeXt Class

 

4.1 make_layer Method

이 포스팅에서는 ResNet 과 비슷하게 make_layer 메소드를 정의하여 각 Stage 를 만들었습니다. for 문을 이용해서 Block을 반복하고, 각 Stage 이후 별도의 DownSampling Layer를 추가합니다.

class ConvNeXt(nn.Module):
    def __init__(self, 
                 in_ch: int,
                 inner_dim: int = 96,
                 stages = [3, 3, 9, 3]):
        super().__init__()

		...

        self.block = Block # Block Class
        self.stage1 = self.make_layer(channels=inner_dim, num_block=stages[0])
        self.stage2 = self.make_layer(channels=inner_dim*2, num_block=stages[1])
        self.stage3 = self.make_layer(channels=inner_dim*4, num_block=stages[2])
        self.stage4 = self.make_layer(channels=inner_dim*8, num_block=stages[3])

		...


    def make_layer(self, channels, num_block=3):
        self.downsampling = nn.Sequential(
            nn.Conv2d(channels, channels*2, kernel_size=2, stride=2),
            LayerNorm(channels*2)
        )

        layers = nn.ModuleList([])
        for i in range(num_block):
            layers.append(self.block(channels))
        layers.append(self.downsampling)

        return nn.Sequential(*layers)

 

 

4.2 Stem

 

처음 이미지를 받는 Stem layer 는 논문에 따라 kernel size=4, stride=4 인 Convolution Layer  와 LayerNorm 으로 구성되어 있습니다. Sequential 로 감싸 선언합니다.

class ConvNeXt(nn.Module):
    def __init__(self, 
                 in_ch: int,
                 inner_dim: int = 96,
                 stages = [3, 3, 9, 3]):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, inner_dim, kernel_size=4, stride=4),
            LayerNorm(inner_dim)
        )

		...

 

4.3 Global Avrage Pooling & Linear Layer

 

Global Avrage Pooling

논문에서는 Stage 이후 Linear Layer 를 통과하기 전 Global Average Pooling 을 사용한다고 되어있습니다. Global Avg Pooling 은 여러가지 구현방법이 있는데 여기서는 Tensor.mean 을 사용하여 구현하겠습니다. 이는 forward 에서 정의합니다.

 

LayerNorm

Global Avg Pooling 이후 Tensor 의 차원이 (Batch, Channel, Width, Height) 에서 (Batch, Channel) 로 줄어듭니다. 따라서 기존에 사용되었던 LayerNorm Class 를 사용못하므로 따로 nn.LayerNorm 을 사용해야 합니다.

 

Linear Layer

이후 Linear Layer 를 정의해서 사용하면 됩니다.

 

코드로 구현하면 다음과 같습니다.

class ConvNeXt(nn.Module):
    def __init__(self, 
                 in_ch: int,
                 inner_dim: int = 96,
                 stages = [3, 3, 9, 3]):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, inner_dim, kernel_size=4, stride=4),
            LayerNorm(inner_dim)
        )

		...
        
        self.norm = nn.LayerNorm(inner_dim*16, eps=1e-6)
        self.linear = nn.Linear(inner_dim*16, 1000)
		
        ...


    def forward(self, x):
		...

        x = x.mean([-2, -1]) # global avrage pooling 
        x = self.norm(x)
        x = self.linear(x)
        
        return x

 

 

4.4 ConvNeXt 완성하기

 

위 코드를 합치면 다음과 같습니다. ConvNeXt Tiny 를 기준으로 파라메터를 정의했습니다.

(inner_dim=96, stages=[3,3,9,3] )

 

class ConvNeXt(nn.Module):
    def __init__(self, 
                 in_ch: int,
                 inner_dim: int = 96,
                 stages = [3, 3, 9, 3]):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, inner_dim, kernel_size=4, stride=4),
            LayerNorm(inner_dim)
        )

        self.block = Block # Block Class
        self.stage1 = self.make_layer(channels=inner_dim, num_block=stages[0])
        self.stage2 = self.make_layer(channels=inner_dim*2, num_block=stages[1])
        self.stage3 = self.make_layer(channels=inner_dim*4, num_block=stages[2])
        self.stage4 = self.make_layer(channels=inner_dim*8, num_block=stages[3])

        self.norm = nn.LayerNorm(inner_dim*16, eps=1e-6)
        self.linear = nn.Linear(inner_dim*16, 1000)


    def make_layer(self, channels, num_block=3):
        self.downsampling = nn.Sequential(
            nn.Conv2d(channels, channels*2, kernel_size=2, stride=2),
            LayerNorm(channels*2)
        )

        layers = nn.ModuleList([])
        for i in range(num_block):
            layers.append(self.block(channels))
        layers.append(self.downsampling)

        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = x.mean([-2, -1]) # global avrage pooling 
        x = self.norm(x)
        x = self.linear(x)
        return x

'DeepLeaning > 구현' 카테고리의 다른 글

[논문구현]Pytorch로 ResNet 구현  (0) 2023.07.27
Pytorch 로 Attention 구현하기  (0) 2023.07.12

BELATED ARTICLES

more