ollama/server/routes.go

219 lines
4.2 KiB
Go
Raw Normal View History

package server
import (
2023-07-06 17:40:11 +00:00
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
2023-07-15 00:27:14 +00:00
"path/filepath"
2023-07-06 17:40:11 +00:00
"strings"
"text/template"
2023-07-13 01:18:06 +00:00
"time"
"github.com/gin-gonic/gin"
2023-07-03 20:32:48 +00:00
"github.com/jmorganca/ollama/api"
2023-07-06 17:40:11 +00:00
"github.com/jmorganca/ollama/llama"
)
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
2023-07-15 00:27:14 +00:00
return filepath.Join(home, ".ollama")
}
2023-07-05 19:37:33 +00:00
func generate(c *gin.Context) {
2023-07-13 01:18:06 +00:00
start := time.Now()
2023-07-07 22:29:17 +00:00
req := api.GenerateRequest{
Options: api.DefaultOptions(),
Prompt: "",
}
2023-07-05 19:37:33 +00:00
if err := c.ShouldBindJSON(&req); err != nil {
2023-07-07 21:04:43 +00:00
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-05 19:37:33 +00:00
return
}
model, err := GetModel(req.Model)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
2023-07-06 22:43:04 +00:00
templ, err := template.New("").Parse(model.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
2023-07-06 17:40:11 +00:00
}
var sb strings.Builder
if err = templ.Execute(&sb, req); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
2023-07-06 17:40:11 +00:00
}
req.Prompt = sb.String()
2023-07-06 17:40:11 +00:00
fmt.Printf("prompt = >>>%s<<<\n", req.Prompt)
llm, err := llama.New(model.ModelPath, req.Options)
2023-07-11 21:57:17 +00:00
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
2023-07-04 04:47:00 +00:00
ch := make(chan any)
go func() {
defer close(ch)
llm.Predict(req.Context, req.Prompt, func(r api.GenerateResponse) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(start)
}
ch <- r
})
}()
2023-07-11 21:57:17 +00:00
streamResponse(c, ch)
2023-07-11 18:54:22 +00:00
}
2023-07-06 17:40:11 +00:00
2023-07-11 18:54:22 +00:00
func pull(c *gin.Context) {
var req api.PullRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) {
ch <- api.PullProgress{
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
}
if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}()
streamResponse(c, ch)
}
func push(c *gin.Context) {
var req api.PushRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
2023-07-11 18:54:22 +00:00
return
}
2023-07-06 17:40:11 +00:00
ch := make(chan any)
go func() {
defer close(ch)
fn := func(status, digest string, total, completed int, percent float64) {
ch <- api.PushProgress{
Status: status,
Digest: digest,
Total: total,
Completed: completed,
Percent: percent,
}
}
if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}()
streamResponse(c, ch)
}
func create(c *gin.Context) {
var req api.CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
2023-07-13 02:07:15 +00:00
return
}
// NOTE consider passing the entire Modelfile in the json instead of the path to it
2023-07-13 02:07:15 +00:00
file, err := os.Open(req.Path)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
2023-07-13 02:07:15 +00:00
return
}
defer file.Close()
2023-07-13 02:07:15 +00:00
2023-07-11 18:54:22 +00:00
ch := make(chan any)
go func() {
defer close(ch)
fn := func(status string) {
ch <- api.CreateProgress{
Status: status,
}
}
if err := CreateModel(req.Name, file, fn); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
}()
2023-07-07 22:29:17 +00:00
streamResponse(c, ch)
2023-07-05 19:37:33 +00:00
}
func Serve(ln net.Listener) error {
r := gin.Default()
2023-07-08 03:46:15 +00:00
r.GET("/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
2023-07-13 00:19:03 +00:00
r.POST("/api/pull", pull)
2023-07-05 19:37:33 +00:00
r.POST("/api/generate", generate)
r.POST("/api/create", create)
r.POST("/api/push", push)
log.Printf("Listening on %s", ln.Addr())
s := &http.Server{
Handler: r,
}
return s.Serve(ln)
}
2023-07-06 17:40:11 +00:00
func streamResponse(c *gin.Context, ch chan any) {
2023-07-11 18:54:22 +00:00
c.Stream(func(w io.Writer) bool {
val, ok := <-ch
if !ok {
return false
}
bts, err := json.Marshal(val)
if err != nil {
return false
}
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
return false
}
return true
})
}