continue conversation
feed responses back into the llm
This commit is contained in:
parent
77dc1a6d74
commit
1775647f76
11 changed files with 47 additions and 10 deletions
10
api/types.go
10
api/types.go
|
@ -18,8 +18,9 @@ type PullProgress struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Options `json:"options"`
|
Options `json:"options"`
|
||||||
}
|
}
|
||||||
|
@ -29,7 +30,8 @@ type GenerateResponse struct {
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Response string `json:"response,omitempty"`
|
Response string `json:"response,omitempty"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
@ -104,7 +106,7 @@ func DefaultOptions() Options {
|
||||||
|
|
||||||
UseNUMA: false,
|
UseNUMA: false,
|
||||||
|
|
||||||
NumCtx: 512,
|
NumCtx: 2048,
|
||||||
NumBatch: 512,
|
NumBatch: 512,
|
||||||
NumGPU: 1,
|
NumGPU: 1,
|
||||||
LowVRAM: false,
|
LowVRAM: false,
|
||||||
|
|
11
cmd/cmd.go
11
cmd/cmd.go
|
@ -85,6 +85,8 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||||
return generateBatch(cmd, args[0])
|
return generateBatch(cmd, args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var generateContextKey struct{}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, model, prompt string) error {
|
func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
if len(strings.TrimSpace(prompt)) > 0 {
|
if len(strings.TrimSpace(prompt)) > 0 {
|
||||||
client := api.NewClient()
|
client := api.NewClient()
|
||||||
|
@ -110,7 +112,12 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
|
|
||||||
var latest api.GenerateResponse
|
var latest api.GenerateResponse
|
||||||
|
|
||||||
request := api.GenerateRequest{Model: model, Prompt: prompt}
|
generateContext, ok := cmd.Context().Value(generateContextKey).([]int)
|
||||||
|
if !ok {
|
||||||
|
generateContext = []int{}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.GenerateRequest{Model: model, Prompt: prompt, Context: generateContext}
|
||||||
fn := func(resp api.GenerateResponse) error {
|
fn := func(resp api.GenerateResponse) error {
|
||||||
if !spinner.IsFinished() {
|
if !spinner.IsFinished() {
|
||||||
spinner.Finish()
|
spinner.Finish()
|
||||||
|
@ -119,6 +126,8 @@ func generate(cmd *cobra.Command, model, prompt string) error {
|
||||||
latest = resp
|
latest = resp
|
||||||
|
|
||||||
fmt.Print(resp.Response)
|
fmt.Print(resp.Response)
|
||||||
|
|
||||||
|
cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey, resp.Context))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -149,9 +149,14 @@ func (llm *llama) Close() {
|
||||||
C.llama_print_timings(llm.ctx)
|
C.llama_print_timings(llm.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) Predict(prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
if tokens := llm.tokenize(prompt); tokens != nil {
|
if input := llm.tokenize(prompt); input != nil {
|
||||||
return llm.generate(tokens, fn)
|
embd := make([]C.llama_token, len(ctx))
|
||||||
|
for i := range ctx {
|
||||||
|
embd[i] = C.llama_token(ctx[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return llm.generate(append(embd, input...), fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("llama: tokenize")
|
return errors.New("llama: tokenize")
|
||||||
|
@ -194,6 +199,11 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||||
|
|
||||||
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
output := deque[C.llama_token]{capacity: llm.NumCtx}
|
||||||
|
|
||||||
|
context := deque[int]{capacity: llm.NumCtx / 2}
|
||||||
|
for _, in := range input {
|
||||||
|
context.PushLeft(int(in))
|
||||||
|
}
|
||||||
|
|
||||||
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
|
||||||
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
|
||||||
return errors.New("llama: eval")
|
return errors.New("llama: eval")
|
||||||
|
@ -212,6 +222,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||||
})
|
})
|
||||||
|
|
||||||
output.PushLeft(token)
|
output.PushLeft(token)
|
||||||
|
context.PushLeft(int(token))
|
||||||
|
|
||||||
input = []C.llama_token{token}
|
input = []C.llama_token{token}
|
||||||
}
|
}
|
||||||
|
@ -228,6 +239,7 @@ func (llm *llama) generate(input []C.llama_token, fn func(api.GenerateResponse))
|
||||||
timings := C.llama_get_timings(llm.ctx)
|
timings := C.llama_get_timings(llm.ctx)
|
||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
|
Context: context.Data(),
|
||||||
PromptEvalCount: int(timings.n_p_eval),
|
PromptEvalCount: int(timings.n_p_eval),
|
||||||
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
|
||||||
EvalCount: int(timings.n_eval),
|
EvalCount: int(timings.n_eval),
|
||||||
|
|
4
main.go
4
main.go
|
@ -1,9 +1,11 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/cmd"
|
"github.com/jmorganca/ollama/cmd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cmd.NewCLI().Execute()
|
cmd.NewCLI().ExecuteContext(context.Background())
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,7 +94,7 @@ func generate(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := llm.Predict(req.Prompt, fn); err != nil {
|
if err := llm.Predict(req.Context, req.Prompt, fn); err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
{{- if not .Context }}
|
||||||
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
### Instruction:
|
### Instruction:
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
{{- if not .Context }}
|
||||||
A helpful assistant who helps the user with any questions asked.
|
A helpful assistant who helps the user with any questions asked.
|
||||||
|
{{- end }}
|
||||||
User: {{ .Prompt }}
|
User: {{ .Prompt }}
|
||||||
Assistant:
|
Assistant:
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
{{- if not .Context }}
|
||||||
Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
|
Below is an instruction that describes a task. Write a response that appropriately completes the request. Be concise. Once the request is completed, include no other text.
|
||||||
|
{{- end }}
|
||||||
### Instruction:
|
### Instruction:
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
### Response:
|
### Response:
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
{{- if not .Context }}
|
||||||
### System:
|
### System:
|
||||||
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
### User:
|
### User:
|
||||||
{{ .Prompt }}
|
{{ .Prompt }}
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
{{ if not .Context }}
|
||||||
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
USER: {{ .Prompt }}
|
USER: {{ .Prompt }}
|
||||||
ASSISTANT:
|
ASSISTANT:
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
{{- if not .Context }}
|
||||||
Below is an instruction that describes a task. Write a response that appropriately completes the request
|
Below is an instruction that describes a task. Write a response that appropriately completes the request
|
||||||
|
{{- end }}
|
||||||
|
|
||||||
### Instruction: {{ .Prompt }}
|
### Instruction: {{ .Prompt }}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue