Automatic Mixed Precision (AMP)
AMP uses float16 compute with float32 accumulation for eligible ops, reducing memory usage and improving throughput. Mirrors the PyTorch torch.cuda.amp API.
Basic Usage
from grilly.backend.amp import autocast, GradScaler
scaler = GradScaler()
for x_batch, y_batch in train_loader:
with autocast():
output = model(x_batch)
loss = criterion(output, y_batch)
scaled_loss = scaler.scale(loss)
# backward pass
grad = criterion.backward(np.ones_like(scaled_loss), output, y_batch)
model.zero_grad()
model.backward(grad)
scaler.step(optimizer)
scaler.update()
autocast
Context manager that enables mixed precision for eligible operations:
from grilly.backend.amp import autocast, is_autocast_enabled
with autocast():
assert is_autocast_enabled()
output = model(x) # Eligible ops use float16 compute
# Outside the context, autocast is off
assert not is_autocast_enabled()
Eligible operations: linear, conv2d, attention. These are flagged to use float16 compute paths in the bridge when autocast is active.
Non-eligible operations: loss computation, normalization, softmax. These stay in float32 for numerical stability.
Autocast contexts can be nested:
with autocast():
h = model.encoder(x)
with autocast(enabled=False):
# Force float32 for a specific section
loss = loss_fn(h, targets)
GradScaler
Scales loss values to prevent float16 gradient underflow:
from grilly.backend.amp import GradScaler
scaler = GradScaler(
init_scale=65536.0, # Initial loss scale (2^16)
growth_factor=2.0, # Scale up when no overflow
backoff_factor=0.5, # Scale down on overflow
growth_interval=2000, # Steps between growth attempts
)
Methods:
| Method | Description |
|---|---|
scaler.scale(loss) |
Multiply loss by current scale factor |
scaler.step(optimizer) |
Unscale gradients, check for overflow, then optimizer.step() |
scaler.update() |
Update the scale factor based on overflow history |
How It Works
- autocast sets a thread-local flag via
is_autocast_enabled() - Bridge functions check this flag and dispatch float16 compute paths when safe
- GradScaler multiplies loss by a large scale factor before backward
- Before the optimizer step, gradients are unscaled by dividing by the same factor
- If overflow (inf/nan) is detected, the step is skipped and the scale is reduced
- If no overflow for
growth_intervalsteps, the scale is increased
This matches PyTorch's AMP behavior exactly.