increase streaming buffer size (#692)

This commit is contained in:
Bruce MacDonald 2023-10-04 14:09:00 -04:00 committed by GitHub
parent dc87e9c9ae
commit 9e2de1bd2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View file

@ -18,9 +18,7 @@ import (
const DefaultHost = "127.0.0.1:11434" const DefaultHost = "127.0.0.1:11434"
var ( var envHost = os.Getenv("OLLAMA_HOST")
envHost = os.Getenv("OLLAMA_HOST")
)
type Client struct { type Client struct {
Base url.URL Base url.URL
@ -123,6 +121,8 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
return nil return nil
} }
const maxBufferSize = 512 * 1024 // 512KB
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer
if data != nil { if data != nil {
@ -151,6 +151,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
defer response.Body.Close() defer response.Body.Close()
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
// increase the buffer size to avoid running out of space
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() { for scanner.Scan() {
var errorResponse struct { var errorResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`

View file

@ -438,6 +438,8 @@ type PredictRequest struct {
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
} }
const maxBufferSize = 512 * 1024 // 512KB
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext) prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil { if err != nil {
@ -498,6 +500,9 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
// increase the buffer size to avoid running out of space
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
for scanner.Scan() { for scanner.Scan() {
select { select {
case <-ctx.Done(): case <-ctx.Done():