Merge pull request #102 from jmorganca/session-id

Session
This commit is contained in:
Michael Yang 2023-07-27 16:46:29 -07:00 committed by GitHub
commit db77dfe01f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 343 additions and 212 deletions

View file

@ -1,7 +1,9 @@
package api
import (
"encoding/json"
"fmt"
"math"
"os"
"runtime"
"time"
@ -28,6 +30,9 @@ func (e StatusError) Error() string {
}
type GenerateRequest struct {
SessionID int64 `json:"session_id"`
SessionDuration Duration `json:"session_duration,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
@ -81,6 +86,9 @@ type ListResponseModel struct {
}
type GenerateResponse struct {
SessionID int64 `json:"session_id"`
SessionExpiresAt time.Time `json:"session_expires_at"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"`
@ -89,6 +97,9 @@ type GenerateResponse struct {
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
SampleCount int `json:"sample_count,omitempty"`
SampleDuration time.Duration `json:"sample_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
}
if r.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
}
if r.SampleCount > 0 {
fmt.Fprintf(os.Stderr, "sample count: %d token(s)\n", r.SampleCount)
}
if r.SampleDuration > 0 {
fmt.Fprintf(os.Stderr, "sample duration: %s\n", r.SampleDuration)
fmt.Fprintf(os.Stderr, "sample rate: %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds())
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
}
@ -127,6 +151,7 @@ type Options struct {
// Model options
NumCtx int `json:"num_ctx,omitempty"`
NumKeep int `json:"num_keep,omitempty"`
NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
@ -151,6 +176,7 @@ type Options struct {
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
PenalizeNewline bool `json:"penalize_newline,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
@ -162,14 +188,14 @@ func DefaultOptions() Options {
UseNUMA: false,
NumCtx: 2048,
NumBatch: 512,
NumBatch: 1024,
NumGPU: 1,
LowVRAM: false,
F16KV: true,
UseMMap: true,
UseMLock: false,
RepeatLastN: 512,
RepeatLastN: 64,
RepeatPenalty: 1.1,
FrequencyPenalty: 0.0,
PresencePenalty: 0.0,
@ -181,7 +207,37 @@ func DefaultOptions() Options {
Mirostat: 0,
MirostatTau: 5.0,
MirostatEta: 0.1,
PenalizeNewline: true,
NumThread: runtime.NumCPU(),
}
}
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any
if err := json.Unmarshal(b, &v); err != nil {
return err
}
d.Duration = 5 * time.Minute
switch t := v.(type) {
case float64:
if t < 0 {
t = math.MaxFloat64
}
d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
}
return nil
}

View file

@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return generateBatch(cmd, args[0])
}
var generateContextKey struct{}
type generateContextKey string
func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 {
@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok {
generateContext = []int{}
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
fn := func(resp api.GenerateResponse) error {
generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
if !ok {
generateSession = 0
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
fn := func(response api.GenerateResponse) error {
if !spinner.IsFinished() {
spinner.Finish()
}
latest = resp
latest = response
fmt.Print(resp.Response)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
fmt.Print(response.Response)
return nil
}
@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
if verbose {
latest.Summary()
}
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
cmd.SetContext(ctx)
}
return nil

View file

@ -1,8 +1,8 @@
package llama
/*
#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS
#cgo CXXFLAGS: -std=c++11
#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
#cgo CXXFLAGS: -std=gnu++11
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#include <stdlib.h>
@ -21,6 +21,7 @@ struct llama_sample_options
int mirostat;
float mirostat_tau;
float mirostat_eta;
bool penalize_newline;
};
llama_token llama_sample(
@ -37,6 +38,8 @@ llama_token llama_sample(
false,
};
struct llama_token_data newline = candidates_p.data[llama_token_nl()];
llama_sample_repetition_penalty(
ctx, &candidates_p,
last_tokens, n_last_tokens,
@ -47,6 +50,10 @@ llama_token llama_sample(
last_tokens, n_last_tokens,
opts->frequency_penalty, opts->presence_penalty);
if (!opts->penalize_newline) {
candidates_p.data[llama_token_nl()] = newline;
}
if (opts->temperature <= 0) {
return llama_sample_token_greedy(ctx, &candidates_p);
}
@ -82,29 +89,37 @@ import (
"errors"
"fmt"
"io"
"log"
"os"
"strings"
"time"
"sync"
"unicode/utf8"
"unsafe"
"github.com/jmorganca/ollama/api"
)
type llama struct {
type LLM struct {
params *C.struct_llama_context_params
model *C.struct_llama_model
ctx *C.struct_llama_context
last []C.llama_token
embd []C.llama_token
cursor int
mu sync.Mutex
gc bool
api.Options
}
func New(model string, opts api.Options) (*llama, error) {
func New(model string, opts api.Options) (*LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
llm := llama{Options: opts}
llm := LLM{Options: opts}
C.llama_backend_init(C.bool(llm.UseNUMA))
@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
return &llm, nil
}
func (llm *llama) Close() {
func (llm *LLM) Close() {
llm.gc = true
llm.mu.Lock()
defer llm.mu.Unlock()
defer C.llama_free_model(llm.model)
defer C.llama_free(llm.ctx)
C.llama_print_timings(llm.ctx)
}
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
if input := llm.tokenize(prompt); input != nil {
embd := make([]C.llama_token, len(ctx))
for i := range ctx {
embd[i] = C.llama_token(ctx[i])
}
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
C.llama_reset_timings(llm.ctx)
return llm.generate(append(embd, input...), fn)
tokens := make([]C.llama_token, len(ctx))
for i := range tokens {
tokens[i] = C.llama_token(ctx[i])
}
return errors.New("llama: tokenize")
if len(tokens) == 0 {
tokens = llm.tokenize(" ")
}
llm.marshalPrompt(tokens, prompt)
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
var b bytes.Buffer
for {
token, err := llm.next()
if llm.gc {
return nil
} else if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
b.WriteString(llm.detokenize(token))
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
fn(api.GenerateResponse{Response: b.String()})
b.Reset()
}
}
last := make([]int, 0, len(llm.last))
for _, i := range llm.last {
if i != 0 {
last = append(last, int(i))
}
}
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
Context: last,
SampleCount: int(timings.n_sample),
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: parseDurationMs(float64(timings.t_eval_ms)),
})
return nil
}
func (llm *llama) tokenize(prompt string) []C.llama_token {
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
tokens := append(ctx, llm.tokenize(prompt)...)
if llm.NumKeep < 0 {
llm.NumKeep = len(tokens)
}
// min(llm.NumCtx - 4, llm.NumKeep)
if llm.NumCtx-4 < llm.NumKeep {
llm.NumKeep = llm.NumCtx - 4
}
if len(tokens) >= llm.NumCtx {
// truncate input
numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := tokens[:llm.NumKeep]
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
tokens = truncated
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
} else {
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
llm.last = append(llm.last, tokens...)
}
var i int
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
// noop
}
llm.embd = tokens
if i == len(tokens) {
// evaluate at least one token to generate logits
i--
}
llm.cursor = i
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
return tokens
}
func (llm *LLM) tokenize(prompt string) []C.llama_token {
cPrompt := C.CString(prompt)
defer C.free(unsafe.Pointer(cPrompt))
@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
return nil
}
func (llm *llama) detokenize(tokens ...C.llama_token) string {
func (llm *LLM) detokenize(tokens ...C.llama_token) string {
var sb strings.Builder
for _, token := range tokens {
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String()
}
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
opts.presence_penalty = C.float(llm.PresencePenalty)
opts.temperature = C.float(llm.Temperature)
opts.top_k = C.int(llm.TopK)
opts.top_p = C.float(llm.TopP)
opts.tfs_z = C.float(llm.TFSZ)
opts.typical_p = C.float(llm.TypicalP)
opts.mirostat = C.int(llm.Mirostat)
opts.mirostat_tau = C.float(llm.MirostatTau)
opts.mirostat_eta = C.float(llm.MirostatEta)
func (llm *LLM) next() (C.llama_token, error) {
llm.mu.Lock()
defer llm.mu.Unlock()
output := deque[C.llama_token]{capacity: llm.NumCtx}
if len(llm.embd) >= llm.NumCtx {
numLeft := (llm.NumCtx - llm.NumKeep) / 2
truncated := llm.embd[:llm.NumKeep]
truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...)
context := deque[int]{capacity: llm.NumCtx / 2}
for _, in := range input {
context.PushLeft(int(in))
llm.embd = truncated
llm.cursor = llm.NumKeep
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor)
}
var b bytes.Buffer
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval")
for {
if llm.gc {
return 0, io.EOF
}
token, err := llm.sample(output, &opts)
if errors.Is(err, io.EOF) {
if llm.cursor >= len(llm.embd) {
break
} else if err != nil {
return err
}
b.WriteString(llm.detokenize(token))
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
// call the callback
fn(api.GenerateResponse{
Response: b.String(),
})
output.PushLeft(token)
context.PushLeft(int(token))
b.Reset()
numEval := len(llm.embd) - llm.cursor
if numEval > llm.NumBatch {
numEval = llm.NumBatch
}
input = []C.llama_token{token}
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 {
return 0, fmt.Errorf("llama_eval: %d", retval)
}
llm.cursor += numEval
}
dur := func(ms float64) time.Duration {
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
var sampleOpts C.struct_llama_sample_options
sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty)
sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty)
sampleOpts.presence_penalty = C.float(llm.PresencePenalty)
sampleOpts.temperature = C.float(llm.Temperature)
sampleOpts.top_k = C.int(llm.TopK)
sampleOpts.top_p = C.float(llm.TopP)
sampleOpts.tfs_z = C.float(llm.TFSZ)
sampleOpts.typical_p = C.float(llm.TypicalP)
sampleOpts.mirostat = C.int(llm.Mirostat)
sampleOpts.mirostat_tau = C.float(llm.MirostatTau)
sampleOpts.mirostat_eta = C.float(llm.MirostatEta)
sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline)
return d
}
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
Context: context.Data(),
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: dur(float64(timings.t_eval_ms)),
})
return nil
}
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx))
numVocab := C.llama_n_vocab(llm.ctx)
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
for i := 0; i < candidates.Cap(); i++ {
candidates.PushLeft(C.struct_llama_token_data{
// TODO: logit bias
candidates := make([]C.llama_token_data, numVocab)
for i := range logits {
candidates[i] = C.llama_token_data{
id: C.int(i),
logit: logits[i],
p: 0,
})
}
}
repeatLastN := llm.RepeatLastN
if len(llm.last) < repeatLastN {
repeatLastN = len(llm.last)
}
if llm.NumCtx < repeatLastN {
repeatLastN = llm.NumCtx
}
lastN := llm.last[len(llm.last)-repeatLastN:]
token := C.llama_sample(
llm.ctx,
unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()),
unsafe.SliceData(output.Data()), C.size_t(output.Len()),
opts)
if token != C.llama_token_eos() {
return token, nil
unsafe.SliceData(candidates), C.size_t(len(candidates)),
unsafe.SliceData(lastN), C.size_t(len(lastN)),
&sampleOpts,
)
llm.last = append(llm.last, token)
llm.embd = append(llm.embd, token)
if token == C.llama_token_eos() {
return 0, io.EOF
}
return 0, io.EOF
return token, nil
}

View file

@ -1,104 +1,15 @@
package llama
type node[T any] struct {
t T
next *node[T]
prev *node[T]
}
import (
"fmt"
"time"
)
type deque[T any] struct {
head *node[T]
tail *node[T]
size int
capacity int
}
func (d *deque[T]) Empty() bool {
return d.size == 0
}
func (d *deque[T]) Len() int {
return d.size
}
func (d *deque[T]) Cap() int {
return d.capacity
}
func (d *deque[T]) Push(t T) {
if d.capacity > 0 && d.size >= d.capacity {
d.PopLeft()
func parseDurationMs(ms float64) time.Duration {
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
n := node[T]{t: t}
if d.head != nil {
n.next = d.head
d.head.prev = &n
d.head = &n
} else {
d.head = &n
d.tail = &n
}
d.size++
}
func (d *deque[T]) PushLeft(t T) {
if d.capacity > 0 && d.size >= d.capacity {
d.Pop()
}
n := node[T]{t: t}
if d.tail != nil {
n.prev = d.tail
d.tail.next = &n
d.tail = &n
} else {
d.head = &n
d.tail = &n
}
d.size++
}
func (d *deque[T]) Pop() *T {
if d.Empty() {
return nil
}
head := d.head
d.head = head.next
if d.head != nil {
d.head.prev = nil
} else {
d.tail = nil
}
d.size--
return &head.t
}
func (d *deque[T]) PopLeft() *T {
if d.Empty() {
return nil
}
tail := d.tail
d.tail = tail.prev
if d.tail != nil {
d.tail.next = nil
} else {
d.head = nil
}
d.size--
return &tail.t
}
func (d *deque[T]) Data() (data []T) {
for n := d.head; n != nil; n = n.next {
data = append(data, n.t)
}
return data
return dur
}

View file

@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"dario.cat/mergo"
@ -21,8 +22,21 @@ import (
"github.com/jmorganca/ollama/llama"
)
var activeSession struct {
mu sync.Mutex
id int64
llm *llama.LLM
expireAt time.Time
expireTimer *time.Timer
}
func GenerateHandler(c *gin.Context) {
start := time.Now()
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
checkpointStart := time.Now()
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); err != nil {
@ -36,16 +50,58 @@ func GenerateHandler(c *gin.Context) {
return
}
opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
if req.SessionID == 0 || req.SessionID != activeSession.id {
if activeSession.llm != nil {
activeSession.llm.Close()
activeSession.llm = nil
}
opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
activeSession.id = time.Now().UnixNano()
activeSession.llm = llm
}
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
sessionDuration := req.SessionDuration
sessionID := activeSession.id
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
if activeSession.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
if sessionID != activeSession.id {
return
}
if time.Now().Before(activeSession.expireAt) {
return
}
activeSession.llm.Close()
activeSession.llm = nil
activeSession.id = 0
})
}
activeSession.expireTimer.Reset(sessionDuration.Duration)
checkpointLoaded := time.Now()
prompt, err := model.Prompt(req)
if err != nil {
@ -53,27 +109,26 @@ func GenerateHandler(c *gin.Context) {
return
}
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
activeSession.expireTimer.Reset(sessionDuration.Duration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done {
r.TotalDuration = time.Since(start)
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- r
}
if err := llm.Predict(req.Context, prompt, fn); err != nil {
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
return
}
c.JSON(http.StatusOK, api.ListResponse{models})
c.JSON(http.StatusOK, api.ListResponse{Models: models})
}
func CopyModelHandler(c *gin.Context) {