[코드 이모저모] Flax library -1
2024. 7. 23. 10:33ㆍCode/코드 이모저모
Flax는 Jax와 연동되는 library.
기본적인 Network architecture를 제공 [e.g. Dense, Conv ,...]
간단하게 torch.nn.Linear, torch.nn.Conv1d,.. 를 사용할 수 있도록 함.
대표적인 사용 예시는 아래와 같음
1. Dense Layer [torch.nn.Linear]
import jax
import flax.linen as nn
from typing import Sequnece, Callable
def default_init(scale: Optional[float] = jnp.sqrt(2)):
return nn.initializers.orthogonal(scale)
class Linear(nn.Module)
hidden_layers: Sequnece[int]
activate_func: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
dropout_rate: int
@nn.compact
def __call__(self, x:jnp.ndarray, training:bool = False)-> jnp.ndarray:
for i, size in enumerate(self.hidden_dims):
x = nn.Dense(size, kernel_init=default_init())(x)
if i + 1 < len(self.hidden_dims) or self.activate_final:
x = self.activations(x)
if self.dropout_rate is not None:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training)
return x
2. Conv Layer [torch.nn.Conv1d]
class ConvLayer(nn.Module):
features : Sequence[int]
kernel_sizes: Sequence[int]
strides: Sequence[int]
paddings: Sequence[int]
activations: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
activate_final: int = False
dropout_rate: Optional[float] = None
after_flatten: int = 128
@nn.compact
def __call__(self, x, training: bool = False):
batch, feat, seq = x.shape
layer_meta_data = zip(self.features, self.kernel_sizes, self.strides, self.paddings)
for i, (feat, kernel, stride, padding) in enumerate(layer_meta_data):
x = nn.Conv(features=feat, kernel_size=(kernel, ), strides= (stride, ), padding= padding, kernel_init=default_init()) (x)
if i + 1 < len(self.features) or self.activate_final:
x = self.activations(x)
if self.dropout_rate is not None:
x = nn.Dropout(rate=self.dropout_rate)(
x, deterministic=not training)
x = x.reshape(batch, -1)
x = nn.Dense(self.after_flatten, kernel_init=default_init())(x)
return x
주의할점
1. convolution layer의 dimenstion을 결정하는 것은 class 명칭이 아니라 kernel_size 의 input 형태에 따라 달라짐.
i.e. nn.Conv의 parameter 중 kernel_size= (3,3)인 경우 2D (3,)인 경우 1D
2. features는 torch.nn.Conv 에서 out_channel과 동일, in_channel은 선언하지 않음.
3. Conv 연산이 기본 torch.nn.Conv 연산과 다름
i.e. nn.Conv의 input size: Batch, Feature, Seq vs torch.nn.Conv: Batch, Seq, Feature
4. __call__은 torch.nn.layer에서 forward와 동일 [애초에 forward는 callable member function으로 정의되어 있음]
'Code > 코드 이모저모' 카테고리의 다른 글
[코드 이모저모] Jax Device print error (0) | 2024.07.16 |
---|