session id
This commit is contained in:
parent
dbb3174cbc
commit
35af37a2cb
4 changed files with 67 additions and 36 deletions
|
@ -28,9 +28,10 @@ func (e StatusError) Error() string {
|
|||
}
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
SessionID int64 `json:"session_id"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
Options `json:"options"`
|
||||
}
|
||||
|
@ -81,6 +82,7 @@ type ListResponseModel struct {
|
|||
}
|
||||
|
||||
type GenerateResponse struct {
|
||||
SessionID int64 `json:"session_id"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Response string `json:"response,omitempty"`
|
||||
|
|
24
cmd/cmd.go
24
cmd/cmd.go
|
@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
|||
return generateBatch(cmd, args[0])
|
||||
}
|
||||
|
||||
var generateContextKey struct{}
|
||||
type generateContextKey string
|
||||
|
||||
func generate(cmd *cobra.Command, model, prompt string) error {
|
||||
if len(strings.TrimSpace(prompt)) > 0 {
|
||||
|
@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
|||
|
||||
var latest api.GenerateResponse
|
||||
|
||||
generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
|
||||
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
|
||||
if !ok {
|
||||
generateContext = []int{}
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
|
||||
if !ok {
|
||||
generateSession = 0
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
if !spinner.IsFinished() {
|
||||
spinner.Finish()
|
||||
}
|
||||
|
||||
latest = resp
|
||||
latest = response
|
||||
|
||||
fmt.Print(resp.Response)
|
||||
|
||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
|
||||
fmt.Print(response.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
|||
if verbose {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
|
||||
cmd.SetContext(ctx)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -91,7 +91,7 @@ import (
|
|||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type llama struct {
|
||||
type LLM struct {
|
||||
params *C.struct_llama_context_params
|
||||
model *C.struct_llama_model
|
||||
ctx *C.struct_llama_context
|
||||
|
@ -99,12 +99,12 @@ type llama struct {
|
|||
api.Options
|
||||
}
|
||||
|
||||
func New(model string, opts api.Options) (*llama, error) {
|
||||
func New(model string, opts api.Options) (*LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llm := llama{Options: opts}
|
||||
llm := LLM{Options: opts}
|
||||
|
||||
C.llama_backend_init(C.bool(llm.UseNUMA))
|
||||
|
||||
|
@ -144,14 +144,14 @@ func New(model string, opts api.Options) (*llama, error) {
|
|||
return &llm, nil
|
||||
}
|
||||
|
||||
func (llm *llama) Close() {
|
||||
func (llm *LLM) Close() {
|
||||
defer C.llama_free_model(llm.model)
|
||||
defer C.llama_free(llm.ctx)
|
||||
|
||||
C.llama_print_timings(llm.ctx)
|
||||
}
|
||||
|
||||
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
if input := llm.tokenize(prompt); input != nil {
|
||||
embd := make([]C.llama_token, len(ctx))
|
||||
for i := range ctx {
|
||||
|
@ -164,7 +164,7 @@ func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse
|
|||
return errors.New("llama: tokenize")
|
||||
}
|
||||
|
||||
func (llm *llama) tokenize(prompt string) []C.llama_token {
|
||||
func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
||||
cPrompt := C.CString(prompt)
|
||||
defer C.free(unsafe.Pointer(cPrompt))
|
||||
|
||||
|
@ -176,7 +176,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
||||
func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
||||
var sb strings.Builder
|
||||
for _, token := range tokens {
|
||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
||||
|
@ -185,7 +185,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
|||
return sb.String()
|
||||
}
|
||||
|
||||
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
||||
func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
||||
var opts C.struct_llama_sample_options
|
||||
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
||||
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
||||
|
@ -256,7 +256,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
|||
return nil
|
||||
}
|
||||
|
||||
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
||||
func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
||||
numVocab := int(C.llama_n_vocab(llm.ctx))
|
||||
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
|
@ -21,7 +22,17 @@ import (
|
|||
"github.com/jmorganca/ollama/llama"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
var activeSession struct {
|
||||
ID int64
|
||||
*llama.LLM
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
|
@ -36,15 +47,31 @@ func GenerateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.SessionID == 0 || req.SessionID != activeSession.ID {
|
||||
if activeSession.LLM != nil {
|
||||
activeSession.Close()
|
||||
activeSession.LLM = nil
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
opts := api.DefaultOptions()
|
||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
activeSession.ID = time.Now().UnixNano()
|
||||
activeSession.LLM = llm
|
||||
}
|
||||
|
||||
prompt, err := model.Prompt(req)
|
||||
|
@ -53,19 +80,13 @@ func GenerateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer llm.Close()
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(r api.GenerateResponse) {
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
r.SessionID = activeSession.ID
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(start)
|
||||
}
|
||||
|
@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) {
|
|||
ch <- r
|
||||
}
|
||||
|
||||
if err := llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
|
Loading…
Reference in a new issue