From 4b3507f0366d927a61fb2d21c6c3effd9889d508 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 8 Aug 2023 15:13:22 -0400 Subject: [PATCH] embeddings endpoint Co-Authored-By: Jeffrey Morgan --- api/types.go | 11 +++++ server/routes.go | 105 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 85 insertions(+), 31 deletions(-) diff --git a/api/types.go b/api/types.go index f868f2db..825db36e 100644 --- a/api/types.go +++ b/api/types.go @@ -42,6 +42,17 @@ type GenerateRequest struct { Options map[string]interface{} `json:"options"` } +type EmbeddingRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + + Options map[string]interface{} `json:"options"` +} + +type EmbeddingResponse struct { + Embedding []float64 `json:"embedding"` +} + type CreateRequest struct { Name string `json:"name"` Path string `json:"path"` diff --git a/server/routes.go b/server/routes.go index 731e8078..14040332 100644 --- a/server/routes.go +++ b/server/routes.go @@ -38,35 +38,17 @@ var loaded struct { options api.Options } -func GenerateHandler(c *gin.Context) { - loaded.mu.Lock() - defer loaded.mu.Unlock() - - checkpointStart := time.Now() - - var req api.GenerateRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - model, err := GetModel(req.Model) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - +// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function +func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { log.Printf("could not load model options: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return err } - if err := opts.FromMap(req.Options); err != nil { + if err := opts.FromMap(reqOpts); err != nil { log.Printf("could not merge model options: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return err } if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) { @@ -83,21 +65,18 @@ func GenerateHandler(c *gin.Context) { llm, err := llama.New(model.ModelPath, opts) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return err } if opts.NumKeep < 0 { promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "") if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return err } promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "") if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return + return err } tokensWithSystem := llm.Encode(promptWithSystem) @@ -110,9 +89,8 @@ func GenerateHandler(c *gin.Context) { loaded.digest = model.Digest loaded.options = opts } - sessionDuration := 5 * time.Minute - loaded.expireAt = time.Now().Add(sessionDuration) + if loaded.expireTimer == nil { loaded.expireTimer = time.AfterFunc(sessionDuration, func() { loaded.mu.Lock() @@ -132,6 +110,32 @@ func GenerateHandler(c *gin.Context) { }) } loaded.expireTimer.Reset(sessionDuration) + return nil +} + +func GenerateHandler(c *gin.Context) { + loaded.mu.Lock() + defer loaded.mu.Unlock() + + checkpointStart := time.Now() + + var req api.GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + model, err := GetModel(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + sessionDuration := 5 * time.Minute + if err := load(model, req.Options, sessionDuration); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } checkpointLoaded := time.Now() @@ -181,6 +185,44 @@ func GenerateHandler(c *gin.Context) { streamResponse(c, ch) } +func EmbeddingHandler(c *gin.Context) { + loaded.mu.Lock() + defer loaded.mu.Unlock() + + var req api.EmbeddingRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + model, err := GetModel(req.Model) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := load(model, req.Options, 5*time.Minute); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if !loaded.options.EmbeddingOnly { + c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"}) + return + } + + embedding, err := loaded.llm.Embedding(req.Prompt) + if err != nil { + log.Printf("embedding generation failed: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) + return + } + + resp := api.EmbeddingResponse{ + Embedding: embedding, + } + c.JSON(http.StatusOK, resp) +} + func PullModelHandler(c *gin.Context) { var req api.PullRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -381,6 +423,7 @@ func Serve(ln net.Listener, extraOrigins []string) error { r.POST("/api/pull", PullModelHandler) r.POST("/api/generate", GenerateHandler) + r.POST("/api/embeddings", EmbeddingHandler) r.POST("/api/create", CreateModelHandler) r.POST("/api/push", PushModelHandler) r.POST("/api/copy", CopyModelHandler)