diff --git a/api/types.go b/api/types.go index e0f1f4da..34ba91b3 100644 --- a/api/types.go +++ b/api/types.go @@ -1,6 +1,11 @@ package api -import "runtime" +import ( + "fmt" + "os" + "runtime" + "time" +) type PullRequest struct { Model string `json:"model"` @@ -20,7 +25,41 @@ type GenerateRequest struct { } type GenerateResponse struct { - Response string `json:"response"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Response string `json:"response,omitempty"` + + Done bool `json:"done"` + + TotalDuration time.Duration `json:"total_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration time.Duration `json:"eval_duration,omitempty"` +} + +func (r *GenerateResponse) Summary() { + if r.TotalDuration > 0 { + fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) + } + + if r.PromptEvalCount > 0 { + fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) + } + + if r.PromptEvalDuration > 0 { + fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) + fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) + } + + if r.EvalCount > 0 { + fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) + } + + if r.EvalDuration > 0 { + fmt.Fprintf(os.Stderr, "eval duraiton: %s\n", r.EvalDuration) + fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) + } } type Options struct { diff --git a/cmd/cmd.go b/cmd/cmd.go index c783e6a8..dede4600 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -72,20 +72,20 @@ func pull(model string) error { ) } -func RunGenerate(_ *cobra.Command, args []string) error { +func RunGenerate(cmd *cobra.Command, args []string) error { if len(args) > 1 { // join all args into a single prompt - return generate(args[0], strings.Join(args[1:], " ")) + return generate(cmd, args[0], strings.Join(args[1:], " ")) } if term.IsTerminal(int(os.Stdin.Fd())) { - return generateInteractive(args[0]) + return generateInteractive(cmd, args[0]) } - return generateBatch(args[0]) + return generateBatch(cmd, args[0]) } -func generate(model, prompt string) error { +func generate(cmd *cobra.Command, model, prompt string) error { if len(strings.TrimSpace(prompt)) > 0 { client := api.NewClient() @@ -108,12 +108,16 @@ func generate(model, prompt string) error { } }() + var latest api.GenerateResponse + request := api.GenerateRequest{Model: model, Prompt: prompt} fn := func(resp api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() } + latest = resp + fmt.Print(resp.Response) return nil } @@ -124,16 +128,25 @@ func generate(model, prompt string) error { fmt.Println() fmt.Println() + + verbose, err := cmd.Flags().GetBool("verbose") + if err != nil { + return err + } + + if verbose { + latest.Summary() + } } return nil } -func generateInteractive(model string) error { +func generateInteractive(cmd *cobra.Command, model string) error { fmt.Print(">>> ") scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { - if err := generate(model, scanner.Text()); err != nil { + if err := generate(cmd, model, scanner.Text()); err != nil { return err } @@ -143,12 +156,12 @@ func generateInteractive(model string) error { return nil } -func generateBatch(model string) error { +func generateBatch(cmd *cobra.Command, model string) error { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { prompt := scanner.Text() fmt.Printf(">>> %s\n", prompt) - if err := generate(model, prompt); err != nil { + if err := generate(cmd, model, prompt); err != nil { return err } } @@ -200,6 +213,8 @@ func NewCLI() *cobra.Command { RunE: RunRun, } + runCmd.Flags().Bool("verbose", false, "Show timings for response") + serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, diff --git a/llama/llama.go b/llama/llama.go index 8922c18f..80a1b420 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -79,9 +79,11 @@ llama_token llama_sample( import "C" import ( "errors" + "fmt" "io" "os" "strings" + "time" "unsafe" "github.com/jmorganca/ollama/api" @@ -147,7 +149,7 @@ func (llm *llama) Close() { C.llama_print_timings(llm.ctx) } -func (llm *llama) Predict(prompt string, fn func(string)) error { +func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error { if tokens := llm.tokenize(prompt); tokens != nil { return llm.generate(tokens, fn) } @@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string { return sb.String() } -func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error { +func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error { var opts C.struct_llama_sample_options opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty) @@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error { opts.mirostat_tau = C.float(llm.MirostatTau) opts.mirostat_eta = C.float(llm.MirostatEta) - pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN} + output := deque[C.llama_token]{capacity: llm.NumCtx} for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { - if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { + if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { return errors.New("llama: eval") } - token, err := llm.sample(pastTokens, &opts) - switch { - case errors.Is(err, io.EOF): - return nil - case err != nil: + token, err := llm.sample(output, &opts) + if errors.Is(err, io.EOF) { + break + } else if err != nil { return err } - fn(llm.detokenize(token)) + // call the callback + fn(api.GenerateResponse{ + Response: llm.detokenize(token), + }) - tokens = []C.llama_token{token} + output.PushLeft(token) - pastTokens.PushLeft(token) + input = []C.llama_token{token} } + dur := func(ms float64) time.Duration { + d, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) + if err != nil { + panic(err) + } + + return d + } + + timings := C.llama_get_timings(llm.ctx) + fn(api.GenerateResponse{ + Done: true, + PromptEvalCount: int(timings.n_p_eval), + PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), + EvalCount: int(timings.n_eval), + EvalDuration: dur(float64(timings.t_eval_ms)), + }) + return nil } -func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { +func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { numVocab := int(C.llama_n_vocab(llm.ctx)) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) - candidates := make([]C.struct_llama_token_data, 0, numVocab) - for i := 0; i < numVocab; i++ { - candidates = append(candidates, C.llama_token_data{ + candidates := deque[C.struct_llama_token_data]{capacity: numVocab} + for i := 0; i < candidates.Cap(); i++ { + candidates.PushLeft(C.struct_llama_token_data{ id: C.int(i), logit: logits[i], p: 0, @@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s token := C.llama_sample( llm.ctx, - unsafe.SliceData(candidates), C.ulong(len(candidates)), - unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()), + unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()), + unsafe.SliceData(output.Data()), C.ulong(output.Len()), opts) if token != C.llama_token_eos() { return token, nil diff --git a/server/routes.go b/server/routes.go index d93a79fc..ace82213 100644 --- a/server/routes.go +++ b/server/routes.go @@ -13,6 +13,7 @@ import ( "path" "strings" "text/template" + "time" "github.com/gin-gonic/gin" "github.com/lithammer/fuzzysearch/fuzzy" @@ -35,6 +36,8 @@ func cacheDir() string { } func generate(c *gin.Context) { + start := time.Now() + req := api.GenerateRequest{ Options: api.DefaultOptions(), } @@ -81,8 +84,14 @@ func generate(c *gin.Context) { } defer llm.Close() - fn := func(s string) { - ch <- api.GenerateResponse{Response: s} + fn := func(r api.GenerateResponse) { + r.Model = req.Model + r.CreatedAt = time.Now().UTC() + if r.Done { + r.TotalDuration = time.Since(start) + } + + ch <- r } if err := llm.Predict(req.Prompt, fn); err != nil { @@ -147,7 +156,7 @@ func Serve(ln net.Listener) error { c.String(http.StatusOK, "Ollama is running") }) - r.POST("api/pull", pull) + r.POST("/api/pull", pull) r.POST("/api/generate", generate) log.Printf("Listening on %s", ln.Addr())