Last updated: 2026-06-12
Shader Dispatch
Attention
Wrap the flash-attention compute shader and its dispatch parameters.
This helper owns the pipeline resources needed to bind paged attention inputs and record a flash-attention compute pass.
4 exports shown
struct
FlashAttnPush
pub const FlashAttnPush = extern struct Push constants for flash attention shader.
struct
FlashAttnBatchedPush
pub const FlashAttnBatchedPush = extern struct Push constants for flash_attn_batched.
Shared by two callers: - prefill batched path: processes N queries sharing a KV cache, seq_start is the position of the first query (0 on fresh prefill). - decode-shape foundation (ZINC_BATCH_ATTN=1): n_queries=1 with seq_start=state.position, bit-equivalent to the non-batched shader. sink_offset is the per-layer offset into the per-head sinks buffer (layer_idx * n_heads) — honoured by gpt-oss, NaN-gated otherwise.
struct
FlashAttnSplitMergePush
pub const FlashAttnSplitMergePush = extern struct Push constants for flash_attn_split_merge.
Reads N_I_CHUNKS partial outputs per head from binding 0, applies the per-head sink, normalizes, writes the final output to binding 1.
struct
AttentionDispatch
pub const AttentionDispatch = struct Owns the Vulkan compute pipelines for flash attention and records their dispatches into a command buffer.
Supports three variants: single-query decode (`pipeline`), batched prefill/decode (`pipeline_batched`), and split-K decode (`pipeline_split` + `pipeline_split_merge`). The active variant is selected by the caller; all pipelines are loaded during `init` and destroyed by `deinit`.
Methods
6method
AttentionDispatch.init
pub fn init( instance: *const Instance, shader_dir: []const u8, allocator: std.mem.Allocator, ) !AttentionDispatch Create the flash-attention dispatch wrapper and load its shader pipeline.
method
AttentionDispatch.recordFlashAttn
pub fn recordFlashAttn( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_len: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void Record a flash-attention dispatch for the current decode position.
method
AttentionDispatch.recordFlashAttnBatched
pub fn recordFlashAttnBatched( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_start: u32, n_queries: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void Record a batched flash-attention dispatch.
Grid is (n_heads, n_queries, 1); each (head, query) workgroup streams over the paged KV cache with causal_len = seq_start + query + 1.
method
AttentionDispatch.recordFlashAttnSplit
pub fn recordFlashAttnSplit( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, n_kv_heads: u32, seq_len: u32, page_size: u32, attn_scale: f32, sink_offset: u32, ) !void Record the split-K flash attention dispatch (per-chunk partial pass).
Grid: (n_heads, fa_split_k_active, 1). Each workgroup runs the standard flash_attn body scoped to its (head, chunk_id) i-range and writes (O_partial, M, L) to the partial output buffer bound at slot 4. Must be followed by `recordFlashAttnSplitMerge` to produce final output.
method
AttentionDispatch.recordFlashAttnSplitMerge
pub fn recordFlashAttnSplitMerge( self: *const AttentionDispatch, cmd: *CommandBuffer, descriptor_set: vk.c.VkDescriptorSet, head_dim: u32, n_heads: u32, sink_offset: u32, ) !void Record the split-K merge pass dispatch — combines per-chunk partials for each head, applies the per-head sink term, and writes the final normalized attention output.
Grid: (n_heads, 1, 1). Must be called after `recordFlashAttnSplit` and a pipeline barrier.
method
AttentionDispatch.deinit
pub fn deinit(self: *AttentionDispatch) void Destroy the loaded pipeline and descriptor pool.