commit
7226980fb6
6 changed files with 143 additions and 144 deletions
|
@ -10,6 +10,20 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type StatusError struct {
|
||||||
|
StatusCode int
|
||||||
|
Status string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Error() string {
|
||||||
|
if e.Message != "" {
|
||||||
|
return fmt.Sprintf("%s: %s", e.Status, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.Status
|
||||||
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
base url.URL
|
base url.URL
|
||||||
}
|
}
|
||||||
|
@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf *bytes.Buffer
|
var buf *bytes.Buffer
|
||||||
if data != nil {
|
if data != nil {
|
||||||
bts, err := json.Marshal(data)
|
bts, err := json.Marshal(data)
|
||||||
|
@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
|
||||||
scanner := bufio.NewScanner(response.Body)
|
scanner := bufio.NewScanner(response.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
var errorResponse struct {
|
var errorResponse struct {
|
||||||
Error string `json:"error"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
bts := scanner.Bytes()
|
bts := scanner.Bytes()
|
||||||
|
@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
|
||||||
return fmt.Errorf("unmarshal: %w", err)
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(errorResponse.Error) > 0 {
|
if response.StatusCode >= 400 {
|
||||||
return fmt.Errorf("stream: %s", errorResponse.Error)
|
return StatusError{
|
||||||
|
StatusCode: response.StatusCode,
|
||||||
|
Status: response.Status,
|
||||||
|
Message: errorResponse.Error,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := callback(bts); err != nil {
|
if err := fn(bts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
19
cmd/cmd.go
19
cmd/cmd.go
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -34,7 +35,14 @@ func RunRun(cmd *cobra.Command, args []string) error {
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
if err := pull(args[0]); err != nil {
|
if err := pull(args[0]); err != nil {
|
||||||
return err
|
var apiStatusError api.StatusError
|
||||||
|
if !errors.As(err, &apiStatusError) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiStatusError.StatusCode != http.StatusBadGateway {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case err != nil:
|
case err != nil:
|
||||||
return err
|
return err
|
||||||
|
@ -50,11 +58,12 @@ func pull(model string) error {
|
||||||
context.Background(),
|
context.Background(),
|
||||||
&api.PullRequest{Model: model},
|
&api.PullRequest{Model: model},
|
||||||
func(progress api.PullProgress) error {
|
func(progress api.PullProgress) error {
|
||||||
if bar == nil && progress.Percent == 100 {
|
|
||||||
// already downloaded
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if bar == nil {
|
if bar == nil {
|
||||||
|
if progress.Percent == 100 {
|
||||||
|
// already downloaded
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
bar = progressbar.DefaultBytes(progress.Total)
|
bar = progressbar.DefaultBytes(progress.Total)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -39,7 +39,6 @@ require (
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/crypto v0.10.0 // indirect
|
golang.org/x/crypto v0.10.0 // indirect
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.10.0 // indirect
|
||||||
golang.org/x/sync v0.3.0
|
|
||||||
golang.org/x/sys v0.10.0 // indirect
|
golang.org/x/sys v0.10.0 // indirect
|
||||||
golang.org/x/term v0.10.0
|
golang.org/x/term v0.10.0
|
||||||
golang.org/x/text v0.10.0 // indirect
|
golang.org/x/text v0.10.0 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -99,8 +99,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
|
||||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
|
101
server/models.go
101
server/models.go
|
@ -2,14 +2,13 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const directoryURL = "https://ollama.ai/api/models"
|
const directoryURL = "https://ollama.ai/api/models"
|
||||||
|
@ -36,12 +35,12 @@ func (m *Model) FullName() string {
|
||||||
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
return path.Join(home, ".ollama", "models", m.Name+".bin")
|
||||||
}
|
}
|
||||||
|
|
||||||
func pull(model string, progressCh chan<- api.PullProgress) error {
|
func (m *Model) TempFile() string {
|
||||||
remote, err := getRemote(model)
|
fullName := m.FullName()
|
||||||
if err != nil {
|
return path.Join(
|
||||||
return fmt.Errorf("failed to pull model: %w", err)
|
path.Dir(fullName),
|
||||||
}
|
fmt.Sprintf(".%s.part", path.Base(fullName)),
|
||||||
return saveModel(remote, progressCh)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRemote(model string) (*Model, error) {
|
func getRemote(model string) (*Model, error) {
|
||||||
|
@ -68,7 +67,7 @@ func getRemote(model string) (*Model, error) {
|
||||||
return nil, fmt.Errorf("model not found in directory: %s", model)
|
return nil, fmt.Errorf("model not found in directory: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
func saveModel(model *Model, fn func(total, completed int64)) error {
|
||||||
// this models cache directory is created by the server on startup
|
// this models cache directory is created by the server on startup
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
|
@ -76,41 +75,45 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download model: %w", err)
|
return fmt.Errorf("failed to download model: %w", err)
|
||||||
}
|
}
|
||||||
// check for resume
|
|
||||||
alreadyDownloaded := int64(0)
|
// check if completed file exists
|
||||||
fileInfo, err := os.Stat(model.FullName())
|
fi, err := os.Stat(model.FullName())
|
||||||
if err != nil {
|
switch {
|
||||||
if !os.IsNotExist(err) {
|
case errors.Is(err, os.ErrNotExist):
|
||||||
return fmt.Errorf("failed to check resume model file: %w", err)
|
// noop, file doesn't exist so create it
|
||||||
}
|
case err != nil:
|
||||||
// file doesn't exist, create it now
|
return fmt.Errorf("stat: %w", err)
|
||||||
} else {
|
default:
|
||||||
alreadyDownloaded = fileInfo.Size()
|
fn(fi.Size(), fi.Size())
|
||||||
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", alreadyDownloaded))
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var size int64
|
||||||
|
|
||||||
|
// completed file doesn't exist, check partial file
|
||||||
|
fi, err = os.Stat(model.TempFile())
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, os.ErrNotExist):
|
||||||
|
// noop, file doesn't exist so create it
|
||||||
|
case err != nil:
|
||||||
|
return fmt.Errorf("stat: %w", err)
|
||||||
|
default:
|
||||||
|
size = fi.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download model: %w", err)
|
return fmt.Errorf("failed to download model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusRequestedRangeNotSatisfiable {
|
if resp.StatusCode >= 400 {
|
||||||
// already downloaded
|
|
||||||
progressCh <- api.PullProgress{
|
|
||||||
Total: alreadyDownloaded,
|
|
||||||
Completed: alreadyDownloaded,
|
|
||||||
Percent: 100,
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
|
|
||||||
return fmt.Errorf("failed to download model: %s", resp.Status)
|
return fmt.Errorf("failed to download model: %s", resp.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := os.OpenFile(model.FullName(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -118,37 +121,23 @@ func saveModel(model *Model, progressCh chan<- api.PullProgress) error {
|
||||||
|
|
||||||
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
totalSize, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
|
||||||
|
|
||||||
buf := make([]byte, 1024)
|
totalBytes := size
|
||||||
totalBytes := alreadyDownloaded
|
totalSize += size
|
||||||
totalSize += alreadyDownloaded
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, err := resp.Body.Read(buf)
|
n, err := io.CopyN(out, resp.Body, 8192)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if _, err := out.Write(buf[:n]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
totalBytes += int64(n)
|
totalBytes += n
|
||||||
|
fn(totalSize, totalBytes)
|
||||||
// send progress updates
|
|
||||||
progressCh <- api.PullProgress{
|
|
||||||
Total: totalSize,
|
|
||||||
Completed: totalBytes,
|
|
||||||
Percent: float64(totalBytes) / float64(totalSize) * 100,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
progressCh <- api.PullProgress{
|
fn(totalSize, totalSize)
|
||||||
Total: totalSize,
|
return os.Rename(model.TempFile(), model.FullName())
|
||||||
Completed: totalSize,
|
|
||||||
Percent: 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
136
server/routes.go
136
server/routes.go
|
@ -16,7 +16,6 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
|
@ -56,12 +55,8 @@ func generate(c *gin.Context) {
|
||||||
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
||||||
}
|
}
|
||||||
|
|
||||||
llm, err := llama.New(req.Model, req.Options)
|
ch := make(chan any)
|
||||||
if err != nil {
|
go stream(c, ch)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer llm.Close()
|
|
||||||
|
|
||||||
templateNames := make([]string, 0, len(templates.Templates()))
|
templateNames := make([]string, 0, len(templates.Templates()))
|
||||||
for _, template := range templates.Templates() {
|
for _, template := range templates.Templates() {
|
||||||
|
@ -79,39 +74,49 @@ func generate(c *gin.Context) {
|
||||||
req.Prompt = sb.String()
|
req.Prompt = sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan string)
|
llm, err := llama.New(req.Model, req.Options)
|
||||||
g, _ := errgroup.WithContext(c.Request.Context())
|
if err != nil {
|
||||||
g.Go(func() error {
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
defer close(ch)
|
return
|
||||||
return llm.Predict(req.Prompt, func(s string) {
|
}
|
||||||
ch <- s
|
defer llm.Close()
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
g.Go(func() error {
|
fn := func(s string) {
|
||||||
c.Stream(func(w io.Writer) bool {
|
ch <- api.GenerateResponse{Response: s}
|
||||||
s, ok := <-ch
|
}
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts, err := json.Marshal(api.GenerateResponse{Response: s})
|
if err := llm.Predict(req.Prompt, fn); err != nil {
|
||||||
if err != nil {
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return false
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bts = append(bts, '\n')
|
}
|
||||||
if _, err := w.Write(bts); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
func pull(c *gin.Context) {
|
||||||
})
|
var req api.PullRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
remote, err := getRemote(req.Model)
|
||||||
})
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
|
ch := make(chan any)
|
||||||
|
go stream(c, ch)
|
||||||
|
|
||||||
|
fn := func(total, completed int64) {
|
||||||
|
ch <- api.PullProgress{
|
||||||
|
Total: total,
|
||||||
|
Completed: completed,
|
||||||
|
Percent: float64(total) / float64(completed) * 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveModel(remote, fn); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -124,47 +129,7 @@ func Serve(ln net.Listener) error {
|
||||||
c.String(http.StatusOK, "Ollama is running")
|
c.String(http.StatusOK, "Ollama is running")
|
||||||
})
|
})
|
||||||
|
|
||||||
r.POST("api/pull", func(c *gin.Context) {
|
r.POST("api/pull", pull)
|
||||||
var req api.PullRequest
|
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
progressCh := make(chan api.PullProgress)
|
|
||||||
go func() {
|
|
||||||
defer close(progressCh)
|
|
||||||
if err := pull(req.Model, progressCh); err != nil {
|
|
||||||
var opError *net.OpError
|
|
||||||
if errors.As(err, &opError) {
|
|
||||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
|
||||||
progress, ok := <-progressCh
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts, err := json.Marshal(progress)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
bts = append(bts, '\n')
|
|
||||||
if _, err := w.Write(bts); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
r.POST("/api/generate", generate)
|
r.POST("/api/generate", generate)
|
||||||
|
|
||||||
log.Printf("Listening on %s", ln.Addr())
|
log.Printf("Listening on %s", ln.Addr())
|
||||||
|
@ -186,3 +151,24 @@ func matchRankOne(source string, targets []string) (bestMatch string, bestRank i
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stream(c *gin.Context, ch chan any) {
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
val, ok := <-ch
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bts, err := json.Marshal(val)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
bts = append(bts, '\n')
|
||||||
|
if _, err := w.Write(bts); err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue