Package inference#
import "github.com/zerfoo/zerfoo/inference"Package inference provides a high-level API for loading GGUF models and running text generation, chat, embedding, and speculative decoding with minimal boilerplate.
For lower-level control over text generation, KV caching, and sampling, see the generate package. For an OpenAI-compatible HTTP server built on top of this package, see the serve package.
Full method signatures: pkg.go.dev/github.com/zerfoo/zerfoo/inference
Quick Start#
m, err := inference.Load("gemma-3-1b-q4",
inference.WithDevice("cuda"),
inference.WithMaxSeqLen(4096),
)
if err != nil {
log.Fatal(err)
}
defer m.Close()
text, err := m.Generate(ctx, "Explain gradient descent briefly.",
inference.WithMaxTokens(256),
inference.WithTemperature(0.7),
)Model Loading#
func Load#
func Load(modelID string, opts ...Option) (*Model, error)Load resolves a model by name or HuggingFace repo ID, pulling it from the registry if not already cached, and returns a ready-to-use Model.
Short aliases such as "gemma-3-1b-q4" and "llama-3-8b-q4" map to full HuggingFace repository IDs. Use ResolveAlias to look up the mapping.
m, err := inference.Load("llama-3-8b-q4",
inference.WithDevice("cuda:0"),
inference.WithCacheDir("/models"),
)
if err != nil {
log.Fatal(err)
}
defer m.Close()func LoadFile#
func LoadFile(path string, opts ...Option) (*Model, error)LoadFile loads a model directly from a local GGUF file path and returns a ready-to-use Model.
m, err := inference.LoadFile("/data/models/gemma-3-1b-q4_k_m.gguf",
inference.WithDevice("cuda"),
)
if err != nil {
log.Fatal(err)
}
defer m.Close()func LoadGGUF#
func LoadGGUF(path string) (*GGUFModel, error)LoadGGUF loads a GGUF model file and returns its configuration and tensors as an intermediate representation. This is useful for inspecting model metadata or building custom computation graphs. Tensor names are mapped from GGUF convention (blk.N.attn_q.weight) to Zerfoo canonical names (model.layers.N.self_attn.q_proj.weight).
gguf, err := inference.LoadGGUF("/data/models/llama-3-8b.gguf")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Architecture: %s\n", gguf.Config.Architecture)
fmt.Printf("Tensors: %d\n", len(gguf.Tensors))Load Options#
type Option#
type Option func(*loadOptions)Option configures model loading. Pass these to Load or LoadFile.
func WithDevice#
func WithDevice(device string) OptionSets the compute device. Supported values: "cpu", "cuda", "cuda:N" (specific GPU), "rocm", "opencl".
m, _ := inference.Load("gemma-3-1b-q4", inference.WithDevice("cuda:0"))func WithCacheDir#
func WithCacheDir(dir string) OptionSets the local directory for cached model files.
func WithMaxSeqLen#
func WithMaxSeqLen(n int) OptionOverrides the model’s default maximum sequence length.
func WithRegistry#
func WithRegistry(r registry.ModelRegistry) OptionSupplies a custom model registry for model resolution.
func WithBackend#
func WithBackend(backend string) OptionSelects the inference backend. Supported values: "" or "default" for the standard Engine path, "tensorrt" for TensorRT-optimized inference. TensorRT requires the cuda build tag and a CUDA device.
func WithPrecision#
func WithPrecision(precision string) OptionSets the compute precision for the TensorRT backend. Supported values: "" or "fp32" for full precision, "fp16" for half precision. Has no effect when the backend is not "tensorrt".
func WithDType#
func WithDType(dtype string) OptionSets the compute precision for the GPU engine. Supported values: "" or "fp32" for full precision, "fp16" for FP16 compute. FP16 mode converts activations F32->FP16 before GPU kernels and back after. Has no effect on CPU engines.
func WithKVDtype#
func WithKVDtype(dtype string) OptionSets the KV cache storage dtype. Supported: "fp32" (default), "fp16". FP16 halves KV cache bandwidth by storing keys/values in half precision.
func WithMmap#
func WithMmap(enabled bool) OptionControls memory-mapped model loading. mmap is enabled by default. When enabled, the GGUF file is mapped into virtual address space using syscall.Mmap; tensor data is paged from disk on demand by the OS, avoiding heap allocation and enabling models larger than physical RAM. Pass false to use heap loading, which is required for CUDA graph capture. Only supported on unix platforms.
Model#
type Model#
type Model struct {
// unexported fields
}Model is a loaded model ready for generation. Created by Load or LoadFile.
func (*Model) Generate#
func (m *Model) Generate(ctx context.Context, prompt string, opts ...GenerateOption) (string, error)Produces text from a prompt. Sessions are pooled to reuse GPU memory addresses, enabling CUDA graph replay across calls. Concurrent Generate calls get separate sessions from the pool.
text, err := m.Generate(ctx, "What is backpropagation?",
inference.WithMaxTokens(512),
inference.WithTemperature(0.7),
inference.WithTopP(0.9),
)func (*Model) GenerateStream#
func (m *Model) GenerateStream(ctx context.Context, prompt string, handler generate.TokenStream, opts ...GenerateOption) errorDelivers tokens one at a time via a TokenStream callback. Sessions are pooled to preserve GPU memory addresses for CUDA graph replay.
err := m.GenerateStream(ctx, "Tell me a story.",
generate.TokenStreamFunc(func(token string, done bool) error {
if !done {
fmt.Print(token)
}
return nil
}),
inference.WithMaxTokens(256),
)func (*Model) GenerateBatch#
func (m *Model) GenerateBatch(ctx context.Context, prompts []string, opts ...GenerateOption) ([]string, error)Processes multiple prompts concurrently and returns the generated text for each prompt. Results are returned in the same order as the input prompts.
prompts := []string{
"Explain neural networks.",
"What is gradient descent?",
"Define overfitting.",
}
results, err := m.GenerateBatch(ctx, prompts,
inference.WithMaxTokens(128),
)func (*Model) Chat#
func (m *Model) Chat(ctx context.Context, messages []Message, opts ...GenerateOption) (Response, error)Formats a slice of Message values using the model’s chat template and generates a Response with token usage statistics.
resp, err := m.Chat(ctx, []inference.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is machine learning?"},
},
inference.WithMaxTokens(256),
)
fmt.Printf("Response: %s\n", resp.Content)
fmt.Printf("Tokens used: %d\n", resp.TokensUsed)func (*Model) Embed#
func (m *Model) Embed(text string) ([]float32, error)Returns an L2-normalized embedding vector for the given text by looking up token embeddings from the model’s embedding table and mean-pooling them.
vec, err := m.Embed("machine learning")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Embedding dimension: %d\n", len(vec))func (*Model) SpeculativeGenerate#
func (m *Model) SpeculativeGenerate(
ctx context.Context,
draft *Model,
prompt string,
draftLen int,
opts ...GenerateOption,
) (string, error)Runs speculative decoding using this model as the target and the draft model for token proposal. draftLen controls how many tokens are proposed per verification step (typically 2-8).
target, _ := inference.Load("llama-3-70b-q4", inference.WithDevice("cuda"))
draft, _ := inference.Load("llama-3-8b-q4", inference.WithDevice("cuda"))
defer target.Close()
defer draft.Close()
text, err := target.SpeculativeGenerate(ctx, draft,
"Explain quantum computing.",
4, // propose 4 tokens per step
inference.WithMaxTokens(256),
)func (*Model) Close#
func (m *Model) Close() errorReleases resources held by the model. If the model was loaded on a GPU, this frees the CUDA engine’s handles, pool, and stream. If loaded with mmap, this releases the memory mapping.
func (*Model) Config#
func (m *Model) Config() ModelMetadataReturns the model metadata.
func (*Model) Generator#
func (m *Model) Generator() *generate.Generator[float32]Returns the underlying generator for lower-level access.
func (*Model) Tokenizer#
func (m *Model) Tokenizer() tokenizer.TokenizerReturns the model’s tokenizer for token counting.
func (*Model) Info#
func (m *Model) Info() *registry.ModelInfoReturns the registry info for this model.
func (*Model) EmbeddingWeights#
func (m *Model) EmbeddingWeights() ([]float32, int)Returns the flattened token embedding table and the hidden dimension. Returns nil, 0 if embeddings are not available.
func (*Model) SetEmbeddingWeights#
func (m *Model) SetEmbeddingWeights(weights []float32, hiddenSize int)Sets the token embedding table for Embed(). weights is a flattened [vocabSize, hiddenSize] matrix.
Generate Options#
type GenerateOption#
type GenerateOption func(*generate.SamplingConfig)GenerateOption configures a generation call. Pass these to Generate, GenerateStream, GenerateBatch, Chat, or SpeculativeGenerate.
func WithTemperature#
func WithTemperature(t float64) GenerateOptionSets the sampling temperature. Higher values produce more random output; 0 uses greedy (argmax) decoding.
func WithTopK#
func WithTopK(k int) GenerateOptionSets the top-K sampling cutoff. Only the top K most probable tokens are considered. 0 disables top-K filtering.
func WithTopP#
func WithTopP(p float64) GenerateOptionSets the nucleus (top-P) sampling threshold. Tokens are selected from the smallest set whose cumulative probability exceeds P. 1.0 disables top-P filtering.
func WithMaxTokens#
func WithMaxTokens(n int) GenerateOptionSets the maximum number of tokens to generate.
func WithRepetitionPenalty#
func WithRepetitionPenalty(p float64) GenerateOptionSets the repetition penalty factor. Values > 1.0 penalize repeated tokens.
func WithStopStrings#
func WithStopStrings(ss ...string) GenerateOptionSets strings that terminate generation when encountered in the output.
func WithGrammar#
func WithGrammar(g *grammar.Grammar) GenerateOptionSets a grammar state machine for constrained decoding. When set, a token mask is applied at each sampling step to restrict output to tokens valid according to the grammar.
Response Types#
type Message#
type Message struct {
Role string // "system", "user", or "assistant"
Content string
Images [][]byte // optional raw image data for vision models
}Message represents a chat message for the Chat method.
type Response#
type Response struct {
Content string
TokensUsed int
PromptTokens int
CompletionTokens int
}Response holds the result of a chat completion.
Model Metadata#
type ModelMetadata#
type ModelMetadata struct {
Architecture string
VocabSize int
HiddenSize int
NumLayers int
MaxPositionEmbeddings int
EOSTokenID int
BOSTokenID int
ChatTemplate string
// Extended fields for multi-architecture support.
IntermediateSize int
NumQueryHeads int
NumKeyValueHeads int
RopeTheta float64
RopeScaling *RopeScalingConfig
TieWordEmbeddings bool
SlidingWindow int
AttentionBias bool
PartialRotaryFactor float64
// DeepSeek MLA and MoE fields.
KVLoRADim int
QLoRADim int
QKRopeHeadDim int
NumExperts int
NumExpertsPerToken int
NumSharedExperts int
}ModelMetadata holds model configuration loaded from config.json or GGUF metadata.
type RopeScalingConfig#
type RopeScalingConfig struct {
Type string
Factor float64
OriginalMaxPositionEmbeddings int
}RopeScalingConfig holds configuration for RoPE scaling methods (e.g., YaRN).
GGUF Model Loading#
type GGUFModel#
type GGUFModel struct {
Config *gguf.ModelConfig
Tensors map[string]*tensor.TensorNumeric[float32]
File *gguf.File
}GGUFModel holds a loaded GGUF model’s configuration and tensors. This is an intermediate representation; full inference requires an architecture-specific graph builder.
func (*GGUFModel) ToModelMetadata#
func (m *GGUFModel) ToModelMetadata() *ModelMetadataConverts a GGUF model config to ModelMetadata.
Architecture Registry#
The architecture registry maps GGUF general.architecture values to graph builder functions.
type ArchBuilder#
type ArchBuilder func(
tensors map[string]*tensor.TensorNumeric[float32],
cfg *gguf.ModelConfig,
engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)ArchBuilder builds a computation graph for a model architecture from pre-loaded GGUF tensors. Returns the graph and the embedding table tensor.
func RegisterArchitecture#
func RegisterArchitecture(name string, builder ArchBuilder)Registers an architecture builder under the given name. Names correspond to GGUF general.architecture values (e.g. "llama", "gemma"). Panics if name is empty or a builder is already registered.
inference.RegisterArchitecture("custom", func(
tensors map[string]*tensor.TensorNumeric[float32],
cfg *gguf.ModelConfig,
engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error) {
// Build custom architecture graph...
return g, embedTensor, nil
})func GetArchitecture#
func GetArchitecture(name string) (ArchBuilder, bool)Returns the builder registered for the given architecture name. Returns nil, false if no builder is registered.
func ListArchitectures#
func ListArchitectures() []stringReturns a sorted list of all registered architecture names.
func BuildArchGraph#
func BuildArchGraph(
arch string,
tensors map[string]*tensor.TensorNumeric[float32],
cfg *gguf.ModelConfig,
engine compute.Engine[float32],
) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)Dispatches to the appropriate architecture-specific graph builder. Exported for benchmark and integration tests that construct synthetic weight maps without loading from GGUF files.
Model Aliases#
func RegisterAlias#
func RegisterAlias(shortName, repoID string)Adds a custom short name to HuggingFace repo ID mapping.
inference.RegisterAlias("my-model", "myorg/my-model-7b-q4")
m, _ := inference.Load("my-model")func ResolveAlias#
func ResolveAlias(name string) stringReturns the HuggingFace repo ID for a short alias. If the name is not an alias, it is returned unchanged.
Architecture-Specific Builders#
func BuildJamba#
func BuildJamba(jc JambaConfig, tensors map[string]*tensor.TensorNumeric[float32], engine compute.Engine[float32]) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)Constructs a computation graph for the Jamba hybrid architecture (mixed attention + SSM layers).
type JambaConfig#
type JambaConfig struct {
NumLayers int
HiddenSize int
IntermediateSize int
AttnHeads int
KVHeads int
SSMHeads int
AttentionLayerOffset int
RMSEps float32
VocabSize int
MaxSeqLen int
RopeTheta float64
DConv int
}func JambaConfigFromGGUF#
func JambaConfigFromGGUF(cfg *gguf.ModelConfig) JambaConfigExtracts Jamba configuration from GGUF ModelConfig.
func BuildMamba3#
func BuildMamba3(mc MambaConfig, tensors map[string]*tensor.TensorNumeric[float32], engine compute.Engine[float32]) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)Constructs a computation graph for Mamba-3 from a weight map.
type MambaConfig#
type MambaConfig struct {
NumLayers int
DModel int
DState int
DConv int
DInner int
VocabSize int
EOSTokenID int
RMSNormEps float32
}func MambaConfigFromGGUF#
func MambaConfigFromGGUF(cfg *gguf.ModelConfig) MambaConfigExtracts Mamba configuration from GGUF ModelConfig.
func MambaConfigFromMetadata#
func MambaConfigFromMetadata(meta map[string]interface{}) MambaConfigExtracts Mamba configuration from a raw metadata map.
func BuildRWKV#
func BuildRWKV(rc RWKVConfig, tensors map[string]*tensor.TensorNumeric[float32], engine compute.Engine[float32]) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)Constructs a computation graph for the RWKV-6/7 architecture.
type RWKVConfig#
type RWKVConfig struct {
NumLayers int
HiddenSize int
VocabSize int
HeadSize int // default 64
NumHeads int // HiddenSize / HeadSize
LayerNormEps float32
}func RWKVConfigFromGGUF#
func RWKVConfigFromGGUF(cfg *gguf.ModelConfig) RWKVConfigExtracts RWKV configuration from GGUF ModelConfig.
func BuildWhisperEncoder#
func BuildWhisperEncoder(wc WhisperConfig, tensors map[string]*tensor.TensorNumeric[float32], engine compute.Engine[float32]) (*graph.Graph[float32], *tensor.TensorNumeric[float32], error)Constructs a computation graph for the Whisper encoder.
type WhisperConfig#
type WhisperConfig struct {
NumMels int
HiddenDim int
NumHeads int
NumLayers int
KernelSize int
}func WhisperConfigFromGGUF#
func WhisperConfigFromGGUF(cfg *gguf.ModelConfig) WhisperConfigExtracts Whisper configuration from GGUF ModelConfig. NumMels defaults to 80, KernelSize defaults to 3.
Config Registry#
type ArchConfigRegistry#
type ArchConfigRegistry struct {
// unexported fields
}Maps model_type strings to config parsers.
func DefaultArchConfigRegistry#
func DefaultArchConfigRegistry() *ArchConfigRegistryReturns a registry with all built-in parsers registered.
func (*ArchConfigRegistry) Register#
func (r *ArchConfigRegistry) Register(modelType string, parser ConfigParser)Adds a parser for the given model type.
func (*ArchConfigRegistry) Parse#
func (r *ArchConfigRegistry) Parse(raw map[string]interface{}) (*ModelMetadata, error)Dispatches to the registered parser for the model_type in raw, or falls back to generic field extraction for unknown types.
type ConfigParser#
type ConfigParser func(raw map[string]interface{}) (*ModelMetadata, error)Parses a raw JSON map (from config.json) into ModelMetadata.
TensorRT Integration#
func ConvertGraphToTRT#
func ConvertGraphToTRT(
g *graph.Graph[float32],
workspaceBytes int,
fp16 bool,
dynamicShapes *DynamicShapeConfig,
) (*trtConversionResult, error)Walks a graph in topological order and maps each node to a TensorRT layer. Returns serialized engine bytes or an UnsupportedOpError if the graph contains operations that cannot be converted.
type DynamicShapeConfig#
type DynamicShapeConfig struct {
InputShapes []ShapeRange
}Specifies per-input shape ranges for TensorRT optimization profiles.
type ShapeRange#
type ShapeRange struct {
Min []int32
Opt []int32
Max []int32
}Defines min/opt/max dimensions for a single input tensor.
func TRTCacheKey#
func TRTCacheKey(modelID, precision string) (string, error)Builds a deterministic cache key from model ID, precision, and GPU architecture.
func SaveTRTEngine / LoadTRTEngine#
func SaveTRTEngine(key string, data []byte) error
func LoadTRTEngine(key string) ([]byte, error)Write/read serialized TensorRT engines to/from the cache directory. LoadTRTEngine returns nil, nil on cache miss.
type TRTInferenceEngine#
type TRTInferenceEngine struct {
// unexported fields
}Holds a TensorRT engine and execution context for inference.
func (*TRTInferenceEngine) Forward#
func (e *TRTInferenceEngine) Forward(inputs []*tensor.TensorNumeric[float32], outputSize int) (*tensor.TensorNumeric[float32], error)Runs inference through TensorRT with the given input tensors. Input tensors must already be on GPU.
func (*TRTInferenceEngine) Close#
func (e *TRTInferenceEngine) Close() errorReleases all TensorRT resources.
type UnsupportedOpError#
type UnsupportedOpError struct {
Ops []string
}Lists the operations that cannot be converted to TensorRT.
Testing Utilities#
func NewTestModel#
func NewTestModel(
gen *generate.Generator[float32],
tok tokenizer.Tokenizer,
eng compute.Engine[float32],
meta ModelMetadata,
info *registry.ModelInfo,
) *ModelConstructs a Model from pre-built components. Intended for use in external test packages that need a Model without going through the full Load pipeline.
Interfaces#
type ConstantValueGetter#
type ConstantValueGetter interface {
GetValue() *tensor.TensorNumeric[float32]
}Interface for nodes that hold constant tensor data.
type DTypeSetter#
type DTypeSetter interface {
SetDType(compute.DType)
}Implemented by engines that support setting compute precision.