Merge pull request #463 from jmorganca/mxyng/fix-last-token
fix not forwarding last token
This commit is contained in:
commit
8dc68417e7
1 changed files with 23 additions and 45 deletions
|
@ -353,11 +353,6 @@ func (llm *llama) SetOptions(opts api.Options) {
|
||||||
llm.Options = opts
|
llm.Options = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
type Prediction struct {
|
|
||||||
Content string `json:"content"`
|
|
||||||
Stop bool `json:"stop"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerationSettings struct {
|
type GenerationSettings struct {
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
FrequencyPenalty float64 `json:"frequency_penalty"`
|
||||||
IgnoreEOS bool `json:"ignore_eos"`
|
IgnoreEOS bool `json:"ignore_eos"`
|
||||||
|
@ -385,31 +380,19 @@ type GenerationSettings struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Timings struct {
|
type Timings struct {
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PredictedN int `json:"predicted_n"`
|
PredictedN int `json:"predicted_n"`
|
||||||
PredictedPerSecond float64 `json:"predicted_per_second"`
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
PredictedPerTokenMS float64 `json:"predicted_per_token_ms"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
PromptN int `json:"prompt_n"`
|
||||||
PromptPerSecond float64 `json:"prompt_per_second"`
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
PromptPerTokenMS float64 `json:"prompt_per_token_ms"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictComplete struct {
|
type Prediction struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
GenerationSettings GenerationSettings `json:"generation_settings"`
|
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
StoppedEOS bool `json:"stopped_eos"`
|
|
||||||
StoppedLimit bool `json:"stopped_limit"`
|
Timings `json:"timings"`
|
||||||
StoppedWord bool `json:"stopped_word"`
|
|
||||||
StoppingWord string `json:"stopping_word"`
|
|
||||||
Timings Timings `json:"timings"`
|
|
||||||
TokensCached int `json:"tokens_cached"`
|
|
||||||
TokensEvaluated int `json:"tokens_evaluated"`
|
|
||||||
TokensPredicted int `json:"tokens_predicted"`
|
|
||||||
Truncated bool `json:"truncated"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictRequest struct {
|
type PredictRequest struct {
|
||||||
|
@ -509,13 +492,15 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
// Read data from the server-side event stream
|
// Read data from the server-side event stream
|
||||||
if strings.HasPrefix(line, "data: ") {
|
if strings.HasPrefix(line, "data: ") {
|
||||||
evt := line[6:]
|
evt := line[6:]
|
||||||
var complete PredictComplete
|
var p Prediction
|
||||||
if err := json.Unmarshal([]byte(evt), &complete); err != nil {
|
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
||||||
return fmt.Errorf("error unmarshaling llm complete response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if complete.Timings.PredictedMS > 0 {
|
fn(api.GenerateResponse{Response: p.Content})
|
||||||
nextContext.WriteString(complete.Content)
|
nextContext.WriteString(p.Content)
|
||||||
|
|
||||||
|
if p.Stop {
|
||||||
embd, err := llm.Encode(ctx, nextContext.String())
|
embd, err := llm.Encode(ctx, nextContext.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("encoding context: %v", err)
|
return fmt.Errorf("encoding context: %v", err)
|
||||||
|
@ -524,21 +509,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: embd,
|
Context: embd,
|
||||||
PromptEvalCount: int(complete.Timings.PromptN),
|
PromptEvalCount: p.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(float64(complete.Timings.PromptMS)),
|
PromptEvalDuration: parseDurationMs(p.PromptMS),
|
||||||
EvalCount: int(complete.Timings.PredictedN),
|
EvalCount: p.PredictedN,
|
||||||
EvalDuration: parseDurationMs(float64(complete.Timings.PredictedMS)),
|
EvalDuration: parseDurationMs(p.PredictedMS),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var p Prediction
|
|
||||||
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn(api.GenerateResponse{Response: p.Content})
|
|
||||||
nextContext.WriteString(p.Content)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue