ollama/server/routes.go

1397 lines
33 KiB
Go
Raw Normal View History

package server
import (
"context"
2023-07-06 10:40:11 -07:00
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
"net/http"
"net/netip"
"os"
"os/signal"
2023-07-14 17:27:14 -07:00
"path/filepath"
2023-07-31 21:35:18 -04:00
"reflect"
"runtime"
"strconv"
2023-07-06 10:40:11 -07:00
"strings"
2023-07-18 11:59:42 -07:00
"sync"
"syscall"
2023-07-12 18:18:06 -07:00
"time"
2023-07-21 18:01:24 -07:00
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/version"
)
2023-08-22 09:48:35 -07:00
var mode string = gin.DebugMode
2023-12-14 16:47:40 -08:00
type Server struct {
addr net.Addr
2023-12-14 16:47:40 -08:00
}
2023-08-22 09:48:35 -07:00
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
}
2023-07-31 21:35:18 -04:00
var loaded struct {
2023-07-19 15:00:28 -07:00
mu sync.Mutex
runner llm.LLM
2023-07-19 15:00:28 -07:00
expireAt time.Time
expireTimer *time.Timer
2023-07-31 21:35:18 -04:00
*Model
*api.Options
2023-07-18 11:59:42 -07:00
}
2023-08-15 10:35:39 -03:00
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
if needLoad {
if loaded.runner != nil {
slog.Info("changing loaded model")
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
2023-07-18 11:59:42 -07:00
}
2023-07-17 12:08:10 -07:00
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
2023-07-18 11:59:42 -07:00
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return err
2023-07-18 11:59:42 -07:00
}
loaded.Model = model
loaded.runner = llmRunner
loaded.Options = &opts
2023-07-19 15:00:28 -07:00
}
2023-07-31 21:35:18 -04:00
loaded.expireAt = time.Now().Add(sessionDuration)
2023-07-31 21:35:18 -04:00
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
2023-07-19 15:00:28 -07:00
2023-07-31 21:35:18 -04:00
if time.Now().Before(loaded.expireAt) {
2023-07-19 15:00:28 -07:00
return
}
if loaded.runner != nil {
loaded.runner.Close()
2023-07-19 15:00:28 -07:00
}
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
2023-07-19 15:00:28 -07:00
})
2023-07-06 10:40:11 -07:00
}
2023-07-31 21:35:18 -04:00
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
if err := opts.FromMap(requestOpts); err != nil {
return api.Options{}, err
}
return opts, nil
}
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
2023-10-18 16:08:42 -07:00
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
2023-12-05 14:57:33 -05:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-12-05 14:57:33 -05:00
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
2023-12-05 14:57:33 -05:00
if req.Prompt == "" && req.Template == "" && req.System == "" {
2023-12-01 11:37:17 -08:00
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
2023-12-15 14:25:12 -08:00
Done: true,
})
return
}
checkpointLoaded := time.Now()
2023-12-05 14:57:33 -05:00
var prompt string
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
2023-12-05 14:57:33 -05:00
}
if req.System == "" {
req.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
for i := range req.Images {
fmt.Fprintf(&sb, "[img-%d] ", i)
}
sb.WriteString(req.Prompt)
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.Reset()
2023-12-05 14:57:33 -05:00
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
2023-12-05 14:57:33 -05:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
}
sb.WriteString(p)
prompt = sb.String()
}
2024-01-31 16:47:26 -08:00
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any)
2023-12-05 14:57:33 -05:00
var generated strings.Builder
go func() {
defer close(ch)
2023-12-05 14:57:33 -05:00
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
2023-12-05 14:57:33 -05:00
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
2023-12-05 14:57:33 -05:00
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
2023-12-05 14:57:33 -05:00
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
2023-12-22 17:07:05 -05:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
2023-12-22 17:07:05 -05:00
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
2023-12-05 14:57:33 -05:00
}
}
ch <- resp
}
2024-01-31 18:56:12 -08:00
var images []llm.ImageData
2024-01-31 17:39:38 -08:00
for i := range req.Images {
2024-01-31 18:56:12 -08:00
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
2024-01-31 17:39:38 -08:00
}
2023-12-05 14:57:33 -05:00
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
2024-01-31 17:39:38 -08:00
Images: images,
Options: opts,
2023-12-05 14:57:33 -05:00
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}
func getDefaultSessionDuration() time.Duration {
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
v, err := strconv.Atoi(t)
if err != nil {
d, err := time.ParseDuration(t)
if err != nil {
return defaultSessionDuration
}
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
d := time.Duration(v) * time.Second
if d < 0 {
return time.Duration(math.MaxInt64)
}
return d
}
return defaultSessionDuration
}
func EmbeddingsHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
model, err := GetModel(req.Model)
if err != nil {
2023-12-05 14:57:33 -05:00
var pErr *fs.PathError
if errors.As(err, &pErr) {
2023-12-05 14:57:33 -05:00
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
2023-12-05 14:57:33 -05:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
2023-12-05 14:57:33 -05:00
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
}
2023-07-20 16:09:23 -07:00
func PullModelHandler(c *gin.Context) {
2023-07-11 11:54:22 -07:00
var req api.PullRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
2024-02-14 11:29:49 -08:00
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
2023-07-20 16:09:23 -07:00
func PushModelHandler(c *gin.Context) {
var req api.PushRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
2023-07-06 10:40:11 -07:00
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
2024-02-14 11:29:49 -08:00
regOpts := &registryOptions{
Insecure: req.Insecure,
}
2023-10-09 10:24:27 -07:00
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
2023-07-20 16:09:23 -07:00
func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-12 19:07:15 -07:00
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if err := ParseModelPath(model).Validate(); err != nil {
2023-11-29 15:54:29 -05:00
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-11-14 13:45:07 -08:00
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
2023-11-14 12:30:34 -08:00
return
}
2023-11-14 13:45:07 -08:00
var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
2023-11-21 15:43:17 -05:00
mf, err := os.Open(req.Path)
2023-11-14 13:45:07 -08:00
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
2023-11-21 15:43:17 -05:00
defer mf.Close()
2023-11-14 13:45:07 -08:00
2023-11-21 15:43:17 -05:00
modelfile = mf
2023-11-14 13:45:07 -08:00
}
2023-11-14 12:30:34 -08:00
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-07-11 11:54:22 -07:00
ch := make(chan any)
go func() {
defer close(ch)
fn := func(resp api.ProgressResponse) {
ch <- resp
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
2023-07-07 15:29:17 -07:00
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
2023-07-05 15:37:33 -04:00
}
2023-07-20 16:09:23 -07:00
func DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-20 16:09:23 -07:00
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
2023-07-20 16:09:23 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
2023-09-26 17:28:14 -07:00
manifestsPath, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := PruneDirectory(manifestsPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
2023-07-20 16:09:23 -07:00
}
2023-09-06 11:04:17 -07:00
func ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-09-06 11:04:17 -07:00
return
}
if req.Model != "" {
2024-01-18 15:36:50 -08:00
// noop
} else if req.Name != "" {
2024-01-18 15:36:50 -08:00
req.Model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
resp, err := GetModelInfo(req)
2023-09-06 11:04:17 -07:00
if err != nil {
if os.IsNotExist(err) {
2024-01-18 15:36:50 -08:00
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
2023-09-06 11:04:17 -07:00
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model, err := GetModel(req.Model)
2023-09-06 11:04:17 -07:00
if err != nil {
return nil, err
}
modelDetails := api.ModelDetails{
2024-01-25 12:12:36 -08:00
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
if req.System != "" {
model.System = req.System
}
if req.Template != "" {
model.Template = req.Template
}
2024-01-25 12:12:36 -08:00
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
2023-09-06 11:04:17 -07:00
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
Details: modelDetails,
2024-01-25 12:12:36 -08:00
Messages: msgs,
2023-09-06 11:04:17 -07:00
}
var params []string
cs := 30
for k, v := range model.Options {
switch val := v.(type) {
case []interface{}:
for _, nv := range val {
2024-01-16 10:34:44 -08:00
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
2023-09-06 11:04:17 -07:00
}
2024-01-16 10:34:44 -08:00
default:
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
2023-09-06 11:04:17 -07:00
}
}
resp.Parameters = strings.Join(params, "\n")
for k, v := range req.Options {
if _, ok := req.Options[k]; ok {
model.Options[k] = v
}
}
mf, err := ShowModelfile(model)
if err != nil {
return nil, err
}
resp.Modelfile = mf
2023-09-06 11:04:17 -07:00
return resp, nil
}
2023-07-20 16:09:23 -07:00
func ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
2023-12-15 15:50:51 -08:00
manifestsPath, err := GetManifestPath()
2023-07-18 09:09:45 -07:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-08-30 14:14:12 -04:00
modelResponse := func(modelName string) (api.ModelResponse, error) {
model, err := GetModel(modelName)
if err != nil {
return api.ModelResponse{}, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
return api.ModelResponse{
2024-01-18 14:32:55 -08:00
Model: model.ShortName,
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
Details: modelDetails,
}, nil
}
2023-08-30 14:14:12 -04:00
walkFunc := func(path string, info os.FileInfo, _ error) error {
2023-07-18 09:09:45 -07:00
if !info.IsDir() {
2023-12-15 15:50:51 -08:00
path, tag := filepath.Split(path)
model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
modelPath := strings.Join([]string{model, tag}, ":")
canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
2023-08-21 21:56:56 -07:00
2023-12-15 15:50:51 -08:00
resp, err := modelResponse(canonicalModelPath)
2023-07-18 09:09:45 -07:00
if err != nil {
slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
2023-12-15 14:07:34 -08:00
// nolint: nilerr
return nil
2023-07-18 09:09:45 -07:00
}
2023-08-30 14:14:12 -04:00
resp.ModifiedAt = info.ModTime()
models = append(models, resp)
2023-07-18 09:09:45 -07:00
}
2023-08-30 14:14:12 -04:00
2023-07-18 09:09:45 -07:00
return nil
2023-08-30 14:14:12 -04:00
}
2023-12-15 15:50:51 -08:00
if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
2023-07-18 09:09:45 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-07-19 15:00:28 -07:00
c.JSON(http.StatusOK, api.ListResponse{Models: models})
2023-07-18 09:09:45 -07:00
}
2023-07-24 11:27:28 -04:00
func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-24 11:27:28 -04:00
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
return
}
2023-11-29 15:54:29 -05:00
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-07-24 11:27:28 -04:00
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
}
2023-11-15 10:59:38 -08:00
func HeadBlobHandler(c *gin.Context) {
2023-11-14 14:07:40 -08:00
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusOK)
2023-11-14 14:07:40 -08:00
}
func CreateBlobHandler(c *gin.Context) {
2023-11-24 12:01:23 -08:00
layer, err := NewLayer(c.Request.Body, "")
2023-11-17 15:21:57 -08:00
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-11-24 12:01:23 -08:00
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
2023-11-14 14:07:40 -08:00
return
}
2023-11-24 12:01:23 -08:00
if _, err := layer.Commit(); err != nil {
2023-11-14 14:07:40 -08:00
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusCreated)
2023-11-14 14:07:40 -08:00
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
2024-03-09 00:22:08 -08:00
func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, a := range addrs {
if parsed, _, err := net.ParseCIDR(a.String()); err == nil {
if parsed.String() == ip.String() {
return true
}
}
}
}
}
return false
}
func allowedHost(host string) bool {
2024-03-09 00:22:08 -08:00
if host == "" || host == "localhost" {
return true
}
if hostname, err := os.Hostname(); err == nil && host == hostname {
return true
}
var tlds = []string{
2024-03-08 23:23:59 -08:00
"localhost",
"local",
"internal",
2023-12-14 16:47:40 -08:00
}
2024-03-08 23:29:53 -08:00
// check if the host is a local TLD
for _, tld := range tlds {
if strings.HasSuffix(host, "."+tld) {
return true
}
}
2024-03-08 23:29:53 -08:00
return false
2024-03-08 23:23:59 -08:00
}
2024-03-08 23:23:59 -08:00
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
c.Next()
return
}
2024-03-09 00:22:08 -08:00
if addr, err := netip.ParseAddrPort(addr.String()); err == nil && !addr.Addr().IsLoopback() {
c.Next()
return
}
host, _, err := net.SplitHostPort(c.Request.Host)
if err != nil {
host = c.Request.Host
}
2024-03-08 23:23:59 -08:00
if addr, err := netip.ParseAddr(host); err == nil {
2024-03-09 00:22:08 -08:00
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
2024-03-08 23:23:59 -08:00
c.Next()
return
}
}
if allowedHost(host) {
c.Next()
return
}
c.AbortWithStatus(http.StatusForbidden)
}
2023-12-14 16:47:40 -08:00
}
2023-12-14 16:47:40 -08:00
func (s *Server) GenerateRoutes() http.Handler {
2023-07-21 18:01:24 -07:00
config := cors.DefaultConfig()
config.AllowWildcard = true
config.AllowBrowserExtensions = true
2024-03-27 15:24:28 -07:00
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
config.AllowOrigins = strings.Split(allowedOrigins, ",")
}
for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
2023-07-21 18:01:24 -07:00
2023-07-05 15:37:33 -04:00
r := gin.Default()
r.Use(
cors.New(config),
allowedHostsMiddleware(s.addr),
)
2023-07-05 15:37:33 -04:00
2023-07-20 16:09:23 -07:00
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
2023-12-05 14:57:33 -05:00
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingsHandler)
2023-07-20 16:09:23 -07:00
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
2023-07-24 11:27:28 -04:00
r.POST("/api/copy", CopyModelHandler)
2023-07-20 16:09:23 -07:00
r.DELETE("/api/delete", DeleteModelHandler)
2023-09-06 11:04:17 -07:00
r.POST("/api/show", ShowModelHandler)
2023-11-14 14:07:40 -08:00
r.POST("/api/blobs/:digest", CreateBlobHandler)
2023-11-15 15:22:12 -08:00
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
2023-09-21 16:38:03 -07:00
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
2023-10-12 15:45:07 -07:00
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
2023-09-21 16:38:03 -07:00
}
2023-12-14 16:47:40 -08:00
return r
}
func Serve(ln net.Listener) error {
level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
blobsDir, err := GetBlobsPath("")
if err != nil {
return err
}
if err := fixBlobs(blobsDir); err != nil {
return err
}
2023-12-14 16:47:40 -08:00
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}
s := &Server{addr: ln.Addr()}
2023-12-14 16:47:40 -08:00
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
2023-12-14 16:47:40 -08:00
srvr := &http.Server{
Handler: r,
}
// listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.runner != nil {
loaded.runner.Close()
2023-09-22 19:41:52 +01:00
}
gpu.Cleanup()
os.Exit(0)
}()
if err := llm.Init(); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error())
}
}
2023-12-14 16:47:40 -08:00
return srvr.Serve(ln)
}
2023-07-06 10:40:11 -07:00
func waitForStream(c *gin.Context, ch chan interface{}) {
c.Header("Content-Type", "application/json")
for resp := range ch {
switch r := resp.(type) {
case api.ProgressResponse:
if r.Status == "success" {
c.JSON(http.StatusOK, r)
return
}
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
return
}
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
}
func streamResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/x-ndjson")
2023-07-11 11:54:22 -07:00
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
2023-07-11 11:54:22 -07:00
return false
}
// Delineate chunks with new-line delimiter
2023-07-11 11:54:22 -07:00
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
2023-07-11 11:54:22 -07:00
return false
}
return true
})
}
2023-12-05 14:57:33 -05:00
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
2023-12-05 14:57:33 -05:00
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
model, err := GetModel(req.Model)
2023-12-05 14:57:33 -05:00
if err != nil {
var pErr *fs.PathError
if errors.As(err, &pErr) {
2023-12-05 14:57:33 -05:00
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
2023-12-05 14:57:33 -05:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
2023-12-05 14:57:33 -05:00
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = getDefaultSessionDuration()
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
2023-12-05 14:57:33 -05:00
return
}
checkpointLoaded := time.Now()
// if the first message is not a system message, then add the model's default system message
if len(req.Messages) > 0 && req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{
{
Role: "system",
Content: model.System,
},
}, req.Messages...)
}
prompt, err := chatPrompt(c.Request.Context(), model.Template, req.Messages, opts.NumCtx)
2023-12-05 14:57:33 -05:00
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2024-01-31 17:39:38 -08:00
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
2023-12-05 14:57:33 -05:00
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
2023-12-05 14:57:33 -05:00
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
2023-12-05 14:57:33 -05:00
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
2024-01-31 19:18:25 -08:00
Images: images,
Options: opts,
2023-12-05 14:57:33 -05:00
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
sb.WriteString(r.Message.Content)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
2023-12-05 14:57:33 -05:00
}
}
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
2023-12-05 14:57:33 -05:00
return
}
streamResponse(c, ch)
}