Skip to content

API Reference: functional

grilly.functional provides stateless operations that mirror torch.nn.functional.

Source: functional/


Activations

Function Signature Description
relu relu(x) Rectified linear unit.
gelu gelu(x) Gaussian error linear unit.
silu silu(x) Sigmoid linear unit (Swish).
softmax softmax(x, dim=-1) Softmax normalization.
softplus softplus(x) Smooth ReLU approximation.

Linear

Function Signature Description
linear linear(x, weight, bias=None) Affine transformation: x @ weight.T + bias.

Attention

Function Signature Description
attention attention(q, k, v, mask=None) Standard scaled dot-product attention.
flash_attention2 flash_attention2(q, k, v) Memory-efficient Flash Attention 2.

Normalization

Function Signature Description
layer_norm layer_norm(x, weight, bias, eps=1e-5) Layer normalization.

Dropout

Function Signature Description
dropout dropout(x, p=0.5, training=True) Randomly zero elements.

Loss Functions

Function Signature Description
cross_entropy cross_entropy(logits, targets) Cross-entropy loss.
binary_cross_entropy binary_cross_entropy(predictions, targets) Binary cross-entropy.

FFT

Function Signature Description
fft fft(x) Fast Fourier Transform.
ifft ifft(x) Inverse FFT.
fft_magnitude fft_magnitude(x) Magnitude of FFT output.
fft_power_spectrum fft_power_spectrum(x) Power spectrum from FFT.

Memory Operations

Function Signature Description
memory_read memory_read(memory, query) Read from memory via attention.
memory_write memory_write(memory, key, value) Write to memory.
memory_context_aggregate memory_context_aggregate(contexts) Aggregate memory contexts.
memory_query_pooling memory_query_pooling(queries) Pool memory queries.
memory_inject_concat memory_inject_concat(x, memory) Inject memory via concatenation.
memory_inject_gate memory_inject_gate(x, memory) Inject memory via gating.
memory_inject_residual memory_inject_residual(x, memory) Inject memory via residual.

Cell Operations

Function Signature Description
place_cell place_cell(position, centers) Place cell activation (spatial).
time_cell time_cell(time, centers) Time cell activation (temporal).
theta_gamma_encoding theta_gamma_encoding(phase, frequency) Theta-gamma neural encoding.

Learning Operations

Function Signature Description
fisher_info fisher_info(grads) Compute Fisher information matrix.
ewc_penalty ewc_penalty(params, fisher, old_params) Elastic Weight Consolidation penalty.
natural_gradient natural_gradient(grad, fisher) Natural gradient via Fisher.
nlms_predict nlms_predict(weights, x) NLMS prediction.
nlms_update nlms_update(weights, x, error, mu) NLMS weight update.
whitening_transform whitening_transform(x) Compute whitening transform.
whitening_apply whitening_apply(x, transform) Apply whitening.

Bridge Operations

Function Signature Description
continuous_to_spikes continuous_to_spikes(x, threshold) Convert continuous signal to spikes.
spikes_to_continuous spikes_to_continuous(spikes, decay) Decode spikes to continuous signal.
bridge_temporal_weights bridge_temporal_weights(t, decay) Temporal weighting for bridge.

Embedding Operations

Function Signature Description
embedding_lookup embedding_lookup(table, indices) Lookup embeddings by index.
embedding_normalize embedding_normalize(embeddings) L2-normalize embeddings.
embedding_position embedding_position(embeddings, positions) Add positional embeddings.
embedding_pool embedding_pool(embeddings, method) Pool embeddings.
embedding_ffn embedding_ffn(embeddings, weights) FFN on embeddings.
embedding_attention embedding_attention(q, k, v) Attention over embeddings.

FAISS Operations

Function Signature Description
faiss_distance faiss_distance(queries, keys) Compute distances.
faiss_topk faiss_topk(distances, k) Top-k selection.
faiss_ivf_filter faiss_ivf_filter(queries, centroids) IVF coarse filtering.
faiss_kmeans_update faiss_kmeans_update(centroids, data) K-means centroid update.
faiss_quantize faiss_quantize(data, codebook) Product quantization.

SNN Operations

Function Signature Description
lif_step lif_step(x, v, tau, v_threshold) Single LIF neuron step. Returns (spike, v_new).
if_step if_step(x, v, v_threshold) Single IF neuron step.
multi_step_forward multi_step_forward(module, x, T) Run module over T time steps.
seq_to_ann_forward seq_to_ann_forward(module, x) Apply ANN module to temporal data.
reset_net reset_net(module) Reset all neuron membrane potentials.
set_step_mode set_step_mode(module, mode) Set single-step or multi-step mode.