Last updated: 2026-06-12
Inference Runtime
Forward Metal
Metal inference engine — decode loop for Apple Silicon.
This is the Metal equivalent of forward.zig (Vulkan). Uses MSL compute shaders dispatched via the Metal shim.
17 exports shown
constant
CommandEncoderMode
pub const CommandEncoderMode = metal_command.CommandEncoderMode Command-encoder mode re-export used by runtime init options.
constant
runtime_context_cap
pub const runtime_context_cap: u32 = 262144 Upper bound on the Metal KV-cache allocation: we still honour the model's architectural context length and the unified-memory budget, but we refuse to allocate more tokens than this in one block to keep allocation latency and staging buffers sane.
Callers that already right-sized `cfg.context_length` from the device budget (see `memory_plan.autoContextTokensForDeviceBudget`) see this as a soft safety net rather than the primary limit.
struct
DecodeState
pub const DecodeState = struct Runtime state for the decode loop.
struct
GenerateMetrics
pub const GenerateMetrics = struct Metrics from generateWithMetrics: prefill/decode token counts, timing, and throughput.
struct
GenerateResult
pub const GenerateResult = struct Output tokens and performance metrics from a generation run.
Methods
1method
GenerateResult.deinit
pub fn deinit(self: *GenerateResult, allocator: std.mem.Allocator) void Free the generated token slice returned by `generateWithMetrics`.
struct
SamplingParams
pub const SamplingParams = struct Token sampling parameters: temperature, top-k, top-p, and repetition penalty.
Methods
1method
SamplingParams.requiresLogitsReadback
pub fn requiresLogitsReadback(self: @This()) bool Return whether sampling settings require CPU-visible logits instead of greedy argmax.
struct
InitOptions
pub const InitOptions = struct Options for `InferenceEngine.init`: profiling, debug validation, KV cache, and dispatch tuning.
struct
RuntimeProfile
pub const RuntimeProfile = struct Per-request profiling counters for dispatch, barrier, and timing breakdown.
function
attentionLayerCount
pub fn attentionLayerCount(cfg: ModelConfig) u32 Return the number of full-attention (non-SSM) transformer layers in the model.
function
defaultKvCacheQ8Enabled
pub fn defaultKvCacheQ8Enabled(config: ModelConfig, debug_validation_enabled: bool) bool Return whether Q8 KV cache quantization should be enabled by default for this model.
is available for numerical validation. gpt-oss (SwiGLU sensitivity) and Gemma4 with SWA (ISWA rotation path not yet ported).
function
kvCacheBytesPerToken
pub fn kvCacheBytesPerToken(config: ModelConfig, q8_enabled: bool) u64 Return the bytes consumed per token in the KV cache.
otherwise returns the unquantized f32 size.
struct
InferenceEngine
pub const InferenceEngine = struct Metal inference engine — owns GPU buffers, pipelines, and KV cache.
Methods
11method
InferenceEngine.init
pub fn init( model: *const metal_loader.Model, device: *const metal_device.MetalDevice, allocator: std.mem.Allocator, options: InitOptions, ) !InferenceEngine Initialize the Metal inference engine, allocating GPU buffers and compiling pipelines.
method
InferenceEngine.deinit
pub fn deinit(self: *InferenceEngine) void Release all GPU buffers, pipelines, and associated resources.
method
InferenceEngine.sampleGreedy
pub fn sampleGreedy(self: *const InferenceEngine) u32 Sample the next token greedily (argmax over logits).
method
InferenceEngine.sample
pub fn sample(self: *const InferenceEngine, history: []const u32, params: SamplingParams, random: std.Random) u32 Sample next token using temperature, top-k, top-p, and repetition penalty.
Falls back to greedy if parameters are near-default or buffers are private.
method
InferenceEngine.enableLogitsReadback
pub fn enableLogitsReadback(self: *InferenceEngine) void No public doc comment yet.
method
InferenceEngine.resetRequestState
pub fn resetRequestState(self: *InferenceEngine, requested_context_tokens: u32) !void Reset position, profiling counters, and SSM state for a new request.
method
InferenceEngine.prefillBatch
pub fn prefillBatch(self: *InferenceEngine, state: *DecodeState, prompt_tokens: []const u32) !void Run prompt prefill in token-major order, processing one token through all layers at a time.
Uses a queued async-submit path for short prompts when available; falls back to sequential `commitAndWait` per token otherwise. Callers that want the layer-major batched path should use `prefillBatched` instead. engine's internal position (stale or mismatched state).
method
InferenceEngine.prefillBatched
pub fn prefillBatched(self: *InferenceEngine, state: *DecodeState, prompt_tokens: []const u32) !void Experimental batched prompt prefill gated by `ZINC_BATCHED_PREFILL`.
Processes all prompt tokens in a single batched forward pass using the gemm_q4k / gemm_q6k / rope_batched / flash_attn_batched shaders — the weight matrix for each projection is read once for the whole prompt. Falls back to the per-token `prefillBatch` when the env flag is off or when the model architecture is outside the supported slice (see `canUseBatchedPrefill`). Supports both fresh prefill (state.position==0) and prefix reuse (state.position>0) — in the latter case, the batched pass extends the KV cache at offset `state.position` and flash attention causal masking is computed relative to that offset. With `ZINC_BATCHED_PREFILL=validate` the batched path runs first, then the per-token path is replayed on a fresh state and the last-token logits are diffed; max abs diff is logged and a warning is emitted if it exceeds 1e-3.
method
InferenceEngine.decodeStep
pub fn decodeStep(self: *InferenceEngine, state: *DecodeState, token_id: u32) !void Advance one autoregressive decode step from the given input token.
method
InferenceEngine.enableProfiling
pub fn enableProfiling(self: *InferenceEngine) !void Enable request-level profiling counters for subsequent decode work.
method
InferenceEngine.logRequestProfileSummary
pub fn logRequestProfileSummary(self: *const InferenceEngine, label: []const u8, prompt_tokens: usize, completion_tokens: u32) void Log the collected Metal profiling summary for the current request to the scoped logger.
Does nothing when profiling was not enabled via `enableProfiling`.
function
dequantRow
pub fn dequantRow(raw_data: []const u8, row: u32, cols: u32, quant_type: GGMLType, output: []f32) void Dequantize one row of quantized weight data to f32 values.
Supports f32, f16, Q5_0, Q5_1, Q8_0, Q4_K, Q5_K, Q6_K, and MXFP4. Unsupported types log a warning and zero the output slice.
function
topKSoftmaxWeight
pub fn topKSoftmaxWeight(logits: []const f32, k: u32, out_ids: []u32, out_weights: []f32) void Select the top-k entries by raw logit value, then apply softmax over only those k values.
Used for the SOFTMAX_WEIGHT expert-routing variant (gpt-oss), which differs from `topKSoftmax` in that softmax is applied to the pre-selected raw logits rather than to the full probability distribution first.
function
topKSoftmax
pub fn topKSoftmax(logits: []const f32, k: u32, out_ids: []u32, out_weights: []f32) void Apply softmax over all logits, then select the top-k entries by probability and renormalize.
function
generateWithMetrics
pub fn generateWithMetrics( engine: *InferenceEngine, prompt_tokens: []const u32, max_tokens: u32, eos_id: u32, allocator: std.mem.Allocator, ) !GenerateResult Run prompt prefill followed by autoregressive decode and return tokens with timing metrics.
function
generate
pub fn generate( engine: *InferenceEngine, prompt_tokens: []const u32, max_tokens: u32, eos_id: u32, allocator: std.mem.Allocator, ) ![]u32 Run prefill and decode, log throughput, and return only the generated token slice.
Convenience wrapper around `generateWithMetrics` for callers that do not need the structured `GenerateMetrics` breakdown.