ollama/llama/runner/runner.go
Jeffrey Morgan 96efd9052f
Re-introduce the llama package (#5034)
* Re-introduce the llama package

This PR brings back the llama package, making it possible to call llama.cpp and
ggml APIs from Go directly via CGo. This has a few advantages:

- C APIs can be called directly from Go without needing to use the previous
  "server" REST API
- On macOS and for CPU builds on Linux and Windows, Ollama can be built without
  a go generate ./... step, making it easy to get up and running to hack on
  parts of Ollama that don't require fast inference
- Faster build times for AVX,AVX2,CUDA and ROCM (a full build of all runners
  takes <5 min on a fast CPU)
- No git submodule making it easier to clone and build from source

This is a big PR, but much of it is vendor code except for:

- llama.go CGo bindings
- example/: a simple example of running inference
- runner/: a subprocess server designed to replace the llm/ext_server package
- Makefile an as minimal as possible Makefile to build the runner package for
  different targets (cpu, avx, avx2, cuda, rocm)

Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>

* cache: Clear old KV cache entries when evicting a slot

When forking a cache entry, if no empty slots are available we
evict the least recently used one and copy over the KV entries
from the closest match. However, this copy does not overwrite
existing values but only adds new ones. Therefore, we need to
clear the old slot first.

This change fixes two issues:
 - The KV cache fills up and runs out of space even though we think
   we are managing it correctly
 - Performance gets worse over time as we use new cache entries that
   are not hot in the processor caches

* doc: explain golang objc linker warning (#6830)

* llama: gather transitive dependencies for rocm for dist packaging (#6848)

* Refine go server makefiles to be more DRY (#6924)

This breaks up the monolithic Makefile for the Go based runners into a
set of utility files as well as recursive Makefiles for the runners.
Files starting with the name "Makefile" are buildable, while files that
end with ".make" are utilities to include in other Makefiles.  This
reduces the amount of nearly identical targets and helps set a pattern
for future community contributions for new GPU runner architectures.

When we are ready to switch over to the Go runners, these files should
move to the top of the repo, and we should add targets for the main CLI,
as well as a helper "install" (put all the built binaries on the local
system in a runnable state) and "dist" target (generate the various
tar/zip files for distribution) for local developer use.

* llama: don't create extraneous directories (#6988)

* llama: Exercise the new build in CI (#6989)

Wire up some basic sanity testing in CI for the Go runner.  GPU runners are not covered yet.

* llama: Refine developer docs for Go server (#6842)

This enhances the documentation for development focusing on the new Go
server.  After we complete the transition further doc refinements
can remove the "transition" discussion.

* runner.go: Allocate batches for all sequences during init

We should tell the model that we could have full batches for all
sequences. We already do this when we allocate the batches but it was
missed during initialization.

* llama.go: Don't return nil from Tokenize on zero length input

Potentially receiving nil in a non-error condition is surprising to
most callers - it's better to return an empty slice.

* runner.go: Remove stop tokens from cache

If the last token is EOG then we don't return this and it isn't
present in the cache (because it was never submitted to Decode).
This works well for extending the cache entry with a new sequence.

However, for multi-token stop sequences, we won't return any of the
tokens but all but the last one will be in the cache. This means
when the conversation continues the cache will contain tokens that
don't overlap with the new prompt.

This works (we will pick up the portion where there is overlap) but
it causes unnecessary cache thrashing because we will fork the original
cache entry as it is not a perfect match.

By trimming the cache to the tokens that we actually return this
issue can be avoided.

* runner.go: Simplify flushing of pending tokens

* runner.go: Update TODOs

* runner.go: Don't panic when processing sequences

If there is an error processing a sequence, we should return a
clean HTTP error back to Ollama rather than panicing. This will
make us more resilient to transient failures.

Panics can still occur during startup as there is no way to serve
requests if that fails.

Co-authored-by: jmorganca <jmorganca@gmail.com>

* runner.go: More accurately capture timings

Currently prompt processing time doesn't capture the that it takes
to tokenize the input, only decoding time. We should capture the
full process to more accurately reflect reality. This is especially
true once we start processing images where the initial processing
can take significant time. This is also more consistent with the
existing C++ runner.

* runner.go: Support for vision models

In addition to bringing feature parity with the C++ runner, this also
incorporates several improvements:
 - Cache prompting works with images, avoiding the need to re-decode
   embeddings for every message in a conversation
 - Parallelism is supported, avoiding the need to restrict to one
   sequence at a time. (Though for now Ollama will not schedule
   them while we might need to fall back to the old runner.)

Co-authored-by: jmorganca <jmorganca@gmail.com>

* runner.go: Move Unicode checking code and add tests

* runner.go: Export external cache members

Runner and cache are in the same package so the change doesn't
affect anything but it is more internally consistent.

* runner.go: Image embedding cache

Generating embeddings from images can take significant time (on
my machine between 100ms and 8s depending on the model). Although
we already cache the result of decoding these images, the embeddings
need to be regenerated every time. This is not necessary if we get
the same image over and over again, for example, during a conversation.

This currently uses a very small cache with a very simple algorithm
but it is easy to improve as is warranted.

* llama: catch up on patches

Carry forward solar-pro and cli-unicode patches

* runner.go: Don't re-allocate memory for every batch

We can reuse memory allocated from batch to batch since batch
size is fixed. This both saves the cost of reallocation as well
keeps the cache lines hot.

This results in a roughly 1% performance improvement for token
generation with Nvidia GPUs on Linux.

* runner.go: Default to classic input cache policy

The input cache as part of the go runner implemented a cache
policy that aims to maximize hit rate in both single and multi-
user scenarios. When there is a cache hit, the response is
very fast.

However, performance is actually slower when there is an input
cache miss due to worse GPU VRAM locality. This means that
performance is generally better overall for multi-user scenarios
(better input cache hit rate, locality was relatively poor already).
But worse for single users (input cache hit rate is about the same,
locality is now worse).

This defaults the policy back to the old one to avoid a regression
but keeps the new one available through an environment variable
OLLAMA_MULTIUSER_CACHE. This is left undocumented as the goal is
to improve this in the future to get the best of both worlds
without user configuration.

For inputs that result in cache misses, on Nvidia/Linux this
change improves performance by 31% for prompt processing and
13% for token generation.

* runner.go: Increase size of response channel

Generally the CPU can easily keep up with handling reponses that
are generated but there's no reason not to let generation continue
and handle things in larger batches if needed.

* llama: Add CI to verify all vendored changes have patches (#7066)

Make sure we don't accidentally merge changes in the vendored code
that aren't also reflected in the patches.

* llama: adjust clip patch for mingw utf-16 (#7065)

* llama: adjust clip patch for mingw utf-16

* llama: ensure static linking of runtime libs

Avoid runtime dependencies on non-standard libraries

* runner.go: Enable llamafile (all platforms) and BLAS (Mac OS)

These are two features that are shown on llama.cpp's system info
that are currently different between the two runners. On my test
systems the performance difference is very small to negligible
but it is probably still good to equalize the features.

* llm: Don't add BOS/EOS for tokenize requests

This is consistent with what server.cpp currently does. It affects
things like token processing counts for embedding requests.

* runner.go: Don't cache prompts for embeddings

Our integration with server.cpp implicitly disables prompt caching
because it is not part of the JSON object being parsed, this makes
the Go runner behavior similarly.

Prompt caching has been seen to affect the results of text completions
on certain hardware. The results are not wrong either way but they
are non-deterministic. However, embeddings seem to be affected even
on hardware that does not show this behavior for completions. For
now, it is best to maintain consistency with the existing behavior.

* runner.go: Adjust debug log levels

Add system info printed at startup and quiet down noisier logging.

* llama: fix compiler flag differences (#7082)

Adjust the flags for the new Go server to more closely match the
generate flow

* llama: refine developer docs (#7121)

* llama: doc and example clean up (#7122)

* llama: doc and example clean up

* llama: Move new dockerfile into llama dir

Temporary home until we fully transition to the Go server

* llama: runner doc cleanup

* llama.go: Add description for Tokenize error case

---------

Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Daniel Hiltgen <dhiltgen@users.noreply.github.com>
2024-10-08 08:53:54 -07:00

878 lines
23 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"log"
"log/slog"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
// input is an element of the prompt to process, either
// a token or an image embedding (generated from a vision projector)
type input struct {
token int
// embed is an image embedding
embed []float32
}
type Sequence struct {
// number of inputs evaluated
numPast int
// batch index
iBatch int
// number of tokens predicted so far
numPredicted int
// prompt inputs left to evaluate
inputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
// input cache being used by this sequence
cache *InputCacheSlot
// channel to send responses over
responses chan string
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
// number of tokens to predict
numPredict int
samplingCtx *llama.SamplingContext
// channel to send back the embedding if embedding only
embedding chan []float32
// stop sequences
stop []string
// number of inputs to keep at the beginning when shifting context window
numKeep int
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
doneReason string
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
numDecoded int
numPromptInputs int
}
type NewSequenceParams struct {
numPredict int
stop []string
numKeep int
samplingParams *llama.SamplingParams
embedding bool
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
inputs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
return nil, errors.New("no input provided")
}
if params.numKeep < 0 {
params.numKeep = len(inputs)
}
if !params.embedding {
// Subtracting 4 ensures that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-4)
params.numKeep += s.bosToken
} else {
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
params.numKeep = min(params.numKeep, s.cache.numCtx)
}
// truncate to fit in context window
if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
inputs = newInputs
}
var sc *llama.SamplingContext
if params.samplingParams != nil {
sc = llama.NewSamplingContext(*params.samplingParams)
for _, input := range inputs {
if input.embed == nil {
sc.Accept(s.lc, input.token, false)
}
}
}
return &Sequence{
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}, nil
}
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts := re.Split(prompt, -1)
matches := re.FindAllStringSubmatch(prompt, -1)
for i, part := range parts {
// text - tokenize
if strings.TrimSpace(part) != "" {
tokens, err := s.lc.Model().Tokenize(part, i == 0, true)
if err != nil {
return nil, err
}
for _, t := range tokens {
inputs = append(inputs, input{token: t})
}
}
// image - generate image embedding
if i < len(matches) {
n, _ := strconv.Atoi(matches[i][1])
imageIndex := -1
for j := range images {
if images[j].ID == n {
imageIndex = j
break
}
}
if imageIndex < 0 {
return nil, fmt.Errorf("invalid image index: %d", n)
}
hash := s.cache.HashImage(images[imageIndex].Data)
// Vision models cannot be accessed concurrently
s.clip.mu.Lock()
embed, err := s.cache.FindImage(hash)
if err != nil {
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
s.cache.AddImage(hash, embed)
}
s.clip.mu.Unlock()
for _, e := range embed {
inputs = append(inputs, input{embed: e})
}
}
}
return inputs, nil
}
type clip struct {
cc *llama.ClipContext
mu sync.Mutex
}
type Server struct {
model *llama.Model
lc *llama.Context
// required for image embeddings
clip clip
batchSize int
// parallel is the number of parallel requests to handle
parallel int
// seqs is the list of parallel sequences being evaluated
// TODO (jmorganca): this can probably be moved into run()
seqs []*Sequence
// KV cache
cache *InputCache
// does this model require a beginning of sequence token?
bosToken int
// next sequence for prompt processing to avoid starvation
nextSeq int
// is the server ready to process requests?
ready sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
progress float32
status ServerStatus
}
func (s *Server) allNil() bool {
for _, item := range s.seqs {
if item != nil {
return false
}
}
return true
}
func (s *Server) shiftContext(seq *Sequence) {
numLeft := seq.numPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
seq.numPast -= numDiscard
}
func flushPending(seq *Sequence) bool {
for _, p := range seq.pendingResponses {
select {
case seq.responses <- p:
case <-seq.quit:
seq.pendingResponses = []string{}
return false
}
}
seq.pendingResponses = []string{}
return true
}
func (s *Server) removeSequence(seqIndex int, reason string) {
seq := s.seqs[seqIndex]
flushPending(seq)
seq.doneReason = reason
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
s.seqs[seqIndex] = nil
}
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
// logically these batches are used only within the context of processBatch
// but it is better for performance to allocate them once here
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
defer tokenBatch.Free()
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs))
defer embedBatch.Free()
for {
select {
case <-ctx.Done():
return
default:
s.processBatch(tokenBatch, embedBatch)
tokenBatch.Clear()
embedBatch.Clear()
}
}
}
// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
// * metrics
// these should instead be handled by the handlers
// it should only be responsible for accepting tokens or embeddings and
// processing batches as fast as possible
func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) {
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
defer s.mu.Unlock()
var batch *llama.Batch
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
s.removeSequence(seqIdx, "limit")
continue
}
if seq.numPast+len(seq.inputs) > s.cache.numCtx {
s.shiftContext(seq)
}
var numInputsProcessed int
for i, input := range seq.inputs {
embedding := input.embed != nil
// If we don't currently have a batch, use one of the correct type and
// fill it up as much as possible across all sequences. If we encounter an
// input of the opppsite type, stop for that sequence but then pick up from
// there for the next batch, ensuring that we alternate types
if batch == nil {
if !embedding {
batch = tokenBatch
} else {
batch = embedBatch
}
} else if embedding != batch.IsEmbedding() {
s.nextSeq = seqIdx
break
}
// todo: make this n_batch
if i >= s.batchSize {
break
}
batch.Add(input.token, input.embed, seq.numPast, []int{seq.cache.Id}, numInputsProcessed+1 == len(seq.inputs))
seq.numPast++
numInputsProcessed++
}
if numInputsProcessed > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
seq.inputs = seq.inputs[numInputsProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
}
if batch == nil || batch.NumTokens() == 0 {
return
}
err := s.lc.Decode(batch)
if err != nil {
slog.Error("failed to decode batch", "error", err)
return
}
for i, seq := range s.seqs {
if seq == nil {
continue
}
// don't sample prompt processing
if len(seq.inputs) != 0 {
continue
}
seq.numDecoded += 1
if seq.numDecoded == 1 {
seq.startGenerationTime = time.Now()
}
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
embed := s.lc.GetEmbeddingsSeq(i)
if embed == nil {
embed = s.lc.GetEmbeddingsIth(seq.iBatch)
}
seq.embedding <- embed
s.removeSequence(i, "")
continue
}
// sample a token
token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
seq.samplingCtx.Accept(s.lc, token, true)
piece := s.model.TokenToPiece(token)
seq.numPredicted++
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, "stop")
continue
}
seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "stop", seq.stop)
trimCacheLen := len(seq.pendingResponses) - 1
seq.pendingResponses = truncateStop(seq.pendingResponses, stop)
trimCacheLen -= len(seq.pendingResponses)
// remove any tokens from the cache that we don't actually return
seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen]
s.removeSequence(i, "stop")
continue
}
if containsStopSuffix(sequence, seq.stop) {
continue
}
if incompleteUnicode(sequence) {
continue
}
if !flushPending(seq) {
s.removeSequence(i, "connection")
}
}
}
// TODO (jmorganca): use structs from the api package to avoid duplication
// this way the api acts as a proxy instead of using a different api for the
// runner
type Options struct {
api.Runner
NumKeep int `json:"n_keep"`
Seed int `json:"seed"`
NumPredict int `json:"n_predict"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
TFSZ float32 `json:"tfs_z"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
Temperature float32 `json:"temperature"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
PenalizeNewline bool `json:"penalize_nl"`
Stop []string `json:"stop"`
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}
type CompletionRequest struct {
Prompt string `json:"prompt"`
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Options
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
PromptN int `json:"prompt_n,omitempty"`
PromptMS float64 `json:"prompt_ms,omitempty"`
Timings Timings `json:"timings"`
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var req CompletionRequest
req.Options = Options(api.DefaultOptions())
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
var samplingParams llama.SamplingParams
samplingParams.TopK = req.TopK
samplingParams.TopP = req.TopP
samplingParams.MinP = req.MinP
samplingParams.TfsZ = req.TFSZ
samplingParams.TypicalP = req.TypicalP
samplingParams.Temp = req.Temperature
samplingParams.RepeatLastN = req.RepeatLastN
samplingParams.PenaltyRepeat = req.RepeatPenalty
samplingParams.PenaltyFreq = req.FrequencyPenalty
samplingParams.PenaltyPresent = req.PresencePenalty
samplingParams.Mirostat = req.Mirostat
samplingParams.MirostatTau = req.MirostatTau
samplingParams.MirostatEta = req.MirostatEta
samplingParams.PenalizeNl = req.PenalizeNewline
samplingParams.Seed = uint32(req.Seed)
samplingParams.Grammar = req.Grammar
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// TODO (jmorganca): add to sequence queue instead of
// failing if a slot isn't available
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
break
}
}
s.mu.Unlock()
for {
select {
case <-r.Context().Done():
close(seq.quit)
return
case content, ok := <-seq.responses:
if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
}
flusher.Flush()
} else {
// Send the final response
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Stop: true,
StoppedLimit: seq.doneReason == "limit",
Timings: Timings{
PromptN: seq.numPromptInputs,
PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()),
PredictedN: seq.numDecoded,
PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()),
},
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
return
}
}
}
}
type EmbeddingRequest struct {
Content string `json:"content"`
CachePrompt bool `json:"cache_prompt"`
}
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// TODO (jessegross): Wait for a free slot instead of failing and blocking forever
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
break
}
}
s.mu.Unlock()
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct {
Status string `json:"status"`
Progress float32 `json:"progress"`
}
type ServerStatus int
const (
ServerStatusReady ServerStatus = iota
ServerStatusLoadingModel
ServerStatusError
)
func (s ServerStatus) ToString() string {
switch s {
case ServerStatusReady:
return "ok"
case ServerStatusLoadingModel:
return "loading model"
default:
return "server error"
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&HealthResponse{
Status: s.status.ToString(),
Progress: s.progress,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) loadModel(
params llama.ModelParams,
mpath string,
lpath string,
ppath string,
kvSize int,
flashAttention bool,
threads int,
multiUserCache bool,
) {
llama.BackendInit()
s.model = llama.LoadModelFromFile(mpath, params)
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
s.lc = llama.NewContextWithModel(s.model, ctxParams)
if lpath != "" {
err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
if err != nil {
panic(err)
}
}
if s.model.AddBOSToken() {
s.bosToken = 1
}
if ppath != "" {
s.clip.cc = llama.NewClipContext(ppath)
}
s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
s.status = ServerStatusReady
s.ready.Done()
}
func main() {
mpath := flag.String("model", "", "Path to model binary file")
ppath := flag.String("mmproj", "", "Path to projector binary file")
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := flag.Int("batch-size", 512, "Batch size")
nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
noMmap := flag.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
mlock := flag.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
tensorSplit := flag.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
// Expose requirements as a JSON output to stdout
requirements := flag.Bool("requirements", false, "print json requirement information")
// These are either ignored by llama.cpp or have no significance to us
_ = flag.Bool("embedding", false, "enable embedding vector output (default: disabled)")
_ = flag.Bool("log-disable", false, "disables logging to a file")
_ = flag.Bool("memory-f32", false, "use f32 instead of f16 for memory key+value (default: disabled) not recommended: doubles context memory required and no measurable increase in quality")
flag.Parse()
if *requirements {
printRequirements(os.Stdout)
return
}
level := slog.LevelInfo
if *verbose {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.Info("starting go runner")
slog.Debug("system info", "cpu", llama.PrintSystemInfo(), "threads", *threads)
server := &Server{
batchSize: *batchSize,
parallel: *parallel,
seqs: make([]*Sequence, *parallel),
status: ServerStatusLoadingModel,
}
var tensorSplitFloats []float32
if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
tensorSplitFloats = make([]float32, 0, len(stringFloats))
for _, s := range stringFloats {
f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f))
}
}
params := llama.ModelParams{
NumGpuLayers: *nGpuLayers,
MainGpu: *mainGpu,
UseMmap: !*noMmap && *lpath == "",
UseMlock: *mlock,
TensorSplit: tensorSplitFloats,
Progress: func(progress float32) {
server.progress = progress
},
}
server.ready.Add(1)
go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
return
}
defer listener.Close()
mux := http.NewServeMux()
mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("/completion", server.completion)
mux.HandleFunc("/health", server.health)
httpServer := http.Server{
Handler: mux,
}
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
}
cancel()
}