[코드 이모저모] Flax library -1

2024. 7. 23. 10:33Code/코드 이모저모

Flax는 Jax와 연동되는 library.

기본적인 Network architecture를 제공 [e.g. Dense, Conv ,...]

간단하게 torch.nn.Linear, torch.nn.Conv1d,.. 를 사용할 수 있도록 함.

 

대표적인 사용 예시는 아래와 같음

1. Dense Layer [torch.nn.Linear]

import jax
import flax.linen as nn
from typing import Sequnece, Callable


def default_init(scale: Optional[float] = jnp.sqrt(2)):
    return nn.initializers.orthogonal(scale)


class Linear(nn.Module)
	hidden_layers: Sequnece[int]
    activate_func: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    dropout_rate: int
    
    @nn.compact
    def __call__(self, x:jnp.ndarray, training:bool = False)-> jnp.ndarray:
    	
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=default_init())(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                    x = nn.Dropout(rate=self.dropout_rate)(
                        x, deterministic=not training)
        return x

 

2. Conv Layer [torch.nn.Conv1d]

class ConvLayer(nn.Module):
    features : Sequence[int]
    kernel_sizes: Sequence[int]
    strides: Sequence[int]
    paddings: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
    activate_final: int = False
    dropout_rate: Optional[float] = None
    after_flatten: int = 128

    @nn.compact
    def __call__(self, x, training: bool = False):
        batch, feat, seq = x.shape
        layer_meta_data = zip(self.features, self.kernel_sizes, self.strides, self.paddings)
        for i, (feat, kernel, stride, padding) in enumerate(layer_meta_data):
            x = nn.Conv(features=feat, kernel_size=(kernel, ), strides= (stride, ), padding= padding, kernel_init=default_init()) (x)
            if i + 1 < len(self.features) or self.activate_final:
                x = self.activations(x)
                if self.dropout_rate is not None:
                     x = nn.Dropout(rate=self.dropout_rate)(
                        x, deterministic=not training)
        x = x.reshape(batch, -1)
        x = nn.Dense(self.after_flatten, kernel_init=default_init())(x)
        return x

 

주의할점

1. convolution layer의 dimenstion을 결정하는 것은 class 명칭이 아니라 kernel_size 의 input 형태에 따라 달라짐.

                    i.e. nn.Conv의 parameter 중 kernel_size= (3,3)인 경우 2D (3,)인 경우 1D

 

2. features는 torch.nn.Conv 에서 out_channel과 동일, in_channel은 선언하지 않음.

 

3. Conv 연산이 기본 torch.nn.Conv 연산과 다름

                    i.e. nn.Conv의 input size: Batch, Feature, Seq    vs   torch.nn.Conv: Batch, Seq, Feature

 

4. __call__은 torch.nn.layer에서 forward와 동일 [애초에 forward는 callable member function으로 정의되어 있음]

'Code > 코드 이모저모' 카테고리의 다른 글

[코드 이모저모] Jax Device print error  (0) 2024.07.16