remove marshalPrompt which is no longer needed
This commit is contained in:
parent
adaa13088b
commit
5d3f314b0b
1 changed files with 19 additions and 42 deletions
|
@ -286,8 +286,8 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
|
||||||
runner.Path,
|
runner.Path,
|
||||||
append(params, "--port", strconv.Itoa(port))...,
|
append(params, "--port", strconv.Itoa(port))...,
|
||||||
)
|
)
|
||||||
var stderr bytes.Buffer
|
cmd.Stdout = os.Stderr
|
||||||
cmd.Stderr = &stderr
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
|
llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
|
||||||
|
|
||||||
|
@ -437,15 +437,19 @@ type PredictRequest struct {
|
||||||
Stop []string `json:"stop,omitempty"`
|
Stop []string `json:"stop,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
// we need to find the trimmed prompt context before predicting so that we can return it to the client
|
prevConvo, err := llm.Decode(ctx, prevContext)
|
||||||
trimmedPrompt, err := llm.marshalPrompt(ctx, predictCtx, prompt)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("marshaling prompt: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var nextContext strings.Builder
|
||||||
|
nextContext.WriteString(prevConvo)
|
||||||
|
nextContext.WriteString(prompt)
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||||
predReq := PredictRequest{
|
predReq := PredictRequest{
|
||||||
Prompt: trimmedPrompt,
|
Prompt: nextContext.String(),
|
||||||
Stream: true,
|
Stream: true,
|
||||||
NPredict: llm.NumPredict,
|
NPredict: llm.NumPredict,
|
||||||
NKeep: llm.NumKeep,
|
NKeep: llm.NumKeep,
|
||||||
|
@ -491,7 +495,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
genCtx := trimmedPrompt // start with the trimmed prompt
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -512,11 +515,12 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
||||||
}
|
}
|
||||||
|
|
||||||
if complete.Timings.PredictedMS > 0 {
|
if complete.Timings.PredictedMS > 0 {
|
||||||
genCtx += complete.Content
|
nextContext.WriteString(complete.Content)
|
||||||
embd, err := llm.Encode(ctx, genCtx)
|
embd, err := llm.Encode(ctx, nextContext.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("encoding context: %v", err)
|
return fmt.Errorf("encoding context: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: embd,
|
Context: embd,
|
||||||
|
@ -528,12 +532,13 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var pred Prediction
|
var p Prediction
|
||||||
if err := json.Unmarshal([]byte(evt), &pred); err != nil {
|
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
genCtx += pred.Content
|
|
||||||
fn(api.GenerateResponse{Response: pred.Content})
|
fn(api.GenerateResponse{Response: p.Content})
|
||||||
|
nextContext.WriteString(p.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -545,34 +550,6 @@ func (llm *llama) Predict(ctx context.Context, predictCtx []int, prompt string,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *llama) marshalPrompt(ctx context.Context, pCtx []int, prompt string) (string, error) {
|
|
||||||
pEncode, err := llm.Encode(ctx, prompt)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("encoding prompt context: %w", err)
|
|
||||||
}
|
|
||||||
tokens := append(pCtx, pEncode...)
|
|
||||||
if llm.NumKeep < 0 {
|
|
||||||
llm.NumKeep = len(tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
// min(llm.NumCtx - 4, llm.NumKeep)
|
|
||||||
if llm.NumCtx-4 < llm.NumKeep {
|
|
||||||
llm.NumKeep = llm.NumCtx - 4
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tokens) >= llm.NumCtx {
|
|
||||||
// truncate input
|
|
||||||
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
|
||||||
truncated := tokens[:llm.NumKeep]
|
|
||||||
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
|
||||||
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
|
||||||
tokens = truncated
|
|
||||||
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
|
||||||
}
|
|
||||||
|
|
||||||
return llm.Decode(ctx, tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue