Last updated: 2026-06-12

Inference Runtime

Forward Metal

All API Sections

Metal inference engine — decode loop for Apple Silicon.

This is the Metal equivalent of forward.zig (Vulkan). Uses MSL compute shaders dispatched via the Metal shim.

17 exports 15 methods src/compute/forward_metal.zig

17 exports shown

constant

CommandEncoderMode

#
pub const CommandEncoderMode = metal_command.CommandEncoderMode

Command-encoder mode re-export used by runtime init options.

src/compute/forward_metal.zig:20

constant

runtime_context_cap

#
pub const runtime_context_cap: u32 = 262144

Upper bound on the Metal KV-cache allocation: we still honour the model's architectural context length and the unified-memory budget, but we refuse to allocate more tokens than this in one block to keep allocation latency and staging buffers sane.

Callers that already right-sized `cfg.context_length` from the device budget (see `memory_plan.autoContextTokensForDeviceBudget`) see this as a soft safety net rather than the primary limit.

src/compute/forward_metal.zig:31

struct

DecodeState

#
pub const DecodeState = struct

Runtime state for the decode loop.

src/compute/forward_metal.zig:62

Methods

2

method

DecodeState.init

#
pub fn init(allocator: std.mem.Allocator) DecodeState

Initialize decode state for a fresh Metal generation request.

src/compute/forward_metal.zig:69

struct

GenerateMetrics

#
pub const GenerateMetrics = struct

Metrics from generateWithMetrics: prefill/decode token counts, timing, and throughput.

src/compute/forward_metal.zig:86

struct

GenerateResult

#
pub const GenerateResult = struct

Output tokens and performance metrics from a generation run.

src/compute/forward_metal.zig:98

Methods

1

method

GenerateResult.deinit

#
pub fn deinit(self: *GenerateResult, allocator: std.mem.Allocator) void

Free the generated token slice returned by `generateWithMetrics`.

src/compute/forward_metal.zig:103

struct

SamplingParams

#
pub const SamplingParams = struct

Token sampling parameters: temperature, top-k, top-p, and repetition penalty.

src/compute/forward_metal.zig:110

Methods

1

method

SamplingParams.requiresLogitsReadback

#
pub fn requiresLogitsReadback(self: @This()) bool

Return whether sampling settings require CPU-visible logits instead of greedy argmax.

src/compute/forward_metal.zig:117

struct

InitOptions

#
pub const InitOptions = struct

Options for `InferenceEngine.init`: profiling, debug validation, KV cache, and dispatch tuning.

src/compute/forward_metal.zig:123

struct

RuntimeProfile

#
pub const RuntimeProfile = struct

Per-request profiling counters for dispatch, barrier, and timing breakdown.

src/compute/forward_metal.zig:918

function

attentionLayerCount

#
pub fn attentionLayerCount(cfg: ModelConfig) u32

Return the number of full-attention (non-SSM) transformer layers in the model.

Parameters

cfg
Model configuration supplying `n_layers` and the full-attention interval.

Returns

Count of layers that use full multi-head attention; 0 for pure-SSM models.

src/compute/forward_metal.zig:1405

function

defaultKvCacheQ8Enabled

#
pub fn defaultKvCacheQ8Enabled(config: ModelConfig, debug_validation_enabled: bool) bool

Return whether Q8 KV cache quantization should be enabled by default for this model.

is available for numerical validation. gpt-oss (SwiGLU sensitivity) and Gemma4 with SWA (ISWA rotation path not yet ported).

Parameters

config
Model configuration used to check architecture and key-value dimensions.
debug_validation_enabled
When true, always returns false so the unquantized cache

Returns

True when the architecture and dimensions support Q8 KV cache; false for

src/compute/forward_metal.zig:1420

function

kvCacheBytesPerToken

#
pub fn kvCacheBytesPerToken(config: ModelConfig, q8_enabled: bool) u64

Return the bytes consumed per token in the KV cache.

otherwise returns the unquantized f32 size.

Parameters

config
Model configuration providing `n_kv_heads` and `head_dim`.
q8_enabled
When true, returns the Q8_0 packed size (34 bytes per 32 floats);

Returns

Bytes per KV-cache slot for a single token position.

src/compute/forward_metal.zig:1439

struct

InferenceEngine

#
pub const InferenceEngine = struct

Metal inference engine — owns GPU buffers, pipelines, and KV cache.

src/compute/forward_metal.zig:4374

Methods

11

method

InferenceEngine.init

#
pub fn init( model: *const metal_loader.Model, device: *const metal_device.MetalDevice, allocator: std.mem.Allocator, options: InitOptions, ) !InferenceEngine

Initialize the Metal inference engine, allocating GPU buffers and compiling pipelines.

src/compute/forward_metal.zig:4798

method

InferenceEngine.sample

#
pub fn sample(self: *const InferenceEngine, history: []const u32, params: SamplingParams, random: std.Random) u32

Sample next token using temperature, top-k, top-p, and repetition penalty.

Falls back to greedy if parameters are near-default or buffers are private.

src/compute/forward_metal.zig:6442

method

InferenceEngine.resetRequestState

#
pub fn resetRequestState(self: *InferenceEngine, requested_context_tokens: u32) !void

Reset position, profiling counters, and SSM state for a new request.

src/compute/forward_metal.zig:6461

method

InferenceEngine.prefillBatch

#
pub fn prefillBatch(self: *InferenceEngine, state: *DecodeState, prompt_tokens: []const u32) !void

Run prompt prefill in token-major order, processing one token through all layers at a time.

Uses a queued async-submit path for short prompts when available; falls back to sequential `commitAndWait` per token otherwise. Callers that want the layer-major batched path should use `prefillBatched` instead. engine's internal position (stale or mismatched state).

Notes

Returns `error.ContextLengthExceeded` if the prompt would overflow the KV cache.

Returns `error.KvStateNotAvailable` if `state.position` does not match the

src/compute/forward_metal.zig:6697

method

InferenceEngine.prefillBatched

#
pub fn prefillBatched(self: *InferenceEngine, state: *DecodeState, prompt_tokens: []const u32) !void

Experimental batched prompt prefill gated by `ZINC_BATCHED_PREFILL`.

Processes all prompt tokens in a single batched forward pass using the gemm_q4k / gemm_q6k / rope_batched / flash_attn_batched shaders — the weight matrix for each projection is read once for the whole prompt. Falls back to the per-token `prefillBatch` when the env flag is off or when the model architecture is outside the supported slice (see `canUseBatchedPrefill`). Supports both fresh prefill (state.position==0) and prefix reuse (state.position>0) — in the latter case, the batched pass extends the KV cache at offset `state.position` and flash attention causal masking is computed relative to that offset. With `ZINC_BATCHED_PREFILL=validate` the batched path runs first, then the per-token path is replayed on a fresh state and the last-token logits are diffed; max abs diff is logged and a warning is emitted if it exceeds 1e-3.

src/compute/forward_metal.zig:7408

method

InferenceEngine.decodeStep

#
pub fn decodeStep(self: *InferenceEngine, state: *DecodeState, token_id: u32) !void

Advance one autoregressive decode step from the given input token.

src/compute/forward_metal.zig:7762

method

InferenceEngine.enableProfiling

#
pub fn enableProfiling(self: *InferenceEngine) !void

Enable request-level profiling counters for subsequent decode work.

src/compute/forward_metal.zig:7776

method

InferenceEngine.logRequestProfileSummary

#
pub fn logRequestProfileSummary(self: *const InferenceEngine, label: []const u8, prompt_tokens: usize, completion_tokens: u32) void

Log the collected Metal profiling summary for the current request to the scoped logger.

Does nothing when profiling was not enabled via `enableProfiling`.

Parameters
label
Short identifier string included in every log line (e.g. model name).
prompt_tokens
Number of prompt tokens processed during prefill.
completion_tokens
Number of tokens generated during decode.

src/compute/forward_metal.zig:7790

function

dequantRow

#
pub fn dequantRow(raw_data: []const u8, row: u32, cols: u32, quant_type: GGMLType, output: []f32) void

Dequantize one row of quantized weight data to f32 values.

Supports f32, f16, Q5_0, Q5_1, Q8_0, Q4_K, Q5_K, Q6_K, and MXFP4. Unsupported types log a warning and zero the output slice.

Parameters

raw_data
Raw GGUF tensor bytes for the full matrix.
row
Zero-based row index to dequantize.
cols
Number of columns (elements) per row.
quant_type
GGML quantization type describing the on-disk layout.
output
Caller-allocated slice of at least `cols` f32 values; filled in place.

src/compute/forward_metal.zig:17234

function

topKSoftmaxWeight

#
pub fn topKSoftmaxWeight(logits: []const f32, k: u32, out_ids: []u32, out_weights: []f32) void

Select the top-k entries by raw logit value, then apply softmax over only those k values.

Used for the SOFTMAX_WEIGHT expert-routing variant (gpt-oss), which differs from `topKSoftmax` in that softmax is applied to the pre-selected raw logits rather than to the full probability distribution first.

Parameters

logits
Raw router logit values, one per expert.
k
Number of top experts to select.
out_ids
Output slice of length k; filled with the indices of selected experts.
out_weights
Output slice of length k; filled with softmax-normalized weights.

src/compute/forward_metal.zig:17517

function

topKSoftmax

#
pub fn topKSoftmax(logits: []const f32, k: u32, out_ids: []u32, out_weights: []f32) void

Apply softmax over all logits, then select the top-k entries by probability and renormalize.

Parameters

logits
Raw router logit values, one per expert.
k
Number of top experts to select.
out_ids
Output slice of length k; filled with the indices of selected experts.
out_weights
Output slice of length k; filled with renormalized softmax weights.

src/compute/forward_metal.zig:17552

function

generateWithMetrics

#
pub fn generateWithMetrics( engine: *InferenceEngine, prompt_tokens: []const u32, max_tokens: u32, eos_id: u32, allocator: std.mem.Allocator, ) !GenerateResult

Run prompt prefill followed by autoregressive decode and return tokens with timing metrics.

Parameters

engine
Initialized inference engine owning the model weights and KV cache.
prompt_tokens
Tokenized prompt; may be empty for continuation from a prior state.
max_tokens
Upper bound on the number of tokens to generate (not counting prompt).
eos_id
Token id that terminates generation early when sampled.
allocator
Used to allocate the returned `output_tokens` slice; caller must free via `GenerateResult.deinit`.

Returns

`GenerateResult` with the generated token slice and per-phase timing metrics.

src/compute/forward_metal.zig:24865

function

generate

#
pub fn generate( engine: *InferenceEngine, prompt_tokens: []const u32, max_tokens: u32, eos_id: u32, allocator: std.mem.Allocator, ) ![]u32

Run prefill and decode, log throughput, and return only the generated token slice.

Convenience wrapper around `generateWithMetrics` for callers that do not need the structured `GenerateMetrics` breakdown.

Parameters

engine
Initialized inference engine owning the model weights and KV cache.
prompt_tokens
Tokenized prompt passed directly to `generateWithMetrics`.
max_tokens
Upper bound on tokens to generate.
eos_id
Token id that terminates generation early when sampled.
allocator
Used to allocate the returned slice; caller is responsible for freeing it.

Returns

Caller-owned slice of generated token ids (excludes the prompt).

src/compute/forward_metal.zig:24974