session id

This commit is contained in:
Michael Yang 2023-07-18 11:59:42 -07:00
parent dbb3174cbc
commit 35af37a2cb
4 changed files with 67 additions and 36 deletions

View file

@ -28,6 +28,7 @@ func (e StatusError) Error() string {
} }
type GenerateRequest struct { type GenerateRequest struct {
SessionID int64 `json:"session_id"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
@ -81,6 +82,7 @@ type ListResponseModel struct {
} }
type GenerateResponse struct { type GenerateResponse struct {
SessionID int64 `json:"session_id"`
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"` Response string `json:"response,omitempty"`

View file

@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return generateBatch(cmd, args[0]) return generateBatch(cmd, args[0])
} }
var generateContextKey struct{} type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var latest api.GenerateResponse var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey).([]int) generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok { if !ok {
generateContext = []int{} generateContext = []int{}
} }
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext} generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
fn := func(resp api.GenerateResponse) error { if !ok {
generateSession = 0
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
fn := func(response api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
} }
latest = resp latest = response
fmt.Print(resp.Response) fmt.Print(response.Response)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
return nil return nil
} }
@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if verbose { if verbose {
latest.Summary() latest.Summary()
} }
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
cmd.SetContext(ctx)
} }
return nil return nil

View file

@ -91,7 +91,7 @@ import (
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
) )
type llama struct { type LLM struct {
params *C.struct_llama_context_params params *C.struct_llama_context_params
model *C.struct_llama_model model *C.struct_llama_model
ctx *C.struct_llama_context ctx *C.struct_llama_context
@ -99,12 +99,12 @@ type llama struct {
api.Options api.Options
} }
func New(model string, opts api.Options) (*llama, error) { func New(model string, opts api.Options) (*LLM, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
llm := llama{Options: opts} llm := LLM{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA)) C.llama_backend_init(C.bool(llm.UseNUMA))
@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) {
return &llm, nil return &llm, nil
} }
func (llm *llama) Close() { func (llm *LLM) Close() {
defer C.llama_free_model(llm.model) defer C.llama_free_model(llm.model)
defer C.llama_free(llm.ctx) defer C.llama_free(llm.ctx)
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
if input := llm.tokenize(prompt); input != nil { if input := llm.tokenize(prompt); input != nil {
embd := make([]C.llama_token, len(ctx)) embd := make([]C.llama_token, len(ctx))
for i := range ctx { for i := range ctx {
@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
return errors.New("llama: tokenize") return errors.New("llama: tokenize")
} }
func (llm *llama) tokenize(prompt string) []C.llama_token { func (llm *LLM) tokenize(prompt string) []C.llama_token {
cPrompt := C.CString(prompt) cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt)) defer C.free(unsafe.Pointer(cPrompt))
@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
return nil return nil
} }
func (llm *llama) detokenize(tokens ...C.llama_token) string { func (llm *LLM) detokenize(tokens ...C.llama_token) string {
var sb strings.Builder var sb strings.Builder
for _, token := range tokens { for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String() return sb.String()
} }
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
return nil return nil
} }
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx)) numVocab := int(C.llama_n_vocab(llm.ctx))
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)

View file

@ -11,6 +11,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"dario.cat/mergo" "dario.cat/mergo"
@ -21,7 +22,17 @@ import (
"github.com/jmorganca/ollama/llama" "github.com/jmorganca/ollama/llama"
) )
var mu sync.Mutex
var activeSession struct {
ID int64
*llama.LLM
}
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
mu.Lock()
defer mu.Unlock()
start := time.Now() start := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
@ -36,6 +47,12 @@ func GenerateHandler(c *gin.Context) {
return return
} }
if req.SessionID == 0 || req.SessionID != activeSession.ID {
if activeSession.LLM != nil {
activeSession.Close()
activeSession.LLM = nil
}
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil { if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@ -47,18 +64,21 @@ func GenerateHandler(c *gin.Context) {
return return
} }
prompt, err := model.Prompt(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
llm, err := llama.New(model.ModelPath, opts) llm, err := llama.New(model.ModelPath, opts)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
defer llm.Close()
activeSession.ID = time.Now().UnixNano()
activeSession.LLM = llm
}
prompt, err := model.Prompt(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ch := make(chan any) ch := make(chan any)
go func() { go func() {
@ -66,6 +86,7 @@ func GenerateHandler(c *gin.Context) {
fn := func(r api.GenerateResponse) { fn := func(r api.GenerateResponse) {
r.Model = req.Model r.Model = req.Model
r.CreatedAt = time.Now().UTC() r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.ID
if r.Done { if r.Done {
r.TotalDuration = time.Since(start) r.TotalDuration = time.Since(start)
} }
@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r ch <- r
} }
if err := llm.Predict(req.Context, prompt, fn); err != nil { if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()