Skip to content

Neural Network Modules

All layers inherit from nn.Module, providing parameters(), train()/eval(), state_dict(), and load_state_dict().


Linear Layers

from grilly import nn

linear = nn.Linear(in_features=784, out_features=256)
output = linear(x)  # x: (batch, 784) -> output: (batch, 256)

nn.Embedding maps integer indices to dense vectors:

embed = nn.Embedding(num_embeddings=10000, embedding_dim=512)
output = embed(token_ids)  # token_ids: (batch, seq_len) -> (batch, seq_len, 512)

Activations

Module Description
nn.ReLU() Rectified linear unit
nn.GELU() Gaussian error linear unit
nn.SiLU() Sigmoid linear unit (Swish)
nn.SwiGLU() SwiGLU gated activation
nn.GCU() Growing cosine unit
nn.RoSwish() Rotary Swish
nn.Softmax() Softmax normalization
nn.Softplus() Smooth approximation to ReLU
model = nn.Sequential(
    nn.Linear(512, 512),
    nn.GELU(),
    nn.Linear(512, 256),
)

Convolution

conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
output = conv(x)  # x: (batch, 3, H, W) -> output: (batch, 64, H, W)

Available: nn.Conv1d, nn.Conv2d.


Normalization

ln = nn.LayerNorm(normalized_shape=512)
rms = nn.RMSNorm(normalized_shape=512)
bn1 = nn.BatchNorm1d(num_features=256)
bn2 = nn.BatchNorm2d(num_features=64)

Attention

# Standard multi-head attention
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8)
output = mha(query, key, value)

# Flash Attention 2 (memory-efficient)
fa = nn.FlashAttention2(embed_dim=512, num_heads=8)
output = fa(query, key, value)

See Flash Attention and HYLA Attention for advanced attention mechanisms.


Recurrent Layers

lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2)
gru = nn.GRU(input_size=256, hidden_size=512)

# Cell-level access
lstm_cell = nn.LSTMCell(input_size=256, hidden_size=512)
gru_cell = nn.GRUCell(input_size=256, hidden_size=512)

Pooling

pool = nn.MaxPool2d(kernel_size=2, stride=2)
avg = nn.AvgPool2d(kernel_size=2, stride=2)
adaptive = nn.AdaptiveMaxPool2d(output_size=(1, 1))

Loss Functions

mse = nn.MSELoss()
ce = nn.CrossEntropyLoss()
bce = nn.BCELoss()

loss = ce(logits, targets)
grad = ce.backward(np.ones_like(loss), logits, targets)

Containers

# Sequential: chain modules
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)

# Residual: skip connection
block = nn.Residual(
    nn.Sequential(
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
    )
)

Dropout

drop = nn.Dropout(p=0.1)
output = drop(x)  # Zeros elements with probability 0.1 during training

Transformer Layers

encoder = nn.TransformerEncoderLayer(d_model=512, nhead=8)
decoder = nn.TransformerDecoderLayer(d_model=512, nhead=8)

# With rotary position embeddings
rope = nn.RoPE(dim=64, max_seq_len=2048)

LoRA (Low-Rank Adaptation)

from grilly.nn import LoRALinear, LoRAModel, LoRAConfig

config = LoRAConfig(r=8, alpha=16, dropout=0.1)
lora_linear = LoRALinear(base_linear, config)

# Or apply LoRA to an entire model
lora_model = LoRAModel(model, config, target_modules=["linear"])

Multimodal Fusion

from grilly.nn import PerceiverIO, CrossModalAttentionFusion, FlamingoFusion

# Perceiver IO for arbitrary input modalities
perceiver = PerceiverIO(input_dim=768, latent_dim=512, num_latents=64)

# Cross-modal attention
fusion = CrossModalAttentionFusion(dim=512, num_heads=8)

Available: BottleneckFusion, CrossModalAttentionFusion, FlamingoFusion, ImageBindFusion, PerceiverIO, PerceiverResampler, VisionLanguageModel.


Full Module List

See API Reference: nn for the complete list of all exported classes and functions.