Skip to content

SympFormer

SympFormerBlock applies symplectic (momentum-based) integration to transformer attention, treating the attention update as a Hamiltonian system.


Usage

from grilly import nn

block = nn.SympFormerBlock(
    embed_dim=512,
    num_heads=8,
    h_x_init=0.1,   # Position step size
    h_y_init=0.1,   # Momentum step size
    c_log=3.0,       # Logarithmic damping
    c_lin=0.1,       # Linear damping
)

output = block(x)  # x: (batch, seq_len, embed_dim)

How It Works

Standard transformer layers apply a residual update:

x = x + Attention(x)

SympFormer treats this as a symplectic integrator. It maintains a "momentum" state y alongside the position x:

y_{n+1} = y_n - h_y * (c_log * y_n / (1 + |y_n|) + c_lin * y_n) + h_y * grad_x H
x_{n+1} = x_n + h_x * y_{n+1}

The damping terms (c_log, c_lin) provide stable convergence. The Hamiltonian H is defined by the attention function.

Benefits:

  • Better convergence on long training runs
  • Built-in momentum prevents oscillation
  • Symplectic structure preserves energy (stable dynamics)

Selective Momentum

SelectiveMomentumAttention combines SympFormer with TAPPA q-similarity to apply momentum only where it helps:

from grilly.nn.selective_momentum import SelectiveMomentumAttention

sel_attn = SelectiveMomentumAttention(
    attention=nn.MultiheadAttention(512, 8),
    embed_dim=512,
    num_heads=8,
    threshold=0.5,  # q-similarity threshold
)

output = sel_attn(query, key, value)

Heads with q-similarity below the threshold (unpredictable, retrieval heads) get momentum acceleration. Predictable heads skip the overhead. This gives SympFormer's convergence benefit at roughly 1x wall-clock cost instead of 2.4x.


Parameters

Param Default Description
embed_dim required Model dimension
num_heads required Number of attention heads
h_x_init 0.1 Position step size (learnable)
h_y_init 0.1 Momentum step size (learnable)
c_log 3.0 Logarithmic damping coefficient
c_lin 0.1 Linear damping coefficient