embeddings endpoint

Co-Authored-By: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
Bruce MacDonald 2023-08-08 15:13:22 -04:00
parent 5ebce03c77
commit 4b3507f036
2 changed files with 85 additions and 31 deletions

View file

@ -42,6 +42,17 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Options map[string]interface{} `json:"options"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
}
type CreateRequest struct {
Name string `json:"name"`
Path string `json:"path"`

View file

@ -38,35 +38,17 @@ var loaded struct {
options api.Options
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
model, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
return err
}
if err := opts.FromMap(req.Options); err != nil {
if err := opts.FromMap(reqOpts); err != nil {
log.Printf("could not merge model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
return err
}
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
@ -83,21 +65,18 @@ func GenerateHandler(c *gin.Context) {
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
return err
}
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
return err
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
return err
}
tokensWithSystem := llm.Encode(promptWithSystem)
@ -110,9 +89,8 @@ func GenerateHandler(c *gin.Context) {
loaded.digest = model.Digest
loaded.options = opts
}
sessionDuration := 5 * time.Minute
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
@ -132,6 +110,32 @@ func GenerateHandler(c *gin.Context) {
})
}
loaded.expireTimer.Reset(sessionDuration)
return nil
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
model, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
sessionDuration := 5 * time.Minute
if err := load(model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
checkpointLoaded := time.Now()
@ -181,6 +185,44 @@ func GenerateHandler(c *gin.Context) {
streamResponse(c, ch)
}
func EmbeddingHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
model, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := load(model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !loaded.options.EmbeddingOnly {
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
return
}
embedding, err := loaded.llm.Embedding(req.Prompt)
if err != nil {
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
}
func PullModelHandler(c *gin.Context) {
var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
@ -381,6 +423,7 @@ func Serve(ln net.Listener, extraOrigins []string) error {
r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler)
r.POST("/api/embeddings", EmbeddingHandler)
r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler)
r.POST("/api/copy", CopyModelHandler)