Last updated: 2026-06-12

Shader Dispatch

Elementwise

All API Sections

Wrap the fused element-wise shader family used by the decode loop.

This helper loads the RMS norm, SwiGLU, and RoPE pipelines and records the push constants needed for their dispatches.

34 exports 24 methods src/compute/elementwise.zig

34 exports shown

struct

RmsNormPush

#
pub const RmsNormPush = extern struct

Push constants for RMS norm shader.

src/compute/elementwise.zig:17

struct

SwigluPush

#
pub const SwigluPush = extern struct

Push constants for SwiGLU shader.

src/compute/elementwise.zig:23

struct

DeinterleavePush

#
pub const DeinterleavePush = extern struct

Push constants for deinterleave shader.

src/compute/elementwise.zig:33

struct

SigmoidMulPush

#
pub const SigmoidMulPush = extern struct

Push constants for sigmoid multiply shader.

src/compute/elementwise.zig:39

struct

ScaleAccPush

#
pub const ScaleAccPush = extern struct

Push constants for scale-accumulate shader.

src/compute/elementwise.zig:44

struct

BiasAddPush

#
pub const BiasAddPush = extern struct

Push constants for bias add shader.

src/compute/elementwise.zig:50

struct

RopePush

#
pub const RopePush = extern struct

Push constants for RoPE shader (with partial rotation / IMRoPE support).

src/compute/elementwise.zig:56

struct

RopeBatchedPush

#
pub const RopeBatchedPush = extern struct

Push constants for rope_batched (multi-token prefill variant).

Layout mirrors src/shaders/rope_batched.comp.

src/compute/elementwise.zig:67

struct

SsmConv1dPush

#
pub const SsmConv1dPush = extern struct

Push constants for SSM conv1d + SiLU shader.

src/compute/elementwise.zig:77

struct

SsmConv1dBatchedPush

#
pub const SsmConv1dBatchedPush = extern struct

Push constants for the batched SSM conv1d shader.

src/compute/elementwise.zig:87

struct

F32DualBatchPush

#
pub const F32DualBatchPush = extern struct

Push constants for batched f32 dual DMMV (SSM alpha/beta).

src/compute/elementwise.zig:96

struct

SsmDeltaNetPush

#
pub const SsmDeltaNetPush = extern struct

Push constants for SSM delta-net state update shader.

src/compute/elementwise.zig:104

struct

SsmQkNormPush

#
pub const SsmQkNormPush = extern struct

Push constants for the SSM Q/K RMS-norm shader.

Drives the per-group normalization applied to query and key projections inside Mamba/SSM blocks.

src/compute/elementwise.zig:126

struct

SsmGatedNormPush

#
pub const SsmGatedNormPush = extern struct

Push constants for SSM gated norm shader.

src/compute/elementwise.zig:133

struct

SoftmaxTopkPush

#
pub const SoftmaxTopkPush = extern struct

Push constants for softmax + top-k MoE router shader.

src/compute/elementwise.zig:142

struct

RouterF32BatchPush

#
pub const RouterF32BatchPush = extern struct

Push constants for token-batched f32 router matvec.

src/compute/elementwise.zig:151

struct

RmsNormScaleDmmvF32BatchPush

#
pub const RmsNormScaleDmmvF32BatchPush = extern struct

Push constants for token-batched Gemma router RMS norm + scale + f32 DMMV.

src/compute/elementwise.zig:160

struct

SoftmaxTopkBatchPush

#
pub const SoftmaxTopkBatchPush = extern struct

Push constants for token-batched MoE top-k.

src/compute/elementwise.zig:168

struct

MoeWeightedAccPush

#
pub const MoeWeightedAccPush = extern struct

Push constants for batched MoE weighted accumulate shader.

Sums all expert outputs at once: a[i] = sum_j(weight_j * b[j*src_stride+i]).

src/compute/elementwise.zig:180

struct

MoeWeightedAccBatchPush

#
pub const MoeWeightedAccBatchPush = extern struct

Push constants for the **batched** MoE weighted-accumulate shader.

Sums each token's `n_used` selected-expert outputs across a token batch in one dispatch: `a[t,i] = sum_j(weight_{t,j} * b[...])`.

src/compute/elementwise.zig:189

struct

SigmoidScaleAccBatchPush

#
pub const SigmoidScaleAccBatchPush = extern struct

Push constants for the **batched** `sigmoid_scale_acc` shader.

Applies a per-token sigmoid-gated shared-expert add across a token batch: `accum[t,i] += sigmoid(gate_t) * src[t,i]`.

src/compute/elementwise.zig:201

