ollama/server/routes.go

1095 lines
27 KiB
Go
Raw Normal View History

package server
import (
"context"
2023-07-06 10:40:11 -07:00
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log"
"net"
"net/http"
"os"
"os/signal"
2023-07-14 17:27:14 -07:00
"path/filepath"
2023-07-31 21:35:18 -04:00
"reflect"
"runtime"
2023-09-06 11:04:17 -07:00
"strconv"
2023-07-06 10:40:11 -07:00
"strings"
2023-07-18 11:59:42 -07:00
"sync"
"syscall"
2023-07-12 18:18:06 -07:00
"time"
2023-07-21 18:01:24 -07:00
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
2023-07-03 16:32:48 -04:00
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/gpu"
2023-07-21 13:33:56 -07:00
"github.com/jmorganca/ollama/llm"
2023-11-14 12:30:34 -08:00
"github.com/jmorganca/ollama/parser"
2023-10-13 16:08:35 -07:00
"github.com/jmorganca/ollama/version"
)
2023-08-22 09:48:35 -07:00
var mode string = gin.DebugMode
2023-12-14 16:47:40 -08:00
type Server struct {
WorkDir string
}
2023-08-22 09:48:35 -07:00
func init() {
switch mode {
case gin.DebugMode:
case gin.ReleaseMode:
case gin.TestMode:
default:
mode = gin.DebugMode
}
gin.SetMode(mode)
}
2023-07-31 21:35:18 -04:00
var loaded struct {
2023-07-19 15:00:28 -07:00
mu sync.Mutex
runner llm.LLM
2023-07-19 15:00:28 -07:00
expireAt time.Time
expireTimer *time.Timer
2023-07-31 21:35:18 -04:00
*Model
*api.Options
2023-07-18 11:59:42 -07:00
}
2023-08-15 10:35:39 -03:00
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
2023-12-05 14:57:33 -05:00
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir")
2023-08-03 15:55:35 -04:00
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
2023-12-05 14:57:33 -05:00
return nil, err
2023-08-03 15:55:35 -04:00
}
if err := opts.FromMap(reqOpts); err != nil {
2023-12-05 14:57:33 -05:00
return nil, err
2023-08-03 15:55:35 -04:00
}
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
if needLoad {
if loaded.runner != nil {
log.Println("changing loaded model")
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
2023-07-18 11:59:42 -07:00
}
2023-07-17 12:08:10 -07:00
2023-11-30 10:30:23 -08:00
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
2023-07-18 11:59:42 -07:00
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
2023-12-05 14:57:33 -05:00
return nil, err
2023-07-18 11:59:42 -07:00
}
loaded.Model = model
loaded.runner = llmRunner
loaded.Options = &opts
2023-07-19 15:00:28 -07:00
}
2023-07-31 21:35:18 -04:00
loaded.expireAt = time.Now().Add(sessionDuration)
2023-07-31 21:35:18 -04:00
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
2023-07-19 15:00:28 -07:00
2023-07-31 21:35:18 -04:00
if time.Now().Before(loaded.expireAt) {
2023-07-19 15:00:28 -07:00
return
}
if loaded.runner != nil {
loaded.runner.Close()
2023-07-19 15:00:28 -07:00
}
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
2023-07-19 15:00:28 -07:00
})
2023-07-06 10:40:11 -07:00
}
2023-07-31 21:35:18 -04:00
loaded.expireTimer.Reset(sessionDuration)
2023-12-05 14:57:33 -05:00
return model, nil
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
2023-10-18 16:08:42 -07:00
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
case req.Raw && (req.Template != "" || req.System != "" || len(req.Context) > 0):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "raw mode does not support template, system, or context"})
return
}
2023-12-05 14:57:33 -05:00
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
2023-12-05 14:57:33 -05:00
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
2023-12-05 14:57:33 -05:00
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
2023-12-05 14:57:33 -05:00
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
2023-12-01 11:37:17 -08:00
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true})
return
}
checkpointLoaded := time.Now()
2023-12-05 14:57:33 -05:00
var prompt string
2023-12-22 17:07:05 -05:00
var promptVars PromptVars
2023-12-05 14:57:33 -05:00
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
var rebuild strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
2023-12-22 17:07:05 -05:00
promptVars = PromptVars{
2023-12-05 14:57:33 -05:00
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
2023-12-22 17:07:05 -05:00
}
p, err := model.PreResponsePrompt(promptVars)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-12-05 14:57:33 -05:00
rebuild.WriteString(p)
prompt = rebuild.String()
}
ch := make(chan any)
2023-12-05 14:57:33 -05:00
var generated strings.Builder
go func() {
defer close(ch)
2023-12-05 14:57:33 -05:00
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
2023-12-05 14:57:33 -05:00
// Build up the full response
if _, err := generated.WriteString(r.Content); err != nil {
ch <- gin.H{"error": err.Error()}
return
}
2023-12-05 14:57:33 -05:00
resp := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Done: r.Done,
Response: r.Content,
2023-12-05 14:57:33 -05:00
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
2023-12-22 17:07:05 -05:00
// append the generated text to the history and template it if needed
promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
2023-12-05 14:57:33 -05:00
}
}
ch <- resp
}
2023-12-05 14:57:33 -05:00
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: req.Images,
2023-12-05 14:57:33 -05:00
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.GenerateResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.GenerateResponse:
sb.WriteString(r.Response)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
}
}
final.Response = sb.String()
c.JSON(http.StatusOK, final)
return
}
streamResponse(c, ch)
}
func EmbeddingHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
var req api.EmbeddingRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if req.Model == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
2023-12-05 14:57:33 -05:00
sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil {
2023-12-05 14:57:33 -05:00
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
if !loaded.Options.EmbeddingOnly {
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
return
}
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
}
2023-07-20 16:09:23 -07:00
func PullModelHandler(c *gin.Context) {
2023-07-11 11:54:22 -07:00
var req api.PullRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
regOpts := &RegistryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, req.Name, regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
2023-07-20 16:09:23 -07:00
func PushModelHandler(c *gin.Context) {
var req api.PushRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 11:54:22 -07:00
return
}
2023-07-06 10:40:11 -07:00
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ch := make(chan any)
go func() {
defer close(ch)
2023-07-18 18:51:30 -07:00
fn := func(r api.ProgressResponse) {
ch <- r
}
2023-07-18 18:51:30 -07:00
regOpts := &RegistryOptions{
Insecure: req.Insecure,
}
2023-10-09 10:24:27 -07:00
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
2023-07-20 16:09:23 -07:00
func CreateModelHandler(c *gin.Context) {
var req api.CreateRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-12 19:07:15 -07:00
return
}
2023-11-14 13:45:07 -08:00
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
2023-11-29 15:54:29 -05:00
if err := ParseModelPath(req.Name).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-11-14 13:45:07 -08:00
if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
2023-11-14 12:30:34 -08:00
return
}
2023-11-14 13:45:07 -08:00
var modelfile io.Reader = strings.NewReader(req.Modelfile)
if req.Path != "" && req.Modelfile == "" {
2023-11-21 15:43:17 -05:00
mf, err := os.Open(req.Path)
2023-11-14 13:45:07 -08:00
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return
}
2023-11-21 15:43:17 -05:00
defer mf.Close()
2023-11-14 13:45:07 -08:00
2023-11-21 15:43:17 -05:00
modelfile = mf
2023-11-14 13:45:07 -08:00
}
2023-11-14 12:30:34 -08:00
commands, err := parser.Parse(modelfile)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-07-11 11:54:22 -07:00
ch := make(chan any)
go func() {
defer close(ch)
fn := func(resp api.ProgressResponse) {
ch <- resp
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
2023-11-21 15:43:17 -05:00
if err := CreateModel(ctx, req.Name, filepath.Dir(req.Path), commands, fn); err != nil {
2023-07-20 12:12:08 -07:00
ch <- gin.H{"error": err.Error()}
}
}()
2023-07-07 15:29:17 -07:00
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
2023-07-05 15:37:33 -04:00
}
2023-07-20 16:09:23 -07:00
func DeleteModelHandler(c *gin.Context) {
var req api.DeleteRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-20 16:09:23 -07:00
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if err := DeleteModel(req.Name); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
} else {
2023-07-20 16:09:23 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
2023-09-26 17:28:14 -07:00
manifestsPath, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := PruneDirectory(manifestsPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, nil)
2023-07-20 16:09:23 -07:00
}
2023-09-06 11:04:17 -07:00
func ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-09-06 11:04:17 -07:00
return
}
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
2023-09-06 11:04:17 -07:00
resp, err := GetModelInfo(req.Name)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(name string) (*api.ShowResponse, error) {
model, err := GetModel(name)
if err != nil {
return nil, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
2023-09-06 11:04:17 -07:00
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
Details: modelDetails,
2023-09-06 11:04:17 -07:00
}
mf, err := ShowModelfile(model)
if err != nil {
return nil, err
}
resp.Modelfile = mf
var params []string
cs := 30
for k, v := range model.Options {
switch val := v.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
case []interface{}:
for _, nv := range val {
switch nval := nv.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
}
}
}
}
resp.Parameters = strings.Join(params, "\n")
return resp, nil
}
2023-07-20 16:09:23 -07:00
func ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
2023-07-18 09:09:45 -07:00
fp, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-08-30 14:14:12 -04:00
modelResponse := func(modelName string) (api.ModelResponse, error) {
model, err := GetModel(modelName)
if err != nil {
return api.ModelResponse{}, err
}
modelDetails := api.ModelDetails{
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
ParameterSize: model.Config.ModelType,
QuantizationLevel: model.Config.FileType,
}
return api.ModelResponse{
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
Details: modelDetails,
}, nil
}
2023-08-30 14:14:12 -04:00
walkFunc := func(path string, info os.FileInfo, _ error) error {
2023-07-18 09:09:45 -07:00
if !info.IsDir() {
2023-08-30 14:14:12 -04:00
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
2023-08-21 21:56:56 -07:00
resp, err := modelResponse(tag)
2023-07-18 09:09:45 -07:00
if err != nil {
log.Printf("skipping file: %s", fp)
return nil
2023-07-18 09:09:45 -07:00
}
2023-08-30 14:14:12 -04:00
resp.ModifiedAt = info.ModTime()
models = append(models, resp)
2023-07-18 09:09:45 -07:00
}
2023-08-30 14:14:12 -04:00
2023-07-18 09:09:45 -07:00
return nil
2023-08-30 14:14:12 -04:00
}
if err := filepath.Walk(fp, walkFunc); err != nil {
2023-07-18 09:09:45 -07:00
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-07-19 15:00:28 -07:00
c.JSON(http.StatusOK, api.ListResponse{Models: models})
2023-07-18 09:09:45 -07:00
}
2023-07-24 11:27:28 -04:00
func CopyModelHandler(c *gin.Context) {
var req api.CopyRequest
2023-10-18 16:08:42 -07:00
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-24 11:27:28 -04:00
return
}
if req.Source == "" || req.Destination == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "source add destination are required"})
return
}
2023-11-29 15:54:29 -05:00
if err := ParseModelPath(req.Destination).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-07-24 11:27:28 -04:00
if err := CopyModel(req.Source, req.Destination); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
}
2023-11-15 10:59:38 -08:00
func HeadBlobHandler(c *gin.Context) {
2023-11-14 14:07:40 -08:00
path, err := GetBlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if _, err := os.Stat(path); err != nil {
c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("blob %q not found", c.Param("digest"))})
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusOK)
2023-11-14 14:07:40 -08:00
}
func CreateBlobHandler(c *gin.Context) {
2023-11-24 12:01:23 -08:00
layer, err := NewLayer(c.Request.Body, "")
2023-11-17 15:21:57 -08:00
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-11-24 12:01:23 -08:00
if layer.Digest != c.Param("digest") {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("digest mismatch, expected %q, got %q", c.Param("digest"), layer.Digest)})
2023-11-14 14:07:40 -08:00
return
}
2023-11-24 12:01:23 -08:00
if _, err := layer.Commit(); err != nil {
2023-11-14 14:07:40 -08:00
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
2023-11-15 13:55:37 -08:00
c.Status(http.StatusCreated)
2023-11-14 14:07:40 -08:00
}
var defaultAllowOrigins = []string{
"localhost",
"127.0.0.1",
"0.0.0.0",
}
2023-12-14 16:47:40 -08:00
func NewServer() (*Server, error) {
workDir, err := os.MkdirTemp("", "ollama")
if err != nil {
return nil, err
}
2023-12-14 16:47:40 -08:00
return &Server{
WorkDir: workDir,
}, nil
}
2023-12-14 16:47:40 -08:00
func (s *Server) GenerateRoutes() http.Handler {
var origins []string
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
origins = strings.Split(o, ",")
}
2023-07-21 18:01:24 -07:00
config := cors.DefaultConfig()
config.AllowWildcard = true
2023-12-14 16:47:40 -08:00
config.AllowOrigins = origins
for _, allowOrigin := range defaultAllowOrigins {
config.AllowOrigins = append(config.AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s:*", allowOrigin),
)
}
2023-07-21 18:01:24 -07:00
2023-07-05 15:37:33 -04:00
r := gin.Default()
r.Use(
cors.New(config),
func(c *gin.Context) {
2023-12-14 16:47:40 -08:00
c.Set("workDir", s.WorkDir)
c.Next()
},
)
2023-07-05 15:37:33 -04:00
2023-07-20 16:09:23 -07:00
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
2023-12-05 14:57:33 -05:00
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingHandler)
2023-07-20 16:09:23 -07:00
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
2023-07-24 11:27:28 -04:00
r.POST("/api/copy", CopyModelHandler)
2023-07-20 16:09:23 -07:00
r.DELETE("/api/delete", DeleteModelHandler)
2023-09-06 11:04:17 -07:00
r.POST("/api/show", ShowModelHandler)
2023-11-14 14:07:40 -08:00
r.POST("/api/blobs/:digest", CreateBlobHandler)
2023-11-15 15:22:12 -08:00
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
2023-09-21 16:38:03 -07:00
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", ListModelsHandler)
2023-10-12 15:45:07 -07:00
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
2023-09-21 16:38:03 -07:00
}
2023-12-14 16:47:40 -08:00
return r
}
func Serve(ln net.Listener) error {
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}
s, err := NewServer()
if err != nil {
return err
}
r := s.GenerateRoutes()
2023-10-13 16:08:35 -07:00
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
2023-12-14 16:47:40 -08:00
srvr := &http.Server{
Handler: r,
}
// listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.runner != nil {
loaded.runner.Close()
2023-09-22 19:41:52 +01:00
}
2023-12-14 16:47:40 -08:00
os.RemoveAll(s.WorkDir)
os.Exit(0)
}()
if err := llm.Init(s.WorkDir); err != nil {
return fmt.Errorf("unable to initialize llm library %w", err)
}
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
2023-12-10 11:44:27 -05:00
log.Print(err.Error())
}
}
2023-12-14 16:47:40 -08:00
return srvr.Serve(ln)
}
2023-07-06 10:40:11 -07:00
func waitForStream(c *gin.Context, ch chan interface{}) {
c.Header("Content-Type", "application/json")
for resp := range ch {
switch r := resp.(type) {
case api.ProgressResponse:
if r.Status == "success" {
c.JSON(http.StatusOK, r)
return
}
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected progress response"})
return
}
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected end of progress response"})
}
func streamResponse(c *gin.Context, ch chan any) {
c.Header("Content-Type", "application/x-ndjson")
2023-07-11 11:54:22 -07:00
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
2023-07-31 16:46:37 -04:00
log.Printf("streamResponse: json.Marshal failed with %s", err)
2023-07-11 11:54:22 -07:00
return false
}
// Delineate chunks with new-line delimiter
2023-07-11 11:54:22 -07:00
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
2023-07-31 16:46:37 -04:00
log.Printf("streamResponse: w.Write failed with %s", err)
2023-07-11 11:54:22 -07:00
return false
}
return true
})
}
2023-12-05 14:57:33 -05:00
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
2023-12-05 14:57:33 -05:00
return
}
checkpointLoaded := time.Now()
prompt, images, err := model.ChatPrompt(req.Messages)
2023-12-05 14:57:33 -05:00
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
2023-12-05 14:57:33 -05:00
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
resp.TotalDuration = time.Since(checkpointStart)
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
2023-12-05 14:57:33 -05:00
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: images,
2023-12-05 14:57:33 -05:00
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Accumulate responses into the final response
var final api.ChatResponse
2023-12-05 14:57:33 -05:00
var sb strings.Builder
for resp := range ch {
switch r := resp.(type) {
case api.ChatResponse:
sb.WriteString(r.Message.Content)
final = r
case gin.H:
if errorMsg, ok := r["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in response"})
return
}
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error"})
return
2023-12-05 14:57:33 -05:00
}
}
final.Message = api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, final)
2023-12-05 14:57:33 -05:00
return
}
streamResponse(c, ch)
}