continue conversation

feed responses back into the llm
This commit is contained in:
Michael Yang 2023-07-13 11:02:53 -07:00
parent 77dc1a6d74
commit 1775647f76
11 changed files with 47 additions and 10 deletions

View file

@ -20,6 +20,7 @@ type PullProgress struct {
type GenerateRequest struct { type GenerateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Context []int `json:"context,omitempty"`
Options `json:"options"` Options `json:"options"`
} }
@ -30,6 +31,7 @@ type GenerateResponse struct {
Response string `json:"response,omitempty"` Response string `json:"response,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"` PromptEvalCount int `json:"prompt_eval_count,omitempty"`
@ -104,7 +106,7 @@ func DefaultOptions() Options {
UseNUMA: false, UseNUMA: false,
NumCtx: 512, NumCtx: 2048,
NumBatch: 512, NumBatch: 512,
NumGPU: 1, NumGPU: 1,
LowVRAM: false, LowVRAM: false,

View file

@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
return generateBatch(cmd, args[0]) return generateBatch(cmd, args[0])
} }
var generateContextKey struct{}
func generate(cmd *cobra.Command, model, prompt string) error { func generate(cmd *cobra.Command, model, prompt string) error {
if len(strings.TrimSpace(prompt)) > 0 { if len(strings.TrimSpace(prompt)) > 0 {
client := api.NewClient() client := api.NewClient()
@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error {
var latest api.GenerateResponse var latest api.GenerateResponse
request := api.GenerateRequest{Model: model, Prompt: prompt} generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
if !ok {
generateContext = []int{}
}
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
fn := func(resp api.GenerateResponse) error { fn := func(resp api.GenerateResponse) error {
if !spinner.IsFinished() { if !spinner.IsFinished() {
spinner.Finish() spinner.Finish()
@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
latest = resp latest = resp
fmt.Print(resp.Response) fmt.Print(resp.Response)
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
return nil return nil
} }

View file

@ -149,9 +149,14 @@ func (llm *llama) Close() {
C.llama_print_timings(llm.ctx) C.llama_print_timings(llm.ctx)
} }
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error { func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
if tokens := llm.tokenize(prompt); tokens != nil { if input := llm.tokenize(prompt); input != nil {
return llm.generate(tokens, fn) embd := make([]C.llama_token, len(ctx))
for i := range ctx {
embd[i] = C.llama_token(ctx[i])
}
return llm.generate(append(embd, input...), fn)
} }
return errors.New("llama: tokenize") return errors.New("llama: tokenize")
@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
output := deque[C.llama_token]{capacity: llm.NumCtx} output := deque[C.llama_token]{capacity: llm.NumCtx}
context := deque[int]{capacity: llm.NumCtx / 2}
for _, in := range input {
context.PushLeft(int(in))
}
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) { for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 { if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
return errors.New("llama: eval") return errors.New("llama: eval")
@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
}) })
output.PushLeft(token) output.PushLeft(token)
context.PushLeft(int(token))
input = []C.llama_token{token} input = []C.llama_token{token}
} }
@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
timings := C.llama_get_timings(llm.ctx) timings := C.llama_get_timings(llm.ctx)
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: context.Data(),
PromptEvalCount: int(timings.n_p_eval), PromptEvalCount: int(timings.n_p_eval),
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)), PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
EvalCount: int(timings.n_eval), EvalCount: int(timings.n_eval),

View file

@ -1,9 +1,11 @@
package main package main
import ( import (
"context"
"github.com/jmorganca/ollama/cmd" "github.com/jmorganca/ollama/cmd"
) )
func main() { func main() {
cmd.NewCLI().Execute() cmd.NewCLI().ExecuteContext(context.Background())
} }

View file

@ -94,7 +94,7 @@ func generate(c *gin.Context) {
ch <- r ch <- r
} }
if err := llm.Predict(req.Prompt, fn); err != nil { if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }

View file

@ -1,4 +1,6 @@
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request. Below is an instruction that describes a task. Write a response that appropriately completes the request.
{{- end }}
### Instruction: ### Instruction:
{{ .Prompt }} {{ .Prompt }}

View file

@ -1,3 +1,5 @@
{{- if not .Context }}
A helpful assistant who helps the user with any questions asked. A helpful assistant who helps the user with any questions asked.
{{- end }}
User: {{ .Prompt }} User: {{ .Prompt }}
Assistant: Assistant:

View file

@ -1,4 +1,6 @@
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text. Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
{{- end }}
### Instruction: ### Instruction:
{{ .Prompt }} {{ .Prompt }}
### Response: ### Response:

View file

@ -1,5 +1,7 @@
{{- if not .Context }}
### System: ### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can. You are an AI assistant that follows instruction extremely well. Help as much as you can.
{{- end }}
### User: ### User:
{{ .Prompt }} {{ .Prompt }}

View file

@ -1,4 +1,6 @@
{{ if not .Context }}
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
{{- end }}
USER: {{ .Prompt }} USER: {{ .Prompt }}
ASSISTANT: ASSISTANT:

View file

@ -1,4 +1,6 @@
{{- if not .Context }}
Below is an instruction that describes a task. Write a response that appropriately completes the request Below is an instruction that describes a task. Write a response that appropriately completes the request
{{- end }}
### Instruction: {{ .Prompt }} ### Instruction: {{ .Prompt }}