Last updated: 2026-06-12
Shader Dispatch
Elementwise
Wrap the fused element-wise shader family used by the decode loop.
This helper loads the RMS norm, SwiGLU, and RoPE pipelines and records the push constants needed for their dispatches.
34 exports shown
struct
RmsNormPush
pub const RmsNormPush = extern struct Push constants for RMS norm shader.
struct
SwigluPush
pub const SwigluPush = extern struct Push constants for SwiGLU shader.
struct
DeinterleavePush
pub const DeinterleavePush = extern struct Push constants for deinterleave shader.
struct
SigmoidMulPush
pub const SigmoidMulPush = extern struct Push constants for sigmoid multiply shader.
struct
ScaleAccPush
pub const ScaleAccPush = extern struct Push constants for scale-accumulate shader.
struct
BiasAddPush
pub const BiasAddPush = extern struct Push constants for bias add shader.
struct
RopePush
pub const RopePush = extern struct Push constants for RoPE shader (with partial rotation / IMRoPE support).
struct
RopeBatchedPush
pub const RopeBatchedPush = extern struct Push constants for rope_batched (multi-token prefill variant).
Layout mirrors src/shaders/rope_batched.comp.
struct
SsmConv1dPush
pub const SsmConv1dPush = extern struct Push constants for SSM conv1d + SiLU shader.
struct
SsmConv1dBatchedPush
pub const SsmConv1dBatchedPush = extern struct Push constants for the batched SSM conv1d shader.
struct
F32DualBatchPush
pub const F32DualBatchPush = extern struct Push constants for batched f32 dual DMMV (SSM alpha/beta).
struct
SsmDeltaNetPush
pub const SsmDeltaNetPush = extern struct Push constants for SSM delta-net state update shader.
struct
SsmQkNormPush
pub const SsmQkNormPush = extern struct Push constants for the SSM Q/K RMS-norm shader.
Drives the per-group normalization applied to query and key projections inside Mamba/SSM blocks.
struct
SsmGatedNormPush
pub const SsmGatedNormPush = extern struct Push constants for SSM gated norm shader.
struct
SoftmaxTopkPush
pub const SoftmaxTopkPush = extern struct Push constants for softmax + top-k MoE router shader.
struct
RouterF32BatchPush
pub const RouterF32BatchPush = extern struct Push constants for token-batched f32 router matvec.
struct
RmsNormScaleDmmvF32BatchPush
pub const RmsNormScaleDmmvF32BatchPush = extern struct Push constants for token-batched Gemma router RMS norm + scale + f32 DMMV.
struct
SoftmaxTopkBatchPush
pub const SoftmaxTopkBatchPush = extern struct Push constants for token-batched MoE top-k.
struct
MoeWeightedAccPush
pub const MoeWeightedAccPush = extern struct Push constants for batched MoE weighted accumulate shader.
Sums all expert outputs at once: a[i] = sum_j(weight_j * b[j*src_stride+i]).
struct
MoeWeightedAccBatchPush
pub const MoeWeightedAccBatchPush = extern struct Push constants for the **batched** MoE weighted-accumulate shader.
Sums each token's `n_used` selected-expert outputs across a token batch in one dispatch: `a[t,i] = sum_j(weight_{t,j} * b[...])`.
struct
SigmoidScaleAccBatchPush
pub const SigmoidScaleAccBatchPush = extern struct Push constants for the **batched** `sigmoid_scale_acc` shader.
Applies a per-token sigmoid-gated shared-expert add across a token batch: `accum[t,i] += sigmoid(gate_t) * src[t,i]`.
struct
KvCacheWritePush
pub const KvCacheWritePush = extern struct Push constants for KV cache write compute shader.
struct
KvCacheWriteBatchedPush
pub const KvCacheWriteBatchedPush = extern struct Push constants for batched KV cache write (prefillBatched path).
Matches src/shaders/kv_cache_write_batched.comp.
struct
ResidualRmsNormPush
pub const ResidualRmsNormPush = extern struct Push constants for fused residual-add + RMS norm (src/shaders/residual_rms_norm.comp).
One dispatch per `n_tokens` workgroups replaces a scale_accumulate → barrier → rms_norm_mul chain.
struct
PostNormResidualRmsNormPush
pub const PostNormResidualRmsNormPush = extern struct Push constants for fused post-norm + residual-add + RMS norm (src/shaders/post_norm_residual_rms_norm.comp).
One dispatch replaces Gemma's post_attention_norm -> barrier -> residual_rms_norm sequence.
struct
ResidualRmsNormQuantQ8_1Push
pub const ResidualRmsNormQuantQ8_1Push = extern struct Push constants for fused residual-add + RMS norm + Q8_1 activation quantize (src/shaders/residual_rms_norm_quant_q8_1.comp).
Same residual/RMS math as ResidualRmsNormPush, but also emits packed int8 lanes + (scale, dsum) so the downstream Qwen3.6-27B dense FFN DP4a gate+up GEMM can skip its separate quantize_act_q8_1 dispatch.
struct
NormRopePush
pub const NormRopePush = extern struct Push constants for fused RMS norm + RoPE shader.
struct
RmsNormAddPush
pub const RmsNormAddPush = extern struct Push constants for fused rmsnorm(src) + hidden accumulate shader (src/shaders/rms_norm_add.comp).
Used by Gemma prefillBatched to fold post_ffw_norm + residual add into one dispatch.
struct
RmsNormDmmvF32Push
pub const RmsNormDmmvF32Push = extern struct Push constants for fused RMS norm + f32 router DMMV shader (src/shaders/rms_norm_dmmv_f32.comp).
Folds the per-MoE-layer rms_norm_mul → router DMMV pair into a single dispatch on architectures whose router weights are f32 (Qwen 3.5/3.6 etc).
struct
RmsNormDmmvQ4kAlphaBetaPush
pub const RmsNormDmmvQ4kAlphaBetaPush = extern struct Push constants for fused RMS norm + Q4_K alpha+beta SSM proj DMMV (src/shaders/rms_norm_dmmv_q4k_alpha_beta.comp).
Folds the per-SSM-layer (rms_norm_mul → alpha DMMV → beta DMMV) trio into a single dispatch on the qwen35moe / qwen36moe SSM proj fast path.
struct
QkNormRopeKvWritePush
pub const QkNormRopeKvWritePush = extern struct Push constants for fused Q+K norm + RoPE + KV cache write shader (src/shaders/qk_norm_rope_kv_write.comp).
Folds the per-attention-layer (Q norm+rope → K norm+rope → kv_cache_write) trio on Qwen 3 family dense attention into a single dispatch.
struct
QkNormRopeKvWriteBatchedPush
pub const QkNormRopeKvWriteBatchedPush = extern struct Batched SWA variant of QkNormRopeKvWritePush.
Binding 4 is the KV page table, so this variant computes RoPE frequencies from freq_base_bits instead of reading a frequency buffer.
struct
KNormRopeKvWriteBatchedPush
pub const KNormRopeKvWriteBatchedPush = extern struct Batched full-attention K/V sibling used when Q must keep the precomputed RoPE frequency buffer binding.
It fuses K RMS norm, K RoPE, optional V unit norm, and paged KV cache write. Q norm/RoPE stays on the existing path.
struct
ElementwiseDispatch
pub const ElementwiseDispatch = struct Manages element-wise fused kernel pipelines.
Methods
24method
ElementwiseDispatch.init
pub fn init( instance: *const Instance, shader_dir: []const u8, allocator: std.mem.Allocator, ) !ElementwiseDispatch Create the fused element-wise dispatch wrapper and load its shaders.
method
ElementwiseDispatch.recordRmsNorm
pub fn recordRmsNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, hidden_dim: u32, n_tokens: u32, eps: f32, ) !void Record an RMS-norm-plus-scale dispatch for a batch of tokens.
This binds the fused normalization shader used before attention and MLP projections so each token is normalized against its hidden dimension.
method
ElementwiseDispatch.recordSwiglu
pub fn recordSwiglu( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record a SwiGLU activation dispatch.
method
ElementwiseDispatch.recordSwigluOai
pub fn recordSwigluOai( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record a GPT-OSS / OAI-variant SwiGLU activation dispatch.
Uses the same 3-binding layout as `recordSwiglu` (gate, up → output) but selects the swiglu_oai shader whose activation function matches gpt-oss.
method
ElementwiseDispatch.recordBiasAdd
pub fn recordBiasAdd( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, src_offset: u32, ) !void Record an in-place bias add dispatch: `out[i] += bias[src_offset + i]`.
method
ElementwiseDispatch.recordGeglu
pub fn recordGeglu( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record a GEGLU activation dispatch (GELU-gated, used by Gemma).
Same buffer layout as SwiGLU: gate, up → output.
method
ElementwiseDispatch.recordRope
pub fn recordRope( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, stride: u32, rope_dim: u32, n_heads: u32, position: u32, freq_base: f32, attn_scale: f32, ) !void Record a RoPE dispatch with partial rotation support (IMRoPE).
Rotates the first `rope_dim` dimensions of each attention head at the given sequence position; the remaining `stride - rope_dim` dimensions are copied unchanged, enabling interleaved-masked (IMRoPE) layouts.
method
ElementwiseDispatch.recordRoPEBatched
pub fn recordRoPEBatched( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, stride: u32, rope_dim: u32, n_heads: u32, position_base: u32, n_tokens: u32, freq_base: f32, attn_scale: f32, ) !void Record a batched RoPE dispatch that rotates N tokens at consecutive positions [position_base, position_base + n_tokens) in a single call.
Grid is (n_heads, n_tokens, 1); each (head, token) workgroup rotates `rope_dim` elements of the token's head slice.
method
ElementwiseDispatch.recordDeinterleave
pub fn recordDeinterleave( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, ) !void Record a deinterleave dispatch: split element-interleaved Q+gate into separate buffers.
method
ElementwiseDispatch.recordDeinterleaveBatched
pub fn recordDeinterleaveBatched( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_tokens: u32, ) !void Record a token-batched deinterleave dispatch.
Splits each token's packed `[Q(head_dim), gate(head_dim)]` interleaved per-head layout into separate Q and gate output buffers in one dispatch. Grid is `(ceil(head_dim * n_heads / 64), n_tokens, 1)`.
method
ElementwiseDispatch.recordSigmoidMul
pub fn recordSigmoidMul( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record a sigmoid multiply dispatch: out = input * sigmoid(gate).
method
ElementwiseDispatch.recordVadd
pub fn recordVadd( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record a vector add dispatch: c = a + b.
method
ElementwiseDispatch.recordScaleAcc
pub fn recordScaleAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, scale: f32, ) !void Record a scale-accumulate dispatch: a[i] += scale * b[i].
Vec4-coalesced: each thread handles one vec4 (4 f32 elements). Caller must pass n_elements divisible by 4; every in-tree caller already does.
method
ElementwiseDispatch.recordScaleInPlace
pub fn recordScaleInPlace( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, scale: f32, ) !void Record an in-place element-wise scale dispatch: `data[i] *= scale`.
method
ElementwiseDispatch.recordSsmConv1d
pub fn recordSsmConv1d( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, conv_channels: u32, d_conv: u32, kernel_is_f16: bool, state_offset: u32, ) !void Record a single-token SSM depthwise conv1d + SiLU dispatch.
Reads the current SSM conv state via `state_offset` (a rotating index into the circular state buffer), applies a depthwise conv kernel of width `d_conv`, and writes the SiLU-activated output in-place.
method
ElementwiseDispatch.recordSsmQkNorm
pub fn recordSsmQkNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, d_state: u32, n_group: u32, ) !void Record an in-place SSM Q/K RMS-norm dispatch.
Applies per-group RMS normalization to the concatenated Q and K projections inside a Mamba/DeltaNet SSM block; dispatches one workgroup per group.
method
ElementwiseDispatch.recordSsmDeltaNet
pub fn recordSsmDeltaNet( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void Record an SSM DeltaNet state-update dispatch (baseline variant).
Executes the DeltaNet recurrence over a single token (or `push.n_tok` prefill tokens when n_tok > 1). Grid is `(dt_rank, head_v_dim, 1)` — one wave64 workgroup per (head, row) pair.
method
ElementwiseDispatch.recordSsmDeltaNetCols8
pub fn recordSsmDeltaNetCols8( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void Record an SSM DeltaNet state-update dispatch using the cols8 tiled variant.
Each wave64 workgroup processes four output rows (head_v_dim / 4 workgroups per head), improving register reuse relative to the baseline 1-row shader.
method
ElementwiseDispatch.recordSsmDeltaNetCols8Normed
pub fn recordSsmDeltaNetCols8Normed( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void Record an SSM DeltaNet state-update dispatch using the cols8 normed variant.
Identical semantics to `recordSsmDeltaNetCols8` but selects the shader that expects Q/K inputs to be pre-normalized (skipping the in-shader norm). Each wave64 workgroup processes eight output rows (head_v_dim / 8 workgroups per head).
method
ElementwiseDispatch.recordSsmGatedNorm
pub fn recordSsmGatedNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmGatedNormPush, ) !void Record an SSM gated norm dispatch: applies z-gate * RMS-norm(delta_output).
Dispatches one wave64 workgroup per head (`push.dt_rank` workgroups total).
method
ElementwiseDispatch.recordSoftmaxTopk
pub fn recordSoftmaxTopk( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_experts: u32, k: u32, ) !void Record softmax + top-k MoE router dispatch.
method
ElementwiseDispatch.recordSigmoidScaleAcc
pub fn recordSigmoidScaleAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void Record sigmoid-gated scale-accumulate: a[i] += sigmoid(c[0]) * b[i].
method
ElementwiseDispatch.recordMoeWeightedAcc
pub fn recordMoeWeightedAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, n_used: u32, src_stride: u32, ) !void Record a MoE weighted accumulate dispatch: `a[i] += routing_weight[j] * b[j*src_stride + i]` summed over `n_used` selected experts.
Routing weights are read from the GPU routing buffer (binding 2), not from a push constant.
method
ElementwiseDispatch.deinit
pub fn deinit(self: *ElementwiseDispatch) void Destroy the loaded pipelines and descriptor pool.