Last updated: 2026-06-12

Shader Dispatch

Attention

All API Sections

Wrap the flash-attention compute shader and its dispatch parameters.

This helper owns the pipeline resources needed to bind paged attention inputs and record a flash-attention compute pass.

4 exports 6 methods src/compute/attention.zig

4 exports shown

struct

FlashAttnPush

#
pub const FlashAttnPush = extern struct

Push constants for flash attention shader.

src/compute/attention.zig:15

struct

FlashAttnBatchedPush

#
pub const FlashAttnBatchedPush = extern struct

Push constants for flash_attn_batched.

Shared by two callers: - prefill batched path: processes N queries sharing a KV cache, seq_start is the position of the first query (0 on fresh prefill). - decode-shape foundation (ZINC_BATCH_ATTN=1): n_queries=1 with seq_start=state.position, bit-equivalent to the non-batched shader. sink_offset is the per-layer offset into the per-head sinks buffer (layer_idx * n_heads) — honoured by gpt-oss, NaN-gated otherwise.

src/compute/attention.zig:32

struct

FlashAttnSplitMergePush

#
pub const FlashAttnSplitMergePush = extern struct

Push constants for flash_attn_split_merge.

Reads N_I_CHUNKS partial outputs per head from binding 0, applies the per-head sink, normalizes, writes the final output to binding 1.

src/compute/attention.zig:46

struct

AttentionDispatch

#
pub const AttentionDispatch = struct

Owns the Vulkan compute pipelines for flash attention and records their dispatches into a command buffer.

Supports three variants: single-query decode (`pipeline`), batched prefill/decode (`pipeline_batched`), and split-K decode (`pipeline_split` + `pipeline_split_merge`). The active variant is selected by the caller; all pipelines are loaded during `init` and destroyed by `deinit`.

src/compute/attention.zig:58

Methods

6

method

AttentionDispatch.init

#
pub fn init( instance: *const Instance, shader_dir: []const u8, allocator: std.mem.Allocator, ) !AttentionDispatch

Create the flash-attention dispatch wrapper and load its shader pipeline.

Parameters
instance
Active Vulkan instance and logical device.
shader_dir
Directory containing compiled SPIR-V shader binaries.
allocator
Allocator used for temporary pipeline creation state.
Returns

An AttentionDispatch ready to record flash-attention passes.

src/compute/attention.zig:81

method

AttentionDispatch.recordFlashAttn

#
pub fn recordFlashAttn( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_len: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void

Record a flash-attention dispatch for the current decode position.

Parameters
self
Dispatch wrapper containing the flash-attention pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set containing query, KV-cache, page-table, and output buffers.
head_dim
Hidden width per attention head.
n_heads
Number of query heads to process.
n_kv_heads
Number of KV heads present in the cache.
seq_len
Current decoded sequence length.
page_size
Tokens stored in each KV-cache page.
Returns

`error.ShaderNotLoaded` when the flash-attention shader pipeline is unavailable.

Notes

The helper dispatches one workgroup per query head.

src/compute/attention.zig:204

method

AttentionDispatch.recordFlashAttnBatched

#
pub fn recordFlashAttnBatched( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_start: u32, n_queries: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void

Record a batched flash-attention dispatch.

Grid is (n_heads, n_queries, 1); each (head, query) workgroup streams over the paged KV cache with causal_len = seq_start + query + 1.

Parameters
self
Dispatch wrapper containing the batched flash-attention pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with Q, KV-cache, page-table, output, and sink buffers.
head_dim
Hidden width per attention head.
n_heads
Number of query heads to process.
n_kv_heads
Number of KV heads in the cache (GQA ratio = n_heads / n_kv_heads).
seq_start
Token position of the first query in the sequence (0 on fresh prefill).
n_queries
Number of query tokens processed in this batch.
page_size
Tokens stored per KV-cache page.
attn_scale
Attention softmax scale factor (0 = use 1/sqrt(head_dim)).
sink_offset
Per-layer offset into the sink buffer (layer_idx * n_heads).
Returns

`error.ShaderNotLoaded` when the batched pipeline is unavailable.

src/compute/attention.zig:255

method

AttentionDispatch.recordFlashAttnSplit

#
pub fn recordFlashAttnSplit( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_len: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void

Record the split-K flash attention dispatch (per-chunk partial pass).

Grid: (n_heads, fa_split_k_active, 1). Each workgroup runs the standard flash_attn body scoped to its (head, chunk_id) i-range and writes (O_partial, M, L) to the partial output buffer bound at slot 4. Must be followed by `recordFlashAttnSplitMerge` to produce final output.

Parameters
self
Dispatch wrapper containing the split-K pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with Q, KV-cache, page-table, partial-output, and sink buffers.
head_dim
Hidden width per attention head.
n_heads
Number of query heads to process.
n_kv_heads
Number of KV heads in the cache (GQA ratio = n_heads / n_kv_heads).
seq_len
Current decoded sequence length (total KV entries to attend over).
page_size
Tokens stored per KV-cache page.
attn_scale
Attention softmax scale factor (0 = use 1/sqrt(head_dim)).
sink_offset
Per-layer offset into the sink buffer (layer_idx * n_heads).
Returns

`error.ShaderNotLoaded` when the split-K pipeline is unavailable.

src/compute/attention.zig:298

method

AttentionDispatch.recordFlashAttnSplitMerge

#
pub fn recordFlashAttnSplitMerge( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, sink_offset: u32, ) !void

Record the split-K merge pass dispatch — combines per-chunk partials for each head, applies the per-head sink term, and writes the final normalized attention output.

Grid: (n_heads, 1, 1). Must be called after `recordFlashAttnSplit` and a pipeline barrier.

Parameters
self
Dispatch wrapper containing the split-K merge pipeline.
cmd
Command buffer currently being recorded.
descriptor_set
Descriptor set with partial-input, final-output, and sink buffers (3 bindings).
head_dim
Hidden width per attention head.
n_heads
Number of query heads whose partials are to be merged.
sink_offset
Per-layer offset into the sink buffer (layer_idx * n_heads).
Returns

`error.ShaderNotLoaded` when the split-K merge pipeline is unavailable.

src/compute/attention.zig:334

method

AttentionDispatch.deinit

#
pub fn deinit(self: *AttentionDispatch) void

Destroy the loaded pipeline and descriptor pool.

Parameters
self
Dispatch wrapper to tear down in place.

src/compute/attention.zig:353