ollama/api/client.go
2023-07-06 16:34:44 -04:00

170 lines
3.5 KiB
Go

package api
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
)
type Client struct {
URL string
HTTP http.Client
}
func checkError(resp *http.Response, body []byte) error {
if resp.StatusCode >= 200 && resp.StatusCode < 400 {
return nil
}
apiError := Error{Code: int32(resp.StatusCode)}
if err := json.Unmarshal(body, &apiError); err != nil {
// Use the full body as the message if we fail to decode a response.
apiError.Message = string(body)
}
return apiError
}
func (c *Client) stream(ctx context.Context, method string, path string, reqData any, callback func(data []byte)) error {
var reqBody io.Reader
var data []byte
var err error
if reqData != nil {
data, err = json.Marshal(reqData)
if err != nil {
return err
}
reqBody = bytes.NewReader(data)
}
url := fmt.Sprintf("%s%s", c.URL, path)
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
res, err := c.HTTP.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
reader := bufio.NewReader(res.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
} else {
return err // Handle other errors
}
}
if err := checkError(res, line); err != nil {
return err
}
callback(bytes.TrimSuffix(line, []byte("\n")))
}
return nil
}
func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error {
var reqBody io.Reader
var data []byte
var err error
if reqData != nil {
data, err = json.Marshal(reqData)
if err != nil {
return err
}
reqBody = bytes.NewReader(data)
}
url := fmt.Sprintf("%s%s", c.URL, path)
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
respObj, err := c.HTTP.Do(req)
if err != nil {
return err
}
defer respObj.Body.Close()
respBody, err := io.ReadAll(respObj.Body)
if err != nil {
return err
}
if err := checkError(respObj, respBody); err != nil {
return err
}
if len(respBody) > 0 && respData != nil {
if err := json.Unmarshal(respBody, respData); err != nil {
return err
}
}
return nil
}
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) {
var res GenerateResponse
if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) {
callback(string(token))
}); err != nil {
return nil, err
}
return &res, nil
}
func (c *Client) Pull(ctx context.Context, req *PullRequest, callback func(progress PullProgress)) error {
var wg sync.WaitGroup
wg.Add(1)
if err := c.stream(ctx, http.MethodPost, "/api/pull", req, func(progressBytes []byte) {
/*
Events have the following format for progress:
event:progress
data:{"total":123,"completed":123,"percent":0.1}
Need to parse out the data part and unmarshal it.
*/
eventParts := strings.Split(string(progressBytes), "data:")
if len(eventParts) < 2 {
// no data part, ignore
return
}
eventData := eventParts[1]
var progress PullProgress
if err := json.Unmarshal([]byte(eventData), &progress); err != nil {
fmt.Println(err)
return
}
if progress.Completed >= progress.Total {
wg.Done()
}
callback(progress)
}); err != nil {
return err
}
wg.Wait()
return nil
}