Merge pull request #74 from jmorganca/timings

Timings
This commit is contained in:
Michael Yang 2023-07-13 10:17:13 -07:00 committed by GitHub
commit 77dc1a6d74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 32 deletions

View file

@ -1,6 +1,11 @@
package api package api
import "runtime" import (
"fmt"
"os"
"runtime"
"time"
)
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
@ -20,7 +25,41 @@ type GenerateRequest struct {
} }
type GenerateResponse struct { type GenerateResponse struct {
Response string `json:"response"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"`
Done bool `json:"done"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duraiton: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
}
} }
type Options struct { type Options struct {

View file

@ -72,20 +72,20 @@ func pull(model string) error {
) )
} }
func RunGenerate(_ *cobra.Command, args []string) error { func RunGenerate(cmd *cobra.Command, args []string) error {
if len(args) > 1 { if len(args) > 1 {
// join all args into a single prompt // join all args into a single prompt
return generate(args[0], strings.Join(args[1:], " ")) return generate(cmd, args[0], strings.Join(args[1:], " "))
} }
if term.IsTerminal(int(os.Stdin.Fd())) { if term.IsTerminal(int(os.Stdin.Fd())) {
return generateInteractive(args[0]) return generateInteractive(cmd, args[0])
} }
return generateBatch(args[0]) return generateBatch(cmd, args[0])
} }
func generate(model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient() client := api.NewClient()
@ -108,12 +108,16 @@ func generate(model, prompt string) error {
} }
}() }()
var latest api.GenerateResponse
request := api.GenerateRequest{Model: model, Prompt: prompt} request := api.GenerateRequest{Model: model, Prompt: prompt}
fn := func(resp api.GenerateResponse) error { fn := func(resp api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
} }
latest = resp
fmt.Print(resp.Response) fmt.Print(resp.Response)
return nil return nil
} }
@ -124,16 +128,25 @@ func generate(model, prompt string) error {
fmt.Println() fmt.Println()
fmt.Println() fmt.Println()
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
latest.Summary()
}
} }
return nil return nil
} }
func generateInteractive(model string) error { func generateInteractive(cmd *cobra.Command, model string) error {
fmt.Print(">>> ") fmt.Print(">>> ")
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
if err := generate(model, scanner.Text()); err != nil { if err := generate(cmd, model, scanner.Text()); err != nil {
return err return err
} }
@ -143,12 +156,12 @@ func generateInteractive(model string) error {
return nil return nil
} }
func generateBatch(model string) error { func generateBatch(cmd *cobra.Command, model string) error {
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
prompt := scanner.Text() prompt := scanner.Text()
fmt.Printf(">>> %s\n", prompt) fmt.Printf(">>> %s\n", prompt)
if err := generate(model, prompt); err != nil { if err := generate(cmd, model, prompt); err != nil {
return err return err
} }
} }
@ -200,6 +213,8 @@ func NewCLI() *cobra.Command {
RunE: RunRun, RunE: RunRun,
} }
runCmd.Flags().Bool("verbose", false, "Show timings for response")
serveCmd := &cobra.Command{ serveCmd := &cobra.Command{
Use: "serve", Use: "serve",
Aliases: []string{"start"}, Aliases: []string{"start"},

View file

@ -79,9 +79,11 @@ llama_token llama_sample(
import "C" import "C"
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"strings" "strings"
"time"
"unsafe" "unsafe"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
@ -147,7 +149,7 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) Predict(prompt string, fn func(string)) error { func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil { if tokens := llm.tokenize(prompt); tokens != nil {
return llm.generate(tokens, fn) return llm.generate(tokens, fn)
} }
@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String() return sb.String()
} }
func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error { func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty) opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty) opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
opts.mirostat_tau = C.float(llm.MirostatTau) opts.mirostat_tau = C.float(llm.MirostatTau)
opts.mirostat_eta = C.float(llm.MirostatEta) opts.mirostat_eta = C.float(llm.MirostatEta)
pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN} output := deque[C.llama_token]{capacity: llm.NumCtx}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval") return errors.New("llama: eval")
} }
token, err := llm.sample(pastTokens, &opts) token, err := llm.sample(output, &opts)
switch { if errors.Is(err, io.EOF) {
case errors.Is(err, io.EOF): break
return nil } else if err != nil {
case err != nil:
return err return err
} }
fn(llm.detokenize(token)) // call the callback
fn(api.GenerateResponse{
Response: llm.detokenize(token),
})
tokens = []C.llama_token{token} output.PushLeft(token)
pastTokens.PushLeft(token) input = []C.llama_token{token}
} }
dur := func(ms float64) time.Duration {
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return d
}
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: dur(float64(timings.t_eval_ms)),
})
return nil return nil
} }
func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) { func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx)) numVocab := int(C.llama_n_vocab(llm.ctx))
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab) logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
candidates := make([]C.struct_llama_token_data, 0, numVocab) candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
for i := 0; i < numVocab; i++ { for i := 0; i < candidates.Cap(); i++ {
candidates = append(candidates, C.llama_token_data{ candidates.PushLeft(C.struct_llama_token_data{
id: C.int(i), id: C.int(i),
logit: logits[i], logit: logits[i],
p: 0, p: 0,
@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
token := C.llama_sample( token := C.llama_sample(
llm.ctx, llm.ctx,
unsafe.SliceData(candidates), C.ulong(len(candidates)), unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()), unsafe.SliceData(output.Data()), C.ulong(output.Len()),
opts) opts)
if token != C.llama_token_eos() { if token != C.llama_token_eos() {
return token, nil return token, nil

View file

@ -13,6 +13,7 @@ import (
"path" "path"
"strings" "strings"
"text/template" "text/template"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy" "github.com/lithammer/fuzzysearch/fuzzy"
@ -35,6 +36,8 @@ func cacheDir() string {
} }
func generate(c *gin.Context) { func generate(c *gin.Context) {
start := time.Now()
req := api.GenerateRequest{ req := api.GenerateRequest{
Options: api.DefaultOptions(), Options: api.DefaultOptions(),
} }
@ -81,8 +84,14 @@ func generate(c *gin.Context) {
} }
defer llm.Close() defer llm.Close()
fn := func(s string) { fn := func(r api.GenerateResponse) {
ch <- api.GenerateResponse{Response: s} r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(start)
}
ch <- r
} }
if err := llm.Predict(req.Prompt, fn); err != nil { if err := llm.Predict(req.Prompt, fn); err != nil {
@ -147,7 +156,7 @@ func Serve(ln net.Listener) error {
c.String(http.StatusOK, "Ollama is running") c.String(http.StatusOK, "Ollama is running")
}) })
r.POST("api/pull", pull) r.POST("/api/pull", pull)
r.POST("/api/generate", generate) r.POST("/api/generate", generate)
log.Printf("Listening on %s", ln.Addr()) log.Printf("Listening on %s", ln.Addr())