commit
db77dfe01f
5 changed files with 343 additions and 212 deletions
60
api/types.go
60
api/types.go
|
@ -1,7 +1,9 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
@ -28,6 +30,9 @@ func (e StatusError) Error() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
|
SessionID int64 `json:"session_id"`
|
||||||
|
SessionDuration Duration `json:"session_duration,omitempty"`
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
@ -81,6 +86,9 @@ type ListResponseModel struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
|
SessionID int64 `json:"session_id"`
|
||||||
|
SessionExpiresAt time.Time `json:"session_expires_at"`
|
||||||
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Response string `json:"response,omitempty"`
|
Response string `json:"response,omitempty"`
|
||||||
|
@ -89,6 +97,9 @@ type GenerateResponse struct {
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
|
SampleCount int `json:"sample_count,omitempty"`
|
||||||
|
SampleDuration time.Duration `json:"sample_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
@ -100,6 +111,19 @@ func (r *GenerateResponse) Summary() {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.LoadDuration > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.SampleCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "sample count: %d token(s)\n", r.SampleCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.SampleDuration > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "sample duration: %s\n", r.SampleDuration)
|
||||||
|
fmt.Fprintf(os.Stderr, "sample rate: %.2f tokens/s\n", float64(r.SampleCount)/r.SampleDuration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
if r.PromptEvalCount > 0 {
|
if r.PromptEvalCount > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
||||||
}
|
}
|
||||||
|
@ -127,6 +151,7 @@ type Options struct {
|
||||||
|
|
||||||
// Model options
|
// Model options
|
||||||
NumCtx int `json:"num_ctx,omitempty"`
|
NumCtx int `json:"num_ctx,omitempty"`
|
||||||
|
NumKeep int `json:"num_keep,omitempty"`
|
||||||
NumBatch int `json:"num_batch,omitempty"`
|
NumBatch int `json:"num_batch,omitempty"`
|
||||||
NumGPU int `json:"num_gpu,omitempty"`
|
NumGPU int `json:"num_gpu,omitempty"`
|
||||||
MainGPU int `json:"main_gpu,omitempty"`
|
MainGPU int `json:"main_gpu,omitempty"`
|
||||||
|
@ -151,6 +176,7 @@ type Options struct {
|
||||||
Mirostat int `json:"mirostat,omitempty"`
|
Mirostat int `json:"mirostat,omitempty"`
|
||||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
||||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
||||||
|
PenalizeNewline bool `json:"penalize_newline,omitempty"`
|
||||||
|
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
@ -162,14 +188,14 @@ func DefaultOptions() Options {
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
|
|
||||||
NumCtx: 2048,
|
NumCtx: 2048,
|
||||||
NumBatch: 512,
|
NumBatch: 1024,
|
||||||
NumGPU: 1,
|
NumGPU: 1,
|
||||||
LowVRAM: false,
|
LowVRAM: false,
|
||||||
F16KV: true,
|
F16KV: true,
|
||||||
UseMMap: true,
|
UseMMap: true,
|
||||||
UseMLock: false,
|
UseMLock: false,
|
||||||
|
|
||||||
RepeatLastN: 512,
|
RepeatLastN: 64,
|
||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.1,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
|
@ -181,7 +207,37 @@ func DefaultOptions() Options {
|
||||||
Mirostat: 0,
|
Mirostat: 0,
|
||||||
MirostatTau: 5.0,
|
MirostatTau: 5.0,
|
||||||
MirostatEta: 0.1,
|
MirostatEta: 0.1,
|
||||||
|
PenalizeNewline: true,
|
||||||
|
|
||||||
NumThread: runtime.NumCPU(),
|
NumThread: runtime.NumCPU(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
|
var v any
|
||||||
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Duration = 5 * time.Minute
|
||||||
|
|
||||||
|
switch t := v.(type) {
|
||||||
|
case float64:
|
||||||
|
if t < 0 {
|
||||||
|
t = math.MaxFloat64
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Duration = time.Duration(t)
|
||||||
|
case string:
|
||||||
|
d.Duration, err = time.ParseDuration(t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
24
cmd/cmd.go
24
cmd/cmd.go
|
@ -244,7 +244,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
return generateBatch(cmd, args[0])
|
return generateBatch(cmd, args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var generateContextKey struct{}
|
type generateContextKey string
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, model, prompt string) error {
|
func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
if len(strings.TrimSpace(prompt)) > 0 {
|
if len(strings.TrimSpace(prompt)) > 0 {
|
||||||
|
@ -255,22 +255,25 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
|
|
||||||
var latest api.GenerateResponse
|
var latest api.GenerateResponse
|
||||||
|
|
||||||
generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
|
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
|
||||||
if !ok {
|
if !ok {
|
||||||
generateContext = []int{}
|
generateContext = []int{}
|
||||||
}
|
}
|
||||||
|
|
||||||
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
generateSession, ok := cmd.Context().Value(generateContextKey("session")).(int64)
|
||||||
fn := func(resp api.GenerateResponse) error {
|
if !ok {
|
||||||
|
generateSession = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext, SessionID: generateSession}
|
||||||
|
fn := func(response api.GenerateResponse) error {
|
||||||
if !spinner.IsFinished() {
|
if !spinner.IsFinished() {
|
||||||
spinner.Finish()
|
spinner.Finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
latest = resp
|
latest = response
|
||||||
|
|
||||||
fmt.Print(resp.Response)
|
fmt.Print(response.Response)
|
||||||
|
|
||||||
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -289,6 +292,11 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
if verbose {
|
if verbose {
|
||||||
latest.Summary()
|
latest.Summary()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := cmd.Context()
|
||||||
|
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
||||||
|
ctx = context.WithValue(ctx, generateContextKey("session"), latest.SessionID)
|
||||||
|
cmd.SetContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
273
llama/llama.go
273
llama/llama.go
|
@ -1,8 +1,8 @@
|
||||||
package llama
|
package llama
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS
|
#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
||||||
#cgo CXXFLAGS: -std=c++11
|
#cgo CXXFLAGS: -std=gnu++11
|
||||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||||
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
#cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
@ -21,6 +21,7 @@ struct llama_sample_options
|
||||||
int mirostat;
|
int mirostat;
|
||||||
float mirostat_tau;
|
float mirostat_tau;
|
||||||
float mirostat_eta;
|
float mirostat_eta;
|
||||||
|
bool penalize_newline;
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_token llama_sample(
|
llama_token llama_sample(
|
||||||
|
@ -37,6 +38,8 @@ llama_token llama_sample(
|
||||||
false,
|
false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_token_data newline = candidates_p.data[llama_token_nl()];
|
||||||
|
|
||||||
llama_sample_repetition_penalty(
|
llama_sample_repetition_penalty(
|
||||||
ctx, &candidates_p,
|
ctx, &candidates_p,
|
||||||
last_tokens, n_last_tokens,
|
last_tokens, n_last_tokens,
|
||||||
|
@ -47,6 +50,10 @@ llama_token llama_sample(
|
||||||
last_tokens, n_last_tokens,
|
last_tokens, n_last_tokens,
|
||||||
opts->frequency_penalty, opts->presence_penalty);
|
opts->frequency_penalty, opts->presence_penalty);
|
||||||
|
|
||||||
|
if (!opts->penalize_newline) {
|
||||||
|
candidates_p.data[llama_token_nl()] = newline;
|
||||||
|
}
|
||||||
|
|
||||||
if (opts->temperature <= 0) {
|
if (opts->temperature <= 0) {
|
||||||
return llama_sample_token_greedy(ctx, &candidates_p);
|
return llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
}
|
}
|
||||||
|
@ -82,29 +89,37 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"sync"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type llama struct {
|
type LLM struct {
|
||||||
params *C.struct_llama_context_params
|
params *C.struct_llama_context_params
|
||||||
model *C.struct_llama_model
|
model *C.struct_llama_model
|
||||||
ctx *C.struct_llama_context
|
ctx *C.struct_llama_context
|
||||||
|
|
||||||
|
last []C.llama_token
|
||||||
|
embd []C.llama_token
|
||||||
|
cursor int
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
gc bool
|
||||||
|
|
||||||
api.Options
|
api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(model string, opts api.Options) (*llama, error) {
|
func New(model string, opts api.Options) (*LLM, error) {
|
||||||
if _, err := os.Stat(model); err != nil {
|
if _, err := os.Stat(model); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
llm := llama{Options: opts}
|
llm := LLM{Options: opts}
|
||||||
|
|
||||||
C.llama_backend_init(C.bool(llm.UseNUMA))
|
C.llama_backend_init(C.bool(llm.UseNUMA))
|
||||||
|
|
||||||
|
@ -144,27 +159,118 @@ func New(model string, opts api.Options) (*llama, error) {
|
||||||
return &llm, nil
|
return &llm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) Close() {
|
func (llm *LLM) Close() {
|
||||||
|
llm.gc = true
|
||||||
|
|
||||||
|
llm.mu.Lock()
|
||||||
|
defer llm.mu.Unlock()
|
||||||
|
|
||||||
defer C.llama_free_model(llm.model)
|
defer C.llama_free_model(llm.model)
|
||||||
defer C.llama_free(llm.ctx)
|
defer C.llama_free(llm.ctx)
|
||||||
|
|
||||||
C.llama_print_timings(llm.ctx)
|
C.llama_print_timings(llm.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
if input := llm.tokenize(prompt); input != nil {
|
C.llama_reset_timings(llm.ctx)
|
||||||
embd := make([]C.llama_token, len(ctx))
|
|
||||||
for i := range ctx {
|
|
||||||
embd[i] = C.llama_token(ctx[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
return llm.generate(append(embd, input...), fn)
|
tokens := make([]C.llama_token, len(ctx))
|
||||||
|
for i := range tokens {
|
||||||
|
tokens[i] = C.llama_token(ctx[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("llama: tokenize")
|
if len(tokens) == 0 {
|
||||||
|
tokens = llm.tokenize(" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
llm.marshalPrompt(tokens, prompt)
|
||||||
|
|
||||||
|
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
for {
|
||||||
|
token, err := llm.next()
|
||||||
|
if llm.gc {
|
||||||
|
return nil
|
||||||
|
} else if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(llm.detokenize(token))
|
||||||
|
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
||||||
|
fn(api.GenerateResponse{Response: b.String()})
|
||||||
|
b.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
last := make([]int, 0, len(llm.last))
|
||||||
|
for _, i := range llm.last {
|
||||||
|
if i != 0 {
|
||||||
|
last = append(last, int(i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
timings := C.llama_get_timings(llm.ctx)
|
||||||
|
fn(api.GenerateResponse{
|
||||||
|
Done: true,
|
||||||
|
Context: last,
|
||||||
|
SampleCount: int(timings.n_sample),
|
||||||
|
SampleDuration: parseDurationMs(float64(timings.t_sample_ms)),
|
||||||
|
PromptEvalCount: int(timings.n_p_eval),
|
||||||
|
PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)),
|
||||||
|
EvalCount: int(timings.n_eval),
|
||||||
|
EvalDuration: parseDurationMs(float64(timings.t_eval_ms)),
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) tokenize(prompt string) []C.llama_token {
|
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
||||||
|
tokens := append(ctx, llm.tokenize(prompt)...)
|
||||||
|
if llm.NumKeep < 0 {
|
||||||
|
llm.NumKeep = len(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// min(llm.NumCtx - 4, llm.NumKeep)
|
||||||
|
if llm.NumCtx-4 < llm.NumKeep {
|
||||||
|
llm.NumKeep = llm.NumCtx - 4
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) >= llm.NumCtx {
|
||||||
|
// truncate input
|
||||||
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||||
|
truncated := tokens[:llm.NumKeep]
|
||||||
|
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
||||||
|
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
||||||
|
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
|
||||||
|
|
||||||
|
tokens = truncated
|
||||||
|
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
||||||
|
} else {
|
||||||
|
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
|
||||||
|
llm.last = append(llm.last, tokens...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var i int
|
||||||
|
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
|
||||||
|
// noop
|
||||||
|
}
|
||||||
|
|
||||||
|
llm.embd = tokens
|
||||||
|
if i == len(tokens) {
|
||||||
|
// evaluate at least one token to generate logits
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
|
||||||
|
llm.cursor = i
|
||||||
|
|
||||||
|
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *LLM) tokenize(prompt string) []C.llama_token {
|
||||||
cPrompt := C.CString(prompt)
|
cPrompt := C.CString(prompt)
|
||||||
defer C.free(unsafe.Pointer(cPrompt))
|
defer C.free(unsafe.Pointer(cPrompt))
|
||||||
|
|
||||||
|
@ -176,7 +282,7 @@ func (llm *llama) tokenize(prompt string) []C.llama_token {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
func (llm *LLM) detokenize(tokens ...C.llama_token) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
||||||
|
@ -185,98 +291,93 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
func (llm *LLM) next() (C.llama_token, error) {
|
||||||
var opts C.struct_llama_sample_options
|
llm.mu.Lock()
|
||||||
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
defer llm.mu.Unlock()
|
||||||
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
|
||||||
opts.presence_penalty = C.float(llm.PresencePenalty)
|
|
||||||
opts.temperature = C.float(llm.Temperature)
|
|
||||||
opts.top_k = C.int(llm.TopK)
|
|
||||||
opts.top_p = C.float(llm.TopP)
|
|
||||||
opts.tfs_z = C.float(llm.TFSZ)
|
|
||||||
opts.typical_p = C.float(llm.TypicalP)
|
|
||||||
opts.mirostat = C.int(llm.Mirostat)
|
|
||||||
opts.mirostat_tau = C.float(llm.MirostatTau)
|
|
||||||
opts.mirostat_eta = C.float(llm.MirostatEta)
|
|
||||||
|
|
||||||
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
if len(llm.embd) >= llm.NumCtx {
|
||||||
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||||
|
truncated := llm.embd[:llm.NumKeep]
|
||||||
|
truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...)
|
||||||
|
|
||||||
context := deque[int]{capacity: llm.NumCtx / 2}
|
llm.embd = truncated
|
||||||
for _, in := range input {
|
llm.cursor = llm.NumKeep
|
||||||
context.PushLeft(int(in))
|
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor)
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
for {
|
||||||
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
if llm.gc {
|
||||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
return 0, io.EOF
|
||||||
return errors.New("llama: eval")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := llm.sample(output, &opts)
|
if llm.cursor >= len(llm.embd) {
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(llm.detokenize(token))
|
numEval := len(llm.embd) - llm.cursor
|
||||||
if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
|
if numEval > llm.NumBatch {
|
||||||
// call the callback
|
numEval = llm.NumBatch
|
||||||
fn(api.GenerateResponse{
|
|
||||||
Response: b.String(),
|
|
||||||
})
|
|
||||||
|
|
||||||
output.PushLeft(token)
|
|
||||||
context.PushLeft(int(token))
|
|
||||||
b.Reset()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
input = []C.llama_token{token}
|
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 {
|
||||||
|
return 0, fmt.Errorf("llama_eval: %d", retval)
|
||||||
|
}
|
||||||
|
|
||||||
|
llm.cursor += numEval
|
||||||
}
|
}
|
||||||
|
|
||||||
dur := func(ms float64) time.Duration {
|
var sampleOpts C.struct_llama_sample_options
|
||||||
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty)
|
||||||
if err != nil {
|
sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
||||||
panic(err)
|
sampleOpts.presence_penalty = C.float(llm.PresencePenalty)
|
||||||
}
|
sampleOpts.temperature = C.float(llm.Temperature)
|
||||||
|
sampleOpts.top_k = C.int(llm.TopK)
|
||||||
|
sampleOpts.top_p = C.float(llm.TopP)
|
||||||
|
sampleOpts.tfs_z = C.float(llm.TFSZ)
|
||||||
|
sampleOpts.typical_p = C.float(llm.TypicalP)
|
||||||
|
sampleOpts.mirostat = C.int(llm.Mirostat)
|
||||||
|
sampleOpts.mirostat_tau = C.float(llm.MirostatTau)
|
||||||
|
sampleOpts.mirostat_eta = C.float(llm.MirostatEta)
|
||||||
|
sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline)
|
||||||
|
|
||||||
return d
|
numVocab := C.llama_n_vocab(llm.ctx)
|
||||||
}
|
|
||||||
|
|
||||||
timings := C.llama_get_timings(llm.ctx)
|
|
||||||
fn(api.GenerateResponse{
|
|
||||||
Done: true,
|
|
||||||
Context: context.Data(),
|
|
||||||
PromptEvalCount: int(timings.n_p_eval),
|
|
||||||
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
|
||||||
EvalCount: int(timings.n_eval),
|
|
||||||
EvalDuration: dur(float64(timings.t_eval_ms)),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
|
||||||
numVocab := int(C.llama_n_vocab(llm.ctx))
|
|
||||||
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
||||||
|
|
||||||
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
|
// TODO: logit bias
|
||||||
for i := 0; i < candidates.Cap(); i++ {
|
|
||||||
candidates.PushLeft(C.struct_llama_token_data{
|
candidates := make([]C.llama_token_data, numVocab)
|
||||||
|
for i := range logits {
|
||||||
|
candidates[i] = C.llama_token_data{
|
||||||
id: C.int(i),
|
id: C.int(i),
|
||||||
logit: logits[i],
|
logit: logits[i],
|
||||||
p: 0,
|
p: 0,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
repeatLastN := llm.RepeatLastN
|
||||||
|
if len(llm.last) < repeatLastN {
|
||||||
|
repeatLastN = len(llm.last)
|
||||||
|
}
|
||||||
|
|
||||||
|
if llm.NumCtx < repeatLastN {
|
||||||
|
repeatLastN = llm.NumCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
lastN := llm.last[len(llm.last)-repeatLastN:]
|
||||||
|
|
||||||
token := C.llama_sample(
|
token := C.llama_sample(
|
||||||
llm.ctx,
|
llm.ctx,
|
||||||
unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()),
|
unsafe.SliceData(candidates), C.size_t(len(candidates)),
|
||||||
unsafe.SliceData(output.Data()), C.size_t(output.Len()),
|
unsafe.SliceData(lastN), C.size_t(len(lastN)),
|
||||||
opts)
|
&sampleOpts,
|
||||||
if token != C.llama_token_eos() {
|
)
|
||||||
return token, nil
|
|
||||||
|
llm.last = append(llm.last, token)
|
||||||
|
llm.embd = append(llm.embd, token)
|
||||||
|
|
||||||
|
if token == C.llama_token_eos() {
|
||||||
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, io.EOF
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
107
llama/utils.go
107
llama/utils.go
|
@ -1,104 +1,15 @@
|
||||||
package llama
|
package llama
|
||||||
|
|
||||||
type node[T any] struct {
|
import (
|
||||||
t T
|
"fmt"
|
||||||
next *node[T]
|
"time"
|
||||||
prev *node[T]
|
)
|
||||||
}
|
|
||||||
|
|
||||||
type deque[T any] struct {
|
func parseDurationMs(ms float64) time.Duration {
|
||||||
head *node[T]
|
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||||
tail *node[T]
|
if err != nil {
|
||||||
size int
|
panic(err)
|
||||||
capacity int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Empty() bool {
|
|
||||||
return d.size == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Len() int {
|
|
||||||
return d.size
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Cap() int {
|
|
||||||
return d.capacity
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Push(t T) {
|
|
||||||
if d.capacity > 0 && d.size >= d.capacity {
|
|
||||||
d.PopLeft()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
n := node[T]{t: t}
|
return dur
|
||||||
if d.head != nil {
|
|
||||||
n.next = d.head
|
|
||||||
d.head.prev = &n
|
|
||||||
d.head = &n
|
|
||||||
} else {
|
|
||||||
d.head = &n
|
|
||||||
d.tail = &n
|
|
||||||
}
|
|
||||||
|
|
||||||
d.size++
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) PushLeft(t T) {
|
|
||||||
if d.capacity > 0 && d.size >= d.capacity {
|
|
||||||
d.Pop()
|
|
||||||
}
|
|
||||||
|
|
||||||
n := node[T]{t: t}
|
|
||||||
if d.tail != nil {
|
|
||||||
n.prev = d.tail
|
|
||||||
d.tail.next = &n
|
|
||||||
d.tail = &n
|
|
||||||
} else {
|
|
||||||
d.head = &n
|
|
||||||
d.tail = &n
|
|
||||||
}
|
|
||||||
|
|
||||||
d.size++
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Pop() *T {
|
|
||||||
if d.Empty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
head := d.head
|
|
||||||
d.head = head.next
|
|
||||||
if d.head != nil {
|
|
||||||
d.head.prev = nil
|
|
||||||
} else {
|
|
||||||
d.tail = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.size--
|
|
||||||
return &head.t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) PopLeft() *T {
|
|
||||||
if d.Empty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
tail := d.tail
|
|
||||||
d.tail = tail.prev
|
|
||||||
if d.tail != nil {
|
|
||||||
d.tail.next = nil
|
|
||||||
} else {
|
|
||||||
d.head = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
d.size--
|
|
||||||
return &tail.t
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deque[T]) Data() (data []T) {
|
|
||||||
for n := d.head; n != nil; n = n.next {
|
|
||||||
data = append(data, n.t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return data
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"dario.cat/mergo"
|
"dario.cat/mergo"
|
||||||
|
@ -21,8 +22,21 @@ import (
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llama"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var activeSession struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
id int64
|
||||||
|
llm *llama.LLM
|
||||||
|
|
||||||
|
expireAt time.Time
|
||||||
|
expireTimer *time.Timer
|
||||||
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
start := time.Now()
|
activeSession.mu.Lock()
|
||||||
|
defer activeSession.mu.Unlock()
|
||||||
|
|
||||||
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
@ -36,16 +50,58 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
if req.SessionID == 0 || req.SessionID != activeSession.id {
|
||||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
if activeSession.llm != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
activeSession.llm.Close()
|
||||||
return
|
activeSession.llm = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := api.DefaultOptions()
|
||||||
|
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
llm, err := llama.New(model.ModelPath, opts)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
activeSession.id = time.Now().UnixNano()
|
||||||
|
activeSession.llm = llm
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
sessionDuration := req.SessionDuration
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
sessionID := activeSession.id
|
||||||
return
|
|
||||||
|
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||||
|
if activeSession.expireTimer == nil {
|
||||||
|
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
|
||||||
|
activeSession.mu.Lock()
|
||||||
|
defer activeSession.mu.Unlock()
|
||||||
|
|
||||||
|
if sessionID != activeSession.id {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Before(activeSession.expireAt) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
activeSession.llm.Close()
|
||||||
|
activeSession.llm = nil
|
||||||
|
activeSession.id = 0
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt, err := model.Prompt(req)
|
prompt, err := model.Prompt(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -53,27 +109,26 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
llm, err := llama.New(model.ModelPath, opts)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer llm.Close()
|
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r api.GenerateResponse) {
|
||||||
|
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||||
|
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||||
|
|
||||||
r.Model = req.Model
|
r.Model = req.Model
|
||||||
r.CreatedAt = time.Now().UTC()
|
r.CreatedAt = time.Now().UTC()
|
||||||
|
r.SessionID = activeSession.id
|
||||||
|
r.SessionExpiresAt = activeSession.expireAt.UTC()
|
||||||
if r.Done {
|
if r.Done {
|
||||||
r.TotalDuration = time.Since(start)
|
r.TotalDuration = time.Since(checkpointStart)
|
||||||
|
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := llm.Predict(req.Context, prompt, fn); err != nil {
|
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -223,7 +278,7 @@ func ListModelsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.ListResponse{models})
|
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
||||||
}
|
}
|
||||||
|
|
||||||
func CopyModelHandler(c *gin.Context) {
|
func CopyModelHandler(c *gin.Context) {
|
||||||
|
|
Loading…
Reference in a new issue