Default Keep Alive environment variable (#3094)
--------- Co-authored-by: Chris-AS1 <8493773+Chris-AS1@users.noreply.github.com>
This commit is contained in:
parent
e72c567cfd
commit
47cfe58af5
2 changed files with 81 additions and 3 deletions
50
api/types_test.go
Normal file
50
api/types_test.go
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req string
|
||||||
|
exp *Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Positive Integer",
|
||||||
|
req: `{ "keep_alive": 42 }`,
|
||||||
|
exp: &Duration{42 * time.Second},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Positive Integer String",
|
||||||
|
req: `{ "keep_alive": "42m" }`,
|
||||||
|
exp: &Duration{42 * time.Minute},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Negative Integer",
|
||||||
|
req: `{ "keep_alive": -1 }`,
|
||||||
|
exp: &Duration{math.MaxInt64},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Negative Integer String",
|
||||||
|
req: `{ "keep_alive": "-1m" }`,
|
||||||
|
exp: &Duration{math.MaxInt64},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var dec ChatRequest
|
||||||
|
err := json.Unmarshal([]byte(test.req), &dec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, test.exp, dec.KeepAlive)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -207,7 +209,7 @@ func GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
var sessionDuration time.Duration
|
||||||
if req.KeepAlive == nil {
|
if req.KeepAlive == nil {
|
||||||
sessionDuration = defaultSessionDuration
|
sessionDuration = getDefaultSessionDuration()
|
||||||
} else {
|
} else {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
@ -384,6 +386,32 @@ func GenerateHandler(c *gin.Context) {
|
||||||
streamResponse(c, ch)
|
streamResponse(c, ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getDefaultSessionDuration() time.Duration {
|
||||||
|
if t, exists := os.LookupEnv("OLLAMA_KEEP_ALIVE"); exists {
|
||||||
|
v, err := strconv.Atoi(t)
|
||||||
|
if err != nil {
|
||||||
|
d, err := time.ParseDuration(t)
|
||||||
|
if err != nil {
|
||||||
|
return defaultSessionDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
if d < 0 {
|
||||||
|
return time.Duration(math.MaxInt64)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(v) * time.Second
|
||||||
|
if d < 0 {
|
||||||
|
return time.Duration(math.MaxInt64)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultSessionDuration
|
||||||
|
}
|
||||||
|
|
||||||
func EmbeddingsHandler(c *gin.Context) {
|
func EmbeddingsHandler(c *gin.Context) {
|
||||||
loaded.mu.Lock()
|
loaded.mu.Lock()
|
||||||
defer loaded.mu.Unlock()
|
defer loaded.mu.Unlock()
|
||||||
|
@ -427,7 +455,7 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
var sessionDuration time.Duration
|
||||||
if req.KeepAlive == nil {
|
if req.KeepAlive == nil {
|
||||||
sessionDuration = defaultSessionDuration
|
sessionDuration = getDefaultSessionDuration()
|
||||||
} else {
|
} else {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
@ -1228,7 +1256,7 @@ func ChatHandler(c *gin.Context) {
|
||||||
|
|
||||||
var sessionDuration time.Duration
|
var sessionDuration time.Duration
|
||||||
if req.KeepAlive == nil {
|
if req.KeepAlive == nil {
|
||||||
sessionDuration = defaultSessionDuration
|
sessionDuration = getDefaultSessionDuration()
|
||||||
} else {
|
} else {
|
||||||
sessionDuration = req.KeepAlive.Duration
|
sessionDuration = req.KeepAlive.Duration
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue