Skip to content

Spiking Neural Networks

Grilly provides a complete SNN framework with neuron models, surrogate gradients, temporal containers, ANN-to-SNN conversion, and specialized normalization layers.


Neuron Models

IFNode (Integrate-and-Fire)

from grilly import nn

node = nn.IFNode(v_threshold=1.0, surrogate_function=nn.ATan())
output = node(x)  # x: input current, output: binary spikes

Simplest spiking neuron. Integrates input, fires when threshold is reached, then resets.

LIFNode (Leaky Integrate-and-Fire)

node = nn.LIFNode(tau=2.0, v_threshold=1.0, surrogate_function=nn.ATan())
output = node(x)

Adds exponential membrane potential decay (leak). The tau parameter controls the time constant.

ParametricLIFNode

node = nn.ParametricLIFNode(init_tau=2.0, surrogate_function=nn.ATan())
output = node(x)

Like LIFNode but with a learnable time constant. The network learns the optimal decay rate.


Surrogate Gradients

Spiking neurons use a step function (non-differentiable). Surrogate gradient functions provide smooth approximations for backpropagation:

atan = nn.ATan(alpha=2.0)         # Arctangent surrogate
sig = nn.Sigmoid(alpha=4.0)       # Sigmoid surrogate
fast = nn.FastSigmoid(alpha=2.0)  # Piecewise linear approximation

All neuron nodes accept a surrogate_function parameter.


Temporal Containers

MultiStepContainer

Runs a module over multiple time steps:

container = nn.MultiStepContainer(nn.LIFNode(tau=2.0))
# input: (T, batch, ...) -> output: (T, batch, ...)
output = container(x_temporal)

SeqToANNContainer

Wraps an ANN layer (e.g., Conv2d) to process temporal data by merging the time and batch dimensions:

container = nn.SeqToANNContainer(nn.Conv2d(3, 64, 3, padding=1))
# Reshapes (T, B, C, H, W) -> (T*B, C, H, W), applies conv, reshapes back
output = container(x_temporal)

SNN Normalization

Specialized batch normalization layers that account for temporal dynamics:

Layer Description
BatchNormThroughTime1d/2d BN computed across the time dimension
TemporalEffectiveBatchNorm1d/2d Time-weighted batch normalization
ThresholdDependentBatchNorm1d/2d BN scaled by firing threshold
NeuNorm Neuron-level normalization

Synapse Models

from grilly import nn

# Basic synapse with delay
synapse = nn.SynapseFilter(tau_syn=5.0)

# Short-term plasticity
stp = nn.STPSynapse(tau_d=200.0, tau_f=20.0, U=0.2)

# Dual timescale (fast + slow)
dual = nn.DualTimescaleSynapse(tau_fast=5.0, tau_slow=50.0)

SNN Attention

# Temporal-wise attention for SNN
twa = nn.TemporalWiseAttention(T=10, reduction=4)

# Spiking self-attention
ssa = nn.SpikingSelfAttention(embed_dim=256, num_heads=4)

# QK attention variants
qk = nn.QKAttention(embed_dim=256, num_heads=4)

ANN-to-SNN Conversion

Convert a trained ANN to an SNN:

from grilly.nn import Converter, VoltageScaler

# Convert ReLU model to spiking equivalent
converter = Converter(mode="max_norm")
snn_model = converter(ann_model)

# Scale voltage thresholds
scaler = VoltageScaler(snn_model)
scaler.scale_thresholds(calibration_data)

Legacy SNN Layers

The original flat-array SNN layers are still available:

lif = nn.LIFNeuron(threshold=1.0, decay=0.9)
snn_layer = nn.SNNLayer(in_features=256, out_features=128)
hebbian = nn.HebbianLayer(in_features=256, out_features=128)
stdp = nn.STDPLayer(in_features=256, out_features=128)

These use the 19 SNN compute shaders (snn-*.glsl) via the VulkanCompute.snn namespace.


Functional API

import grilly.functional as F

# Single LIF step
spike, v = F.lif_step(x, v_prev, tau=2.0, v_threshold=1.0)

# Multi-step forward
output = F.multi_step_forward(model, x_temporal, T=10)

# Reset all neuron states
F.reset_net(model)