return more info in generate response
This commit is contained in:
parent
31590284a7
commit
05e08d2310
4 changed files with 116 additions and 31 deletions
43
api/types.go
43
api/types.go
|
@ -1,6 +1,11 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import "runtime"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
@ -20,7 +25,41 @@ type GenerateRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateResponse struct {
|
type GenerateResponse struct {
|
||||||
Response string `json:"response"`
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Response string `json:"response,omitempty"`
|
||||||
|
|
||||||
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GenerateResponse) Summary() {
|
||||||
|
if r.TotalDuration > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.PromptEvalCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.PromptEvalDuration > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
|
||||||
|
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.EvalCount > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.EvalDuration > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "eval duraiton: %s\n", r.EvalDuration)
|
||||||
|
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
|
33
cmd/cmd.go
33
cmd/cmd.go
|
@ -72,20 +72,20 @@ func pull(model string) error {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunGenerate(_ *cobra.Command, args []string) error {
|
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
if len(args) > 1 {
|
if len(args) > 1 {
|
||||||
// join all args into a single prompt
|
// join all args into a single prompt
|
||||||
return generate(args[0], strings.Join(args[1:], " "))
|
return generate(cmd, args[0], strings.Join(args[1:], " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
if term.IsTerminal(int(os.Stdin.Fd())) {
|
if term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
return generateInteractive(args[0])
|
return generateInteractive(cmd, args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
return generateBatch(args[0])
|
return generateBatch(cmd, args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(model, prompt string) error {
|
func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
if len(strings.TrimSpace(prompt)) > 0 {
|
if len(strings.TrimSpace(prompt)) > 0 {
|
||||||
client := api.NewClient()
|
client := api.NewClient()
|
||||||
|
|
||||||
|
@ -108,12 +108,16 @@ func generate(model, prompt string) error {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var latest api.GenerateResponse
|
||||||
|
|
||||||
request := api.GenerateRequest{Model: model, Prompt: prompt}
|
request := api.GenerateRequest{Model: model, Prompt: prompt}
|
||||||
fn := func(resp api.GenerateResponse) error {
|
fn := func(resp api.GenerateResponse) error {
|
||||||
if !spinner.IsFinished() {
|
if !spinner.IsFinished() {
|
||||||
spinner.Finish()
|
spinner.Finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
latest = resp
|
||||||
|
|
||||||
fmt.Print(resp.Response)
|
fmt.Print(resp.Response)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -124,16 +128,25 @@ func generate(model, prompt string) error {
|
||||||
|
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
|
verbose, err := cmd.Flags().GetBool("verbose")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
latest.Summary()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInteractive(model string) error {
|
func generateInteractive(cmd *cobra.Command, model string) error {
|
||||||
fmt.Print(">>> ")
|
fmt.Print(">>> ")
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
if err := generate(model, scanner.Text()); err != nil {
|
if err := generate(cmd, model, scanner.Text()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,12 +156,12 @@ func generateInteractive(model string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateBatch(model string) error {
|
func generateBatch(cmd *cobra.Command, model string) error {
|
||||||
scanner := bufio.NewScanner(os.Stdin)
|
scanner := bufio.NewScanner(os.Stdin)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
prompt := scanner.Text()
|
prompt := scanner.Text()
|
||||||
fmt.Printf(">>> %s\n", prompt)
|
fmt.Printf(">>> %s\n", prompt)
|
||||||
if err := generate(model, prompt); err != nil {
|
if err := generate(cmd, model, prompt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -200,6 +213,8 @@ func NewCLI() *cobra.Command {
|
||||||
RunE: RunRun,
|
RunE: RunRun,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||||
|
|
||||||
serveCmd := &cobra.Command{
|
serveCmd := &cobra.Command{
|
||||||
Use: "serve",
|
Use: "serve",
|
||||||
Aliases: []string{"start"},
|
Aliases: []string{"start"},
|
||||||
|
|
|
@ -79,9 +79,11 @@ llama_token llama_sample(
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
|
@ -147,7 +149,7 @@ func (llm *llama) Close() {
|
||||||
C.llama_print_timings(llm.ctx)
|
C.llama_print_timings(llm.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) Predict(prompt string, fn func(string)) error {
|
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
|
||||||
if tokens := llm.tokenize(prompt); tokens != nil {
|
if tokens := llm.tokenize(prompt); tokens != nil {
|
||||||
return llm.generate(tokens, fn)
|
return llm.generate(tokens, fn)
|
||||||
}
|
}
|
||||||
|
@ -176,7 +178,7 @@ func (llm *llama) detokenize(tokens ...C.llama_token) string {
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
|
func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
|
||||||
var opts C.struct_llama_sample_options
|
var opts C.struct_llama_sample_options
|
||||||
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
opts.repeat_penalty = C.float(llm.RepeatPenalty)
|
||||||
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
opts.frequency_penalty = C.float(llm.FrequencyPenalty)
|
||||||
|
@ -190,38 +192,58 @@ func (llm *llama) generate(tokens []C.llama_token, fn func(string)) error {
|
||||||
opts.mirostat_tau = C.float(llm.MirostatTau)
|
opts.mirostat_tau = C.float(llm.MirostatTau)
|
||||||
opts.mirostat_eta = C.float(llm.MirostatEta)
|
opts.mirostat_eta = C.float(llm.MirostatEta)
|
||||||
|
|
||||||
pastTokens := deque[C.llama_token]{capacity: llm.RepeatLastN}
|
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
||||||
|
|
||||||
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
||||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
||||||
return errors.New("llama: eval")
|
return errors.New("llama: eval")
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := llm.sample(pastTokens, &opts)
|
token, err := llm.sample(output, &opts)
|
||||||
switch {
|
if errors.Is(err, io.EOF) {
|
||||||
case errors.Is(err, io.EOF):
|
break
|
||||||
return nil
|
} else if err != nil {
|
||||||
case err != nil:
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(llm.detokenize(token))
|
// call the callback
|
||||||
|
fn(api.GenerateResponse{
|
||||||
|
Response: llm.detokenize(token),
|
||||||
|
})
|
||||||
|
|
||||||
tokens = []C.llama_token{token}
|
output.PushLeft(token)
|
||||||
|
|
||||||
pastTokens.PushLeft(token)
|
input = []C.llama_token{token}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dur := func(ms float64) time.Duration {
|
||||||
|
d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
timings := C.llama_get_timings(llm.ctx)
|
||||||
|
fn(api.GenerateResponse{
|
||||||
|
Done: true,
|
||||||
|
PromptEvalCount: int(timings.n_p_eval),
|
||||||
|
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
||||||
|
EvalCount: int(timings.n_eval),
|
||||||
|
EvalDuration: dur(float64(timings.t_eval_ms)),
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
func (llm *llama) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
|
||||||
numVocab := int(C.llama_n_vocab(llm.ctx))
|
numVocab := int(C.llama_n_vocab(llm.ctx))
|
||||||
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
|
||||||
|
|
||||||
candidates := make([]C.struct_llama_token_data, 0, numVocab)
|
candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
|
||||||
for i := 0; i < numVocab; i++ {
|
for i := 0; i < candidates.Cap(); i++ {
|
||||||
candidates = append(candidates, C.llama_token_data{
|
candidates.PushLeft(C.struct_llama_token_data{
|
||||||
id: C.int(i),
|
id: C.int(i),
|
||||||
logit: logits[i],
|
logit: logits[i],
|
||||||
p: 0,
|
p: 0,
|
||||||
|
@ -230,8 +252,8 @@ func (llm *llama) sample(pastTokens deque[C.llama_token], opts *C.struct_llama_s
|
||||||
|
|
||||||
token := C.llama_sample(
|
token := C.llama_sample(
|
||||||
llm.ctx,
|
llm.ctx,
|
||||||
unsafe.SliceData(candidates), C.ulong(len(candidates)),
|
unsafe.SliceData(candidates.Data()), C.ulong(candidates.Len()),
|
||||||
unsafe.SliceData(pastTokens.Data()), C.ulong(pastTokens.Len()),
|
unsafe.SliceData(output.Data()), C.ulong(output.Len()),
|
||||||
opts)
|
opts)
|
||||||
if token != C.llama_token_eos() {
|
if token != C.llama_token_eos() {
|
||||||
return token, nil
|
return token, nil
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||||
|
@ -35,6 +36,8 @@ func cacheDir() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(c *gin.Context) {
|
func generate(c *gin.Context) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Options: api.DefaultOptions(),
|
Options: api.DefaultOptions(),
|
||||||
}
|
}
|
||||||
|
@ -81,8 +84,14 @@ func generate(c *gin.Context) {
|
||||||
}
|
}
|
||||||
defer llm.Close()
|
defer llm.Close()
|
||||||
|
|
||||||
fn := func(s string) {
|
fn := func(r api.GenerateResponse) {
|
||||||
ch <- api.GenerateResponse{Response: s}
|
r.Model = req.Model
|
||||||
|
r.CreatedAt = time.Now().UTC()
|
||||||
|
if r.Done {
|
||||||
|
r.TotalDuration = time.Since(start)
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := llm.Predict(req.Prompt, fn); err != nil {
|
if err := llm.Predict(req.Prompt, fn); err != nil {
|
||||||
|
|
Loading…
Add table
Reference in a new issue