Merge pull request #74 from jmorganca/timings

Timings
This commit is contained in:
Michael Yang 2023-07-13 10:17:13 -07:00 committed by GitHub
commit 77dc1a6d74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 32 deletions

View file

@ -1,6 +1,11 @@
package api
import "runtime"
import (
"fmt"
"os"
"runtime"
"time"
)
type PullRequest struct {
Model string `json:"model"`
@ -20,7 +25,41 @@ type GenerateRequest struct {
}
type GenerateResponse struct {
Response string `json:"response"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Response string `json:"response,omitempty"`
Done bool `json:"done"`
TotalDuration time.Duration `json:"total_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
func (r *GenerateResponse) Summary() {
if r.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
}
if r.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
}
if r.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
}
if r.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
}
if r.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duraiton: %s\n", r.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
}
}
type Options struct {

View file

@ -72,20 +72,20 @@ func pull(model string) error {
)
}
func RunGenerate(_ *cobra.Command, args []string) error {
func RunGenerate(cmd *cobra.Command, args []string) error {
if len(args) > 1 {
// join all args into a single prompt
return generate(args[0], strings.Join(args[1:], " "))
return generate(cmd, args[0], strings.Join(args[1:], " "))
}
if term.IsTerminal(int(os.Stdin.Fd())) {
return generateInteractive(args[0])
return generateInteractive(cmd, args[0])
}
return generateBatch(args[0])
return generateBatch(cmd, args[0])
}
func generate(model, prompt string) error {
func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient()
@ -108,12 +108,16 @@ func generate(model, prompt string) error {
}
}()
var latest api.GenerateResponse
request := api.GenerateRequest{Model: model, Prompt: prompt}
fn := func(resp api.GenerateResponse) error {
if !spinner.IsFinished() {
spinner.Finish()
}
latest = resp
fmt.Print(resp.Response)
return nil
}
@ -124,16 +128,25 @@ func generate(model, prompt string) error {
fmt.Println()
fmt.Println()
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
latest.Summary()
}
}
return nil
}
func generateInteractive(model string) error {
func generateInteractive(cmd *cobra.Command, model string) error {
fmt.Print(">>> ")
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
if err := generate(model, scanner.Text()); err != nil {
if err := generate(cmd, model, scanner.Text()); err != nil {
return err
}
@ -143,12 +156,12 @@ func generateInteractive(model string) error {
return nil
}
func generateBatch(model string) error {
func generateBatch(cmd *cobra.Command, model string) error {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
prompt := scanner.Text()
fmt.Printf(">>> %s\n", prompt)
if err := generate(model, prompt); err != nil {
if err := generate(cmd, model, prompt); err != nil {
return err
}
}
@ -200,6 +213,8 @@ func NewCLI() *cobra.Command {
RunE: RunRun,
}
runCmd.Flags().Bool("verbose", false, "Show timings for response")
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},

View file

@ -79,9 +79,11 @@ llama_token llama_sample(
import "C"
import (
"errors"
"fmt"
"io"
"os"
"strings"
"time"
"unsafe"
"github.com/jmorganca/ollama/api"
@ -147,7 +149,7 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx)
}
func (llm *llama) Predict(prompt string, fn func(string)) error {
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil {
return llm.generate(tokens, fn)
}
@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
return sb.String()
}
func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
var opts C.struct_llama_sample_options
opts.repeat_penalty = C.float(llm.RepeatPenalty)
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
opts.mirostat_tau = C.float(llm.MirostatTau)
opts.mirostat_eta = C.float(llm.MirostatEta)
pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN}
output := deque[C.llama_token]{capacity: llm.NumCtx}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval")
}
token, err := llm.sample(pastTokens, &opts)
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
token, err := llm.sample(output, &opts)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
fn(llm.detokenize(token))
// call the callback
fn(api.GenerateResponse{
Response: llm.detokenize(token),
})
tokens = []C.llama_token{token}
output.PushLeft(token)
pastTokens.PushLeft(token)
input = []C.llama_token{token}
}
dur := func(ms float64) time.Duration {
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
if err != nil {
panic(err)
}
return d
}
timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{
Done: true,
PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval),
EvalDuration: dur(float64(timings.t_eval_ms)),
})
return nil
}
func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
numVocab := int(C.llama_n_vocab(llm.ctx))
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
candidates := make([]C.struct_llama_token_data, 0, numVocab)
for i := 0; i < numVocab; i++ {
candidates = append(candidates, C.llama_token_data{
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
for i := 0; i < candidates.Cap(); i++ {
candidates.PushLeft(C.struct_llama_token_data{
id: C.int(i),
logit: logits[i],
p: 0,
@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
token := C.llama_sample(
llm.ctx,
unsafe.SliceData(candidates), C.ulong(len(candidates)),
unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()),
unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
unsafe.SliceData(output.Data()), C.ulong(output.Len()),
opts)
if token != C.llama_token_eos() {
return token, nil

View file

@ -13,6 +13,7 @@ import (
"path"
"strings"
"text/template"
"time"
"github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy"
@ -35,6 +36,8 @@ func cacheDir() string {
}
func generate(c *gin.Context) {
start := time.Now()
req := api.GenerateRequest{
Options: api.DefaultOptions(),
}
@ -81,8 +84,14 @@ func generate(c *gin.Context) {
}
defer llm.Close()
fn := func(s string) {
ch <- api.GenerateResponse{Response: s}
fn := func(r api.GenerateResponse) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
if r.Done {
r.TotalDuration = time.Since(start)
}
ch <- r
}
if err := llm.Predict(req.Prompt, fn); err != nil {
@ -147,7 +156,7 @@ func Serve(ln net.Listener) error {
c.String(http.StatusOK, "Ollama is running")
})
r.POST("api/pull", pull)
r.POST("/api/pull", pull)
r.POST("/api/generate", generate)
log.Printf("Listening on %s", ln.Addr())