fix not forwarding last token

This commit is contained in:
Michael Yang 2023-09-03 17:46:35 -04:00
parent 5d3f314b0b
commit 59a705525c

View file

@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts llm.Options = opts
} }
type Prediction struct {
Content string `json:"content"`
Stop bool `json:"stop"`
}
type GenerationSettings struct { type GenerationSettings struct {
FrequencyPenalty float64 `json:"frequency_penalty"` FrequencyPenalty float64 `json:"frequency_penalty"`
IgnoreEOS bool `json:"ignore_eos"` IgnoreEOS bool `json:"ignore_eos"`
@ -385,31 +380,19 @@ type GenerationSettings struct {
} }
type Timings struct { type Timings struct {
PredictedMS float64 `json:"predicted_ms"`
PredictedN int `json:"predicted_n"` PredictedN int `json:"predicted_n"`
PredictedPerSecond float64 `json:"predicted_per_second"` PredictedMS float64 `json:"predicted_ms"`
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
PromptMS float64 `json:"prompt_ms"`
PromptN int `json:"prompt_n"` PromptN int `json:"prompt_n"`
PromptPerSecond float64 `json:"prompt_per_second"` PromptMS float64 `json:"prompt_ms"`
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
} }
type PredictComplete struct { type Prediction struct {
Content string `json:"content"` Content string `json:"content"`
GenerationSettings GenerationSettings `json:"generation_settings"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
StoppedEOS bool `json:"stopped_eos"`
StoppedLimit bool `json:"stopped_limit"` Timings `json:"timings"`
StoppedWord bool `json:"stopped_word"`
StoppingWord string `json:"stopping_word"`
Timings Timings `json:"timings"`
TokensCached int `json:"tokens_cached"`
TokensEvaluated int `json:"tokens_evaluated"`
TokensPredicted int `json:"tokens_predicted"`
Truncated bool `json:"truncated"`
} }
type PredictRequest struct { type PredictRequest struct {
@ -509,13 +492,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
// Read data from the server-side event stream // Read data from the server-side event stream
if strings.HasPrefix(line, "data: ") { if strings.HasPrefix(line, "data: ") {
evt := line[6:] evt := line[6:]
var complete PredictComplete var p Prediction
if err := json.Unmarshal([]byte(evt), &complete); err != nil { if err := json.Unmarshal([]byte(evt), &p); err != nil {
return fmt.Errorf("error unmarshaling llm complete response: %v", err) return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
} }
if complete.Timings.PredictedMS > 0 { fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(complete.Content) nextContext.WriteString(p.Content)
if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String()) embd, err := llm.Encode(ctx, nextContext.String())
if err != nil { if err != nil {
return fmt.Errorf("encoding context: %v", err) return fmt.Errorf("encoding context: %v", err)
@ -524,21 +509,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: embd, Context: embd,
PromptEvalCount: int(complete.Timings.PromptN), PromptEvalCount: p.PromptN,
PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)), PromptEvalDuration: parseDurationMs(p.PromptMS),
EvalCount: int(complete.Timings.PredictedN), EvalCount: p.PredictedN,
EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)), EvalDuration: parseDurationMs(p.PredictedMS),
}) })
return nil return nil
} }
var p Prediction
if err := json.Unmarshal([]byte(evt), &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
fn(api.GenerateResponse{Response: p.Content})
nextContext.WriteString(p.Content)
} }
} }
} }