struct

KvCacheWritePush

#
pub const KvCacheWritePush = extern struct

Push constants for KV cache write compute shader.

src/compute/elementwise.zig:208

struct

KvCacheWriteBatchedPush

#
pub const KvCacheWriteBatchedPush = extern struct

Push constants for batched KV cache write (prefillBatched path).

Matches src/shaders/kv_cache_write_batched.comp.

src/compute/elementwise.zig:215

struct

ResidualRmsNormPush

#
pub const ResidualRmsNormPush = extern struct

Push constants for fused residual-add + RMS norm (src/shaders/residual_rms_norm.comp).

One dispatch per `n_tokens` workgroups replaces a scale_accumulate → barrier → rms_norm_mul chain.

src/compute/elementwise.zig:225

struct

PostNormResidualRmsNormPush

#
pub const PostNormResidualRmsNormPush = extern struct

Push constants for fused post-norm + residual-add + RMS norm (src/shaders/post_norm_residual_rms_norm.comp).

One dispatch replaces Gemma's post_attention_norm -> barrier -> residual_rms_norm sequence.

src/compute/elementwise.zig:234

struct

ResidualRmsNormQuantQ8_1Push

#
pub const ResidualRmsNormQuantQ8_1Push = extern struct

Push constants for fused residual-add + RMS norm + Q8_1 activation quantize (src/shaders/residual_rms_norm_quant_q8_1.comp).

Same residual/RMS math as ResidualRmsNormPush, but also emits packed int8 lanes + (scale, dsum) so the downstream Qwen3.6-27B dense FFN DP4a gate+up GEMM can skip its separate quantize_act_q8_1 dispatch.

src/compute/elementwise.zig:244

struct

NormRopePush

#
pub const NormRopePush = extern struct

Push constants for fused RMS norm + RoPE shader.

src/compute/elementwise.zig:254

struct

RmsNormAddPush

#
pub const RmsNormAddPush = extern struct

Push constants for fused rmsnorm(src) + hidden accumulate shader (src/shaders/rms_norm_add.comp).

Used by Gemma prefillBatched to fold post_ffw_norm + residual add into one dispatch.

src/compute/elementwise.zig:267

struct

RmsNormDmmvF32Push

#
pub const RmsNormDmmvF32Push = extern struct

Push constants for fused RMS norm + f32 router DMMV shader (src/shaders/rms_norm_dmmv_f32.comp).

Folds the per-MoE-layer rms_norm_mul → router DMMV pair into a single dispatch on architectures whose router weights are f32 (Qwen 3.5/3.6 etc).

src/compute/elementwise.zig:276

struct

RmsNormDmmvQ4kAlphaBetaPush

#
pub const RmsNormDmmvQ4kAlphaBetaPush = extern struct

Push constants for fused RMS norm + Q4_K alpha+beta SSM proj DMMV (src/shaders/rms_norm_dmmv_q4k_alpha_beta.comp).

Folds the per-SSM-layer (rms_norm_mul → alpha DMMV → beta DMMV) trio into a single dispatch on the qwen35moe / qwen36moe SSM proj fast path.

src/compute/elementwise.zig:286

struct

QkNormRopeKvWritePush

#
pub const QkNormRopeKvWritePush = extern struct

Push constants for fused Q+K norm + RoPE + KV cache write shader (src/shaders/qk_norm_rope_kv_write.comp).

Folds the per-attention-layer (Q norm+rope → K norm+rope → kv_cache_write) trio on Qwen 3 family dense attention into a single dispatch.

src/compute/elementwise.zig:296

struct

QkNormRopeKvWriteBatchedPush

#
pub const QkNormRopeKvWriteBatchedPush = extern struct

Batched SWA variant of QkNormRopeKvWritePush.

Binding 4 is the KV page table, so this variant computes RoPE frequencies from freq_base_bits instead of reading a frequency buffer.

src/compute/elementwise.zig:312

struct

KNormRopeKvWriteBatchedPush

#
pub const KNormRopeKvWriteBatchedPush = extern struct

Batched full-attention K/V sibling used when Q must keep the precomputed RoPE frequency buffer binding.

It fuses K RMS norm, K RoPE, optional V unit norm, and paged KV cache write. Q norm/RoPE stays on the existing path.

src/compute/elementwise.zig:329

struct

ElementwiseDispatch

#
pub const ElementwiseDispatch = struct

Manages element-wise fused kernel pipelines.

src/compute/elementwise.zig:343

Methods

24

method

ElementwiseDispatch.init

#
pub fn init( instance: *const Instance, shader_dir: []const u8, allocator: std.mem.Allocator, ) !ElementwiseDispatch

