6c5ccb11f9
This refines where we extract the LLM libraries to by adding a new OLLAMA_HOME env var, that defaults to `~/.ollama` The logic was already idempotenent, so this should speed up startups after the first time a new release is deployed. It also cleans up after itself. We now build only a single ROCm version (latest major) on both windows and linux. Given the large size of ROCms tensor files, we split the dependency out. It's bundled into the installer on windows, and a separate download on windows. The linux install script is now smart and detects the presence of AMD GPUs and looks to see if rocm v6 is already present, and if not, then downloads our dependency tar file. For Linux discovery, we now use sysfs and check each GPU against what ROCm supports so we can degrade to CPU gracefully instead of having llama.cpp+rocm assert/crash on us. For Windows, we now use go's windows dynamic library loading logic to access the amdhip64.dll APIs to query the GPU information.
1299 lines
31 KiB
Go
1299 lines
31 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"reflect"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/gin-contrib/cors"
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/exp/slices"
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
"github.com/jmorganca/ollama/gpu"
|
|
"github.com/jmorganca/ollama/llm"
|
|
"github.com/jmorganca/ollama/openai"
|
|
"github.com/jmorganca/ollama/parser"
|
|
"github.com/jmorganca/ollama/version"
|
|
)
|
|
|
|
var mode string = gin.DebugMode
|
|
|
|
type Server struct {
|
|
WorkDir string
|
|
}
|
|
|
|
func init() {
|
|
switch mode {
|
|
case gin.DebugMode:
|
|
case gin.ReleaseMode:
|
|
case gin.TestMode:
|
|
default:
|
|
mode = gin.DebugMode
|
|
}
|
|
|
|
gin.SetMode(mode)
|
|
}
|
|
|
|
var loaded struct {
|
|
mu sync.Mutex
|
|
|
|
runner llm.LLM
|
|
|
|
expireAt time.Time
|
|
expireTimer *time.Timer
|
|
|
|
*Model
|
|
*api.Options
|
|
}
|
|
|
|
var defaultSessionDuration = 5 * time.Minute
|
|
|
|
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
|
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
|
needLoad := loaded.runner == nil || // is there a model loaded?
|
|
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
|
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
|
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
|
|
|
|
if needLoad {
|
|
if loaded.runner != nil {
|
|
slog.Info("changing loaded model")
|
|
loaded.runner.Close()
|
|
loaded.runner = nil
|
|
loaded.Model = nil
|
|
loaded.Options = nil
|
|
}
|
|
|
|
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
|
if err != nil {
|
|
// some older models are not compatible with newer versions of llama.cpp
|
|
// show a generalized compatibility error until there is a better way to
|
|
// check for model compatibility
|
|
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
|
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
loaded.Model = model
|
|
loaded.runner = llmRunner
|
|
loaded.Options = &opts
|
|
}
|
|
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
|
|
if loaded.expireTimer == nil {
|
|
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
|
loaded.mu.Lock()
|
|
defer loaded.mu.Unlock()
|
|
|
|
if time.Now().Before(loaded.expireAt) {
|
|
return
|
|
}
|
|
|
|
if loaded.runner != nil {
|
|
loaded.runner.Close()
|
|
}
|
|
|
|
loaded.runner = nil
|
|
loaded.Model = nil
|
|
loaded.Options = nil
|
|
})
|
|
}
|
|
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
return nil
|
|
}
|
|
|
|
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
|
|
opts := api.DefaultOptions()
|
|
if err := opts.FromMap(model.Options); err != nil {
|
|
return api.Options{}, err
|
|
}
|
|
|
|
if err := opts.FromMap(requestOpts); err != nil {
|
|
return api.Options{}, err
|
|
}
|
|
|
|
return opts, nil
|
|
}
|
|
|
|
func isSupportedImageType(image []byte) bool {
|
|
contentType := http.DetectContentType(image)
|
|
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
|
return slices.Contains(allowedTypes, contentType)
|
|
}
|
|
|
|
func GenerateHandler(c *gin.Context) {
|
|
loaded.mu.Lock()
|
|
defer loaded.mu.Unlock()
|
|
|
|
checkpointStart := time.Now()
|
|
var req api.GenerateRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// validate the request
|
|
switch {
|
|
case req.Model == "":
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
case len(req.Format) > 0 && req.Format != "json":
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
return
|
|
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
|
|
return
|
|
}
|
|
|
|
for _, img := range req.Images {
|
|
if !isSupportedImageType(img) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
return
|
|
}
|
|
}
|
|
|
|
model, err := GetModel(req.Model)
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
if errors.As(err, &pErr) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if model.IsEmbedding() {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
|
|
return
|
|
}
|
|
|
|
opts, err := modelOptions(model, req.Options)
|
|
if err != nil {
|
|
if errors.Is(err, api.ErrInvalidOpts) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var sessionDuration time.Duration
|
|
if req.KeepAlive == nil {
|
|
sessionDuration = defaultSessionDuration
|
|
} else {
|
|
sessionDuration = req.KeepAlive.Duration
|
|
}
|
|
|
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// an empty request loads the model
|
|
// note: for a short while template was used in lieu
|
|
// of `raw` mode so we need to check for it too
|
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
CreatedAt: time.Now().UTC(),
|
|
Model: req.Model,
|
|
Done: true,
|
|
})
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
var prompt string
|
|
switch {
|
|
case req.Raw:
|
|
prompt = req.Prompt
|
|
case req.Prompt != "":
|
|
if req.Template == "" {
|
|
req.Template = model.Template
|
|
}
|
|
|
|
if req.System == "" {
|
|
req.System = model.System
|
|
}
|
|
|
|
slog.Debug("generate handler", "prompt", req.Prompt)
|
|
slog.Debug("generate handler", "template", req.Template)
|
|
slog.Debug("generate handler", "system", req.System)
|
|
|
|
var sb strings.Builder
|
|
for i := range req.Images {
|
|
fmt.Fprintf(&sb, "[img-%d] ", i)
|
|
}
|
|
|
|
sb.WriteString(req.Prompt)
|
|
|
|
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
sb.Reset()
|
|
if req.Context != nil {
|
|
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
sb.WriteString(prev)
|
|
}
|
|
|
|
sb.WriteString(p)
|
|
|
|
prompt = sb.String()
|
|
}
|
|
|
|
slog.Debug("generate handler", "prompt", prompt)
|
|
|
|
ch := make(chan any)
|
|
var generated strings.Builder
|
|
go func() {
|
|
defer close(ch)
|
|
|
|
fn := func(r llm.PredictResult) {
|
|
// Update model expiration
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
|
|
// Build up the full response
|
|
if _, err := generated.WriteString(r.Content); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
|
|
resp := api.GenerateResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Done: r.Done,
|
|
Response: r.Content,
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
PromptEvalDuration: r.PromptEvalDuration,
|
|
EvalCount: r.EvalCount,
|
|
EvalDuration: r.EvalDuration,
|
|
},
|
|
}
|
|
|
|
if r.Done {
|
|
resp.TotalDuration = time.Since(checkpointStart)
|
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
|
|
if !req.Raw {
|
|
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// TODO (jmorganca): encode() should not strip special tokens
|
|
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
|
|
if err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
return
|
|
}
|
|
|
|
resp.Context = append(req.Context, tokens...)
|
|
}
|
|
}
|
|
|
|
ch <- resp
|
|
}
|
|
|
|
var images []llm.ImageData
|
|
for i := range req.Images {
|
|
images = append(images, llm.ImageData{
|
|
ID: i,
|
|
Data: req.Images[i],
|
|
})
|
|
}
|
|
|
|
// Start prediction
|
|
predictReq := llm.PredictOpts{
|
|
Prompt: prompt,
|
|
Format: req.Format,
|
|
Images: images,
|
|
Options: opts,
|
|
}
|
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
// Accumulate responses into the final response
|
|
var final api.GenerateResponse
|
|
var sb strings.Builder
|
|
for resp := range ch {
|
|
switch r := resp.(type) {
|
|
case api.GenerateResponse:
|
|
sb.WriteString(r.Response)
|
|
final = r
|
|
case gin.H:
|
|
if errorMsg, ok := r["error"].(string); ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
return
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
return
|
|
}
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
return
|
|
}
|
|
}
|
|
|
|
final.Response = sb.String()
|
|
c.JSON(http.StatusOK, final)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func EmbeddingsHandler(c *gin.Context) {
|
|
loaded.mu.Lock()
|
|
defer loaded.mu.Unlock()
|
|
|
|
var req api.EmbeddingRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Model == "" {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
model, err := GetModel(req.Model)
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
if errors.As(err, &pErr) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
opts, err := modelOptions(model, req.Options)
|
|
if err != nil {
|
|
if errors.Is(err, api.ErrInvalidOpts) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var sessionDuration time.Duration
|
|
if req.KeepAlive == nil {
|
|
sessionDuration = defaultSessionDuration
|
|
} else {
|
|
sessionDuration = req.KeepAlive.Duration
|
|
}
|
|
|
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// an empty request loads the model
|
|
if req.Prompt == "" {
|
|
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
|
return
|
|
}
|
|
|
|
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
return
|
|
}
|
|
|
|
resp := api.EmbeddingResponse{
|
|
Embedding: embedding,
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func PullModelHandler(c *gin.Context) {
|
|
var req api.PullRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var model string
|
|
if req.Model != "" {
|
|
model = req.Model
|
|
} else if req.Name != "" {
|
|
model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
fn := func(r api.ProgressResponse) {
|
|
ch <- r
|
|
}
|
|
|
|
regOpts := ®istryOptions{
|
|
Insecure: req.Insecure,
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
defer cancel()
|
|
|
|
if err := PullModel(ctx, model, regOpts, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
waitForStream(c, ch)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func PushModelHandler(c *gin.Context) {
|
|
var req api.PushRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var model string
|
|
if req.Model != "" {
|
|
model = req.Model
|
|
} else if req.Name != "" {
|
|
model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
fn := func(r api.ProgressResponse) {
|
|
ch <- r
|
|
}
|
|
|
|
regOpts := ®istryOptions{
|
|
Insecure: req.Insecure,
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
defer cancel()
|
|
|
|
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
waitForStream(c, ch)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func CreateModelHandler(c *gin.Context) {
|
|
var req api.CreateRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var model string
|
|
if req.Model != "" {
|
|
model = req.Model
|
|
} else if req.Name != "" {
|
|
model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
if err := ParseModelPath(model).Validate(); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Path == "" && req.Modelfile == "" {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
|
return
|
|
}
|
|
|
|
var modelfile io.Reader = strings.NewReader(req.Modelfile)
|
|
if req.Path != "" && req.Modelfile == "" {
|
|
mf, err := os.Open(req.Path)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
|
|
return
|
|
}
|
|
defer mf.Close()
|
|
|
|
modelfile = mf
|
|
}
|
|
|
|
commands, err := parser.Parse(modelfile)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
ch := make(chan any)
|
|
go func() {
|
|
defer close(ch)
|
|
fn := func(resp api.ProgressResponse) {
|
|
ch <- resp
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(c.Request.Context())
|
|
defer cancel()
|
|
|
|
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
waitForStream(c, ch)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|
|
|
|
func DeleteModelHandler(c *gin.Context) {
|
|
var req api.DeleteRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var model string
|
|
if req.Model != "" {
|
|
model = req.Model
|
|
} else if req.Name != "" {
|
|
model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
if err := DeleteModel(model); err != nil {
|
|
if os.IsNotExist(err) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
manifestsPath, err := GetManifestPath()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if err := PruneDirectory(manifestsPath); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, nil)
|
|
}
|
|
|
|
func ShowModelHandler(c *gin.Context) {
|
|
var req api.ShowRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Model != "" {
|
|
// noop
|
|
} else if req.Name != "" {
|
|
req.Model = req.Name
|
|
} else {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
}
|
|
|
|
resp, err := GetModelInfo(req)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|
model, err := GetModel(req.Model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
modelDetails := api.ModelDetails{
|
|
ParentModel: model.ParentModel,
|
|
Format: model.Config.ModelFormat,
|
|
Family: model.Config.ModelFamily,
|
|
Families: model.Config.ModelFamilies,
|
|
ParameterSize: model.Config.ModelType,
|
|
QuantizationLevel: model.Config.FileType,
|
|
}
|
|
|
|
if req.System != "" {
|
|
model.System = req.System
|
|
}
|
|
|
|
if req.Template != "" {
|
|
model.Template = req.Template
|
|
}
|
|
|
|
msgs := make([]api.Message, 0)
|
|
for _, msg := range model.Messages {
|
|
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
|
|
}
|
|
|
|
resp := &api.ShowResponse{
|
|
License: strings.Join(model.License, "\n"),
|
|
System: model.System,
|
|
Template: model.Template,
|
|
Details: modelDetails,
|
|
Messages: msgs,
|
|
}
|
|
|
|
var params []string
|
|
cs := 30
|
|
for k, v := range model.Options {
|
|
switch val := v.(type) {
|
|
case []interface{}:
|
|
for _, nv := range val {
|
|
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
|
|
}
|
|
default:
|
|
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
|
|
}
|
|
}
|
|
resp.Parameters = strings.Join(params, "\n")
|
|
|
|
for k, v := range req.Options {
|
|
if _, ok := req.Options[k]; ok {
|
|
model.Options[k] = v
|
|
}
|
|
}
|
|
|
|
mf, err := ShowModelfile(model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp.Modelfile = mf
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func ListModelsHandler(c *gin.Context) {
|
|
models := make([]api.ModelResponse, 0)
|
|
manifestsPath, err := GetManifestPath()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
modelResponse := func(modelName string) (api.ModelResponse, error) {
|
|
model, err := GetModel(modelName)
|
|
if err != nil {
|
|
return api.ModelResponse{}, err
|
|
}
|
|
|
|
modelDetails := api.ModelDetails{
|
|
Format: model.Config.ModelFormat,
|
|
Family: model.Config.ModelFamily,
|
|
Families: model.Config.ModelFamilies,
|
|
ParameterSize: model.Config.ModelType,
|
|
QuantizationLevel: model.Config.FileType,
|
|
}
|
|
|
|
return api.ModelResponse{
|
|
Model: model.ShortName,
|
|
Name: model.ShortName,
|
|
Size: model.Size,
|
|
Digest: model.Digest,
|
|
Details: modelDetails,
|
|
}, nil
|
|
}
|
|
|
|
walkFunc := func(path string, info os.FileInfo, _ error) error {
|
|
if !info.IsDir() {
|
|
path, tag := filepath.Split(path)
|
|
model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
|
|
modelPath := strings.Join([]string{model, tag}, ":")
|
|
canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
|
|
|
|
resp, err := modelResponse(canonicalModelPath)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
|
|
// nolint: nilerr
|
|
return nil
|
|
}
|
|
|
|
resp.ModifiedAt = info.ModTime()
|
|
models = append(models, resp)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
|
}
|
|
|
|
func CopyModelHandler(c *gin.Context) {
|
|
var req api.CopyRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if req.Source == "" || req.Destination == "" {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
|
|
return
|
|
}
|
|
|
|
if err := ParseModelPath(req.Destination).Validate(); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if err := CopyModel(req.Source, req.Destination); err != nil {
|
|
if os.IsNotExist(err) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
func HeadBlobHandler(c *gin.Context) {
|
|
path, err := GetBlobsPath(c.Param("digest"))
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if _, err := os.Stat(path); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
|
|
return
|
|
}
|
|
|
|
c.Status(http.StatusOK)
|
|
}
|
|
|
|
func CreateBlobHandler(c *gin.Context) {
|
|
layer, err := NewLayer(c.Request.Body, "")
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if layer.Digest != c.Param("digest") {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
|
|
return
|
|
}
|
|
|
|
if _, err := layer.Commit(); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Status(http.StatusCreated)
|
|
}
|
|
|
|
var defaultAllowOrigins = []string{
|
|
"localhost",
|
|
"127.0.0.1",
|
|
"0.0.0.0",
|
|
}
|
|
|
|
func NewServer() (*Server, error) {
|
|
workDir, err := os.MkdirTemp("", "ollama")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Server{
|
|
WorkDir: workDir,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) GenerateRoutes() http.Handler {
|
|
var origins []string
|
|
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
|
|
origins = strings.Split(o, ",")
|
|
}
|
|
|
|
config := cors.DefaultConfig()
|
|
config.AllowWildcard = true
|
|
config.AllowBrowserExtensions = true
|
|
|
|
config.AllowOrigins = origins
|
|
for _, allowOrigin := range defaultAllowOrigins {
|
|
config.AllowOrigins = append(config.AllowOrigins,
|
|
fmt.Sprintf("http://%s", allowOrigin),
|
|
fmt.Sprintf("https://%s", allowOrigin),
|
|
fmt.Sprintf("http://%s:*", allowOrigin),
|
|
fmt.Sprintf("https://%s:*", allowOrigin),
|
|
)
|
|
}
|
|
|
|
r := gin.Default()
|
|
r.Use(
|
|
cors.New(config),
|
|
func(c *gin.Context) {
|
|
c.Set("workDir", s.WorkDir)
|
|
c.Next()
|
|
},
|
|
)
|
|
|
|
r.POST("/api/pull", PullModelHandler)
|
|
r.POST("/api/generate", GenerateHandler)
|
|
r.POST("/api/chat", ChatHandler)
|
|
r.POST("/api/embeddings", EmbeddingsHandler)
|
|
r.POST("/api/create", CreateModelHandler)
|
|
r.POST("/api/push", PushModelHandler)
|
|
r.POST("/api/copy", CopyModelHandler)
|
|
r.DELETE("/api/delete", DeleteModelHandler)
|
|
r.POST("/api/show", ShowModelHandler)
|
|
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
|
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
|
|
|
|
// Compatibility endpoints
|
|
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
|
|
|
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
|
r.Handle(method, "/", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "Ollama is running")
|
|
})
|
|
|
|
r.Handle(method, "/api/tags", ListModelsHandler)
|
|
r.Handle(method, "/api/version", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"version": version.Version})
|
|
})
|
|
}
|
|
|
|
return r
|
|
}
|
|
|
|
func Serve(ln net.Listener) error {
|
|
level := slog.LevelInfo
|
|
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
|
level = slog.LevelDebug
|
|
}
|
|
|
|
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
|
Level: level,
|
|
AddSource: true,
|
|
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
|
if attr.Key == slog.SourceKey {
|
|
source := attr.Value.Any().(*slog.Source)
|
|
source.File = filepath.Base(source.File)
|
|
}
|
|
|
|
return attr
|
|
},
|
|
})
|
|
|
|
slog.SetDefault(slog.New(handler))
|
|
|
|
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
|
// clean up unused layers and manifests
|
|
if err := PruneLayers(); err != nil {
|
|
return err
|
|
}
|
|
|
|
manifestsPath, err := GetManifestPath()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := PruneDirectory(manifestsPath); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
s, err := NewServer()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r := s.GenerateRoutes()
|
|
|
|
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
|
srvr := &http.Server{
|
|
Handler: r,
|
|
}
|
|
|
|
// listen for a ctrl+c and stop any loaded llm
|
|
signals := make(chan os.Signal, 1)
|
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
|
go func() {
|
|
<-signals
|
|
if loaded.runner != nil {
|
|
loaded.runner.Close()
|
|
}
|
|
os.RemoveAll(s.WorkDir)
|
|
os.Exit(0)
|
|
}()
|
|
|
|
if err := llm.Init(); err != nil {
|
|
return fmt.Errorf("unable to initialize llm library %w", err)
|
|
}
|
|
if runtime.GOOS == "linux" { // TODO - windows too
|
|
// check compatibility to log warnings
|
|
if _, err := gpu.CheckVRAM(); err != nil {
|
|
slog.Info(err.Error())
|
|
}
|
|
}
|
|
|
|
return srvr.Serve(ln)
|
|
}
|
|
|
|
func waitForStream(c *gin.Context, ch chan interface{}) {
|
|
c.Header("Content-Type", "application/json")
|
|
for resp := range ch {
|
|
switch r := resp.(type) {
|
|
case api.ProgressResponse:
|
|
if r.Status == "success" {
|
|
c.JSON(http.StatusOK, r)
|
|
return
|
|
}
|
|
case gin.H:
|
|
if errorMsg, ok := r["error"].(string); ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
return
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
|
|
return
|
|
}
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
|
|
return
|
|
}
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
|
|
}
|
|
|
|
func streamResponse(c *gin.Context, ch chan any) {
|
|
c.Header("Content-Type", "application/x-ndjson")
|
|
c.Stream(func(w io.Writer) bool {
|
|
val, ok := <-ch
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
bts, err := json.Marshal(val)
|
|
if err != nil {
|
|
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
|
|
return false
|
|
}
|
|
|
|
// Delineate chunks with new-line delimiter
|
|
bts = append(bts, '\n')
|
|
if _, err := w.Write(bts); err != nil {
|
|
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
|
|
return false
|
|
}
|
|
|
|
return true
|
|
})
|
|
}
|
|
|
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
|
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
|
encode := func(s string) ([]int, error) {
|
|
return loaded.runner.Encode(ctx, s)
|
|
}
|
|
|
|
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return prompt, nil
|
|
}
|
|
|
|
func ChatHandler(c *gin.Context) {
|
|
loaded.mu.Lock()
|
|
defer loaded.mu.Unlock()
|
|
|
|
checkpointStart := time.Now()
|
|
|
|
var req api.ChatRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
|
return
|
|
case err != nil:
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// validate the request
|
|
switch {
|
|
case req.Model == "":
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
|
return
|
|
case len(req.Format) > 0 && req.Format != "json":
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
|
return
|
|
}
|
|
|
|
model, err := GetModel(req.Model)
|
|
if err != nil {
|
|
var pErr *fs.PathError
|
|
if errors.As(err, &pErr) {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if model.IsEmbedding() {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
|
|
return
|
|
}
|
|
|
|
opts, err := modelOptions(model, req.Options)
|
|
if err != nil {
|
|
if errors.Is(err, api.ErrInvalidOpts) {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
var sessionDuration time.Duration
|
|
if req.KeepAlive == nil {
|
|
sessionDuration = defaultSessionDuration
|
|
} else {
|
|
sessionDuration = req.KeepAlive.Duration
|
|
}
|
|
|
|
if err := load(c, model, opts, sessionDuration); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
checkpointLoaded := time.Now()
|
|
|
|
// if the first message is not a system message, then add the model's default system message
|
|
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
|
|
req.Messages = append([]api.Message{
|
|
{
|
|
Role: "system",
|
|
Content: model.System,
|
|
},
|
|
}, req.Messages...)
|
|
}
|
|
|
|
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// an empty request loads the model
|
|
if len(req.Messages) == 0 || prompt == "" {
|
|
resp := api.ChatResponse{
|
|
CreatedAt: time.Now().UTC(),
|
|
Model: req.Model,
|
|
Done: true,
|
|
Message: api.Message{Role: "assistant"},
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
return
|
|
}
|
|
|
|
// only send images that are in the prompt
|
|
var i int
|
|
var images []llm.ImageData
|
|
for _, m := range req.Messages {
|
|
for _, img := range m.Images {
|
|
if !isSupportedImageType(img) {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
|
|
return
|
|
}
|
|
|
|
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
|
|
images = append(images, llm.ImageData{Data: img, ID: i})
|
|
}
|
|
i += 1
|
|
}
|
|
}
|
|
|
|
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
|
|
|
|
ch := make(chan any)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
|
|
fn := func(r llm.PredictResult) {
|
|
// Update model expiration
|
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
|
loaded.expireTimer.Reset(sessionDuration)
|
|
|
|
resp := api.ChatResponse{
|
|
Model: req.Model,
|
|
CreatedAt: time.Now().UTC(),
|
|
Message: api.Message{Role: "assistant", Content: r.Content},
|
|
Done: r.Done,
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: r.PromptEvalCount,
|
|
PromptEvalDuration: r.PromptEvalDuration,
|
|
EvalCount: r.EvalCount,
|
|
EvalDuration: r.EvalDuration,
|
|
},
|
|
}
|
|
|
|
if r.Done {
|
|
resp.TotalDuration = time.Since(checkpointStart)
|
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
}
|
|
|
|
ch <- resp
|
|
}
|
|
|
|
// Start prediction
|
|
predictReq := llm.PredictOpts{
|
|
Prompt: prompt,
|
|
Format: req.Format,
|
|
Images: images,
|
|
Options: opts,
|
|
}
|
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
|
ch <- gin.H{"error": err.Error()}
|
|
}
|
|
}()
|
|
|
|
if req.Stream != nil && !*req.Stream {
|
|
// Accumulate responses into the final response
|
|
var final api.ChatResponse
|
|
var sb strings.Builder
|
|
for resp := range ch {
|
|
switch r := resp.(type) {
|
|
case api.ChatResponse:
|
|
sb.WriteString(r.Message.Content)
|
|
final = r
|
|
case gin.H:
|
|
if errorMsg, ok := r["error"].(string); ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
return
|
|
} else {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
|
|
return
|
|
}
|
|
default:
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
|
|
return
|
|
}
|
|
}
|
|
|
|
final.Message = api.Message{Role: "assistant", Content: sb.String()}
|
|
c.JSON(http.StatusOK, final)
|
|
return
|
|
}
|
|
|
|
streamResponse(c, ch)
|
|
}
|