Create the fused element-wise dispatch wrapper and load its shaders.

Parameters
instance
Active Vulkan instance and logical device.
shader_dir
Directory containing compiled SPIR-V shader binaries.
allocator
Allocator used for temporary pipeline creation state.
Returns

An ElementwiseDispatch ready to record element-wise passes.

src/compute/elementwise.zig:481

method

ElementwiseDispatch.recordRmsNorm

#
pub fn recordRmsNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, hidden_dim: u32, n_tokens: u32, eps: f32, ) !void

Record an RMS-norm-plus-scale dispatch for a batch of tokens.

This binds the fused normalization shader used before attention and MLP projections so each token is normalized against its hidden dimension.

Parameters
self
Dispatch wrapper containing the RMS norm pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set containing input, weight, and output buffers.
hidden_dim
Hidden width processed per token.
n_tokens
Number of tokens covered by the dispatch.
eps
Numerical stability epsilon passed to the shader.
Returns

`error.ShaderNotLoaded` when the RMS norm pipeline is unavailable.

Notes

The helper dispatches one workgroup per token.

src/compute/elementwise.zig:945

method

ElementwiseDispatch.recordSwiglu

#
pub fn recordSwiglu( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record a SwiGLU activation dispatch.

Parameters
self
Dispatch wrapper containing the SwiGLU pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set containing gate, up, and output buffers.
n_elements
Total number of output elements to compute.
Returns

`error.ShaderNotLoaded` when the SwiGLU pipeline is unavailable.

Notes

Workgroups are sized as `ceil(n_elements / 64)`.

src/compute/elementwise.zig:971

method

ElementwiseDispatch.recordSwigluOai

#
pub fn recordSwigluOai( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record a GPT-OSS / OAI-variant SwiGLU activation dispatch.

Uses the same 3-binding layout as `recordSwiglu` (gate, up → output) but selects the swiglu_oai shader whose activation function matches gpt-oss.

Parameters
self
Dispatch wrapper containing the OAI SwiGLU pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set containing gate, up, and output buffers.
n_elements
Total number of output elements to compute.
Returns

`error.ShaderNotLoaded` when the OAI SwiGLU pipeline is unavailable.

Notes

Workgroups are sized as `ceil(n_elements / 64)`.

src/compute/elementwise.zig:993

method

ElementwiseDispatch.recordBiasAdd

#
pub fn recordBiasAdd( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, src_offset: u32, ) !void

Record an in-place bias add dispatch: `out[i] += bias[src_offset + i]`.

Parameters
self
Dispatch wrapper containing the bias add pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with two bindings: output buffer (rw) and bias buffer (ro).
n_elements
Number of elements to update.
src_offset
Element offset into the bias buffer (allows a shared bias tensor to be sliced).
Returns

`error.ShaderNotLoaded` when the bias add pipeline is unavailable.

src/compute/elementwise.zig:1012

method

ElementwiseDispatch.recordGeglu

#
pub fn recordGeglu( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record a GEGLU activation dispatch (GELU-gated, used by Gemma).

Same buffer layout as SwiGLU: gate, up → output.

src/compute/elementwise.zig:1030

method

ElementwiseDispatch.recordRope

#
pub fn recordRope( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, stride: u32, rope_dim: u32, n_heads: u32, position: u32, freq_base: f32, attn_scale: f32, ) !void

Record a RoPE dispatch with partial rotation support (IMRoPE).

Rotates the first `rope_dim` dimensions of each attention head at the given sequence position; the remaining `stride - rope_dim` dimensions are copied unchanged, enabling interleaved-masked (IMRoPE) layouts.

Parameters
self
Dispatch wrapper containing the RoPE pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with three bindings: input, output, and freq buffer.
stride
Full head dimension in f32 elements (distance between heads in the buffer).
rope_dim
Number of dimensions to rotate (must be <= stride; pass stride for plain RoPE).
n_heads
Number of query heads to rotate; one workgroup is dispatched per head.
position
Current decode token position used to compute rotation angles.
freq_base
Base frequency for the sinusoidal schedule (e.g. 10000.0 for standard RoPE).
attn_scale
YaRN magnitude scale applied after rotation; use 1.0 for plain RoPE.
Returns

`error.ShaderNotLoaded` when the RoPE pipeline is unavailable.

src/compute/elementwise.zig:1056

method

ElementwiseDispatch.recordRoPEBatched

#
pub fn recordRoPEBatched( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, stride: u32, rope_dim: u32, n_heads: u32, position_base: u32, n_tokens: u32, freq_base: f32, attn_scale: f32, ) !void

Record a batched RoPE dispatch that rotates N tokens at consecutive positions [position_base, position_base + n_tokens) in a single call.

Grid is (n_heads, n_tokens, 1); each (head, token) workgroup rotates `rope_dim` elements of the token's head slice.

src/compute/elementwise.zig:1087

method

ElementwiseDispatch.recordDeinterleave

#
pub fn recordDeinterleave( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, ) !void

Record a deinterleave dispatch: split element-interleaved Q+gate into separate buffers.

src/compute/elementwise.zig:1112

method

ElementwiseDispatch.recordDeinterleaveBatched

#
pub fn recordDeinterleaveBatched( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_tokens: u32, ) !void

Record a token-batched deinterleave dispatch.

Splits each token's packed `[Q(head_dim), gate(head_dim)]` interleaved per-head layout into separate Q and gate output buffers in one dispatch. Grid is `(ceil(head_dim * n_heads / 64), n_tokens, 1)`.

Parameters
self
Dispatch wrapper containing the batched deinterleave pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 3 bindings: packed input, Q output, gate output.
head_dim
Per-head dimension in elements.
n_heads
Number of query heads per token.
n_tokens
Number of tokens to process (Y dimension of the dispatch grid).
Returns

`error.ShaderNotLoaded` when the batched deinterleave pipeline is unavailable.

src/compute/elementwise.zig:1140

method

ElementwiseDispatch.recordSigmoidMul

#
pub fn recordSigmoidMul( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record a sigmoid multiply dispatch: out = input * sigmoid(gate).

src/compute/elementwise.zig:1156

method

ElementwiseDispatch.recordVadd

#
pub fn recordVadd( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record a vector add dispatch: c = a + b.

src/compute/elementwise.zig:1170

method

ElementwiseDispatch.recordScaleAcc

#
pub fn recordScaleAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, scale: f32, ) !void

Record a scale-accumulate dispatch: a[i] += scale * b[i].

Vec4-coalesced: each thread handles one vec4 (4 f32 elements). Caller must pass n_elements divisible by 4; every in-tree caller already does.

src/compute/elementwise.zig:1186

method

ElementwiseDispatch.recordScaleInPlace

#
pub fn recordScaleInPlace( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, scale: f32, ) !void

Record an in-place element-wise scale dispatch: `data[i] *= scale`.

Parameters
self
Dispatch wrapper containing the scale-in-place pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with one binding: the buffer to scale in place.
n_elements
Number of f32 elements to scale.
scale
Scalar multiplier applied to every element.
Returns

`error.ShaderNotLoaded` when the scale-in-place pipeline is unavailable.

src/compute/elementwise.zig:1207

method

ElementwiseDispatch.recordSsmConv1d

#
pub fn recordSsmConv1d( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, conv_channels: u32, d_conv: u32, kernel_is_f16: bool, state_offset: u32, ) !void

Record a single-token SSM depthwise conv1d + SiLU dispatch.

Reads the current SSM conv state via `state_offset` (a rotating index into the circular state buffer), applies a depthwise conv kernel of width `d_conv`, and writes the SiLU-activated output in-place.

Parameters
self
Dispatch wrapper containing the SSM conv1d pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with four bindings: input, kernel, state, output.
conv_channels
Number of SSM channels (width of the depthwise conv).
d_conv
Kernel width of the depthwise convolution.
kernel_is_f16
True when the kernel weight buffer is f16; false for f32.
state_offset
Current rotation index (0..d_conv-2) into the circular state buffer.
Returns

`error.ShaderNotLoaded` when the SSM conv1d pipeline is unavailable.

src/compute/elementwise.zig:1232

method

ElementwiseDispatch.recordSsmQkNorm

#
pub fn recordSsmQkNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, d_state: u32, n_group: u32, ) !void

Record an in-place SSM Q/K RMS-norm dispatch.

Applies per-group RMS normalization to the concatenated Q and K projections inside a Mamba/DeltaNet SSM block; dispatches one workgroup per group.

Parameters
self
Dispatch wrapper containing the SSM Q/K norm pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with one binding: the Q+K buffer (in-place).
d_state
Per-group state dimension (qk_dim = d_state * n_group).
n_group
Number of normalization groups; one workgroup per group.
Returns

`error.ShaderNotLoaded` when the SSM Q/K norm pipeline is unavailable.

src/compute/elementwise.zig:1261

method

ElementwiseDispatch.recordSsmDeltaNet

#
pub fn recordSsmDeltaNet( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void

Record an SSM DeltaNet state-update dispatch (baseline variant).

Executes the DeltaNet recurrence over a single token (or `push.n_tok` prefill tokens when n_tok > 1). Grid is `(dt_rank, head_v_dim, 1)` — one wave64 workgroup per (head, row) pair.

Parameters
self
Dispatch wrapper containing the SSM delta-net pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 7 bindings: conv_out, dt_bias, alpha, beta, ssm_a, state, output.
push
Fully populated push-constant struct describing the SSM dimensions and flags.
Returns

`error.ShaderNotLoaded` when the SSM delta-net pipeline is unavailable.

src/compute/elementwise.zig:1286

method

ElementwiseDispatch.recordSsmDeltaNetCols8

#
pub fn recordSsmDeltaNetCols8( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void

Record an SSM DeltaNet state-update dispatch using the cols8 tiled variant.

Each wave64 workgroup processes four output rows (head_v_dim / 4 workgroups per head), improving register reuse relative to the baseline 1-row shader.

Parameters
self
Dispatch wrapper containing the SSM delta-net cols8 pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 7 bindings (same layout as `recordSsmDeltaNet`).
push
Fully populated push-constant struct describing the SSM dimensions and flags.
Returns

`error.ShaderNotLoaded` when the SSM delta-net cols8 pipeline is unavailable.

src/compute/elementwise.zig:1306

method

ElementwiseDispatch.recordSsmDeltaNetCols8Normed

#
pub fn recordSsmDeltaNetCols8Normed( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmDeltaNetPush, ) !void

Record an SSM DeltaNet state-update dispatch using the cols8 normed variant.

Identical semantics to `recordSsmDeltaNetCols8` but selects the shader that expects Q/K inputs to be pre-normalized (skipping the in-shader norm). Each wave64 workgroup processes eight output rows (head_v_dim / 8 workgroups per head).

Parameters
self
Dispatch wrapper containing the SSM delta-net cols8 normed pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 7 bindings (same layout as `recordSsmDeltaNet`).
push
Fully populated push-constant struct describing the SSM dimensions and flags.
Returns

`error.ShaderNotLoaded` when the SSM delta-net cols8 normed pipeline is unavailable.

src/compute/elementwise.zig:1327

method

ElementwiseDispatch.recordSsmGatedNorm

#
pub fn recordSsmGatedNorm( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, push: SsmGatedNormPush, ) !void

Record an SSM gated norm dispatch: applies z-gate * RMS-norm(delta_output).

Dispatches one wave64 workgroup per head (`push.dt_rank` workgroups total).

Parameters
self
Dispatch wrapper containing the SSM gated norm pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 4 bindings: delta_output, z_gate, norm_weights, output.
push
Push-constant struct specifying d_inner, dt_rank, head_v_dim, d_state, and norm_per_head.
Returns

`error.ShaderNotLoaded` when the SSM gated norm pipeline is unavailable.

src/compute/elementwise.zig:1345

method

ElementwiseDispatch.recordSoftmaxTopk

#
pub fn recordSoftmaxTopk( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_experts: u32, k: u32, ) !void

Record softmax + top-k MoE router dispatch.

src/compute/elementwise.zig:1356

method

ElementwiseDispatch.recordSigmoidScaleAcc

#
pub fn recordSigmoidScaleAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, ) !void

Record sigmoid-gated scale-accumulate: a[i] += sigmoid(c[0]) * b[i].

src/compute/elementwise.zig:1369

method

ElementwiseDispatch.recordMoeWeightedAcc

#
pub fn recordMoeWeightedAcc( self: *const ElementwiseDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, n_elements: u32, n_used: u32, src_stride: u32, ) !void

Record a MoE weighted accumulate dispatch: `a[i] += routing_weight[j] * b[j*src_stride + i]` summed over `n_used` selected experts.

Routing weights are read from the GPU routing buffer (binding 2), not from a push constant.

Parameters
self
Dispatch wrapper containing the MoE weighted accumulate pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with 3 bindings: accum (rw), src experts, routing weights.
n_elements
Hidden dimension of the accumulation buffer (elements updated per token).
n_used
Number of selected experts whose outputs are summed.
src_stride
Elements per expert in the source buffer (typically equal to n_elements).
Returns

`error.ShaderNotLoaded` when the MoE weighted accumulate pipeline is unavailable.

src/compute/elementwise.zig:1392

method

ElementwiseDispatch.deinit

#
pub fn deinit(self: *ElementwiseDispatch) void

Destroy the loaded pipelines and descriptor pool.

Parameters
self
Dispatch wrapper to tear down in place.

src/compute/elementwise.zig:1408