Merge pull request #813 from jmorganca/mxyng/llama

refactor llm/llama.go
This commit is contained in:
Michael Yang 2023-10-17 14:05:58 -07:00 committed by GitHub
commit 08b0e04f40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts llm.Options = opts
} }
type GenerationSettings struct { type prediction struct {
FrequencyPenalty float64 `json:"frequency_penalty"`
IgnoreEOS bool `json:"ignore_eos"`
LogitBias []interface{} `json:"logit_bias"`
Mirostat int `json:"mirostat"`
MirostatEta float64 `json:"mirostat_eta"`
MirostatTau float64 `json:"mirostat_tau"`
Model string `json:"model"`
NCtx int `json:"n_ctx"`
NKeep int `json:"n_keep"`
NPredict int `json:"n_predict"`
NProbs int `json:"n_probs"`
PenalizeNl bool `json:"penalize_nl"`
PresencePenalty float64 `json:"presence_penalty"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float64 `json:"repeat_penalty"`
Seed uint32 `json:"seed"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
Temp float64 `json:"temp"`
TfsZ float64 `json:"tfs_z"`
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
TypicalP float64 `json:"typical_p"`
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type Prediction struct {
Content string `json:"content"` Content string `json:"content"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
Timings `json:"timings"` Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
} }
type PredictRequest struct {
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
NPredict int `json:"n_predict"`
NKeep int `json:"n_keep"`
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
TfsZ float32 `json:"tfs_z"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
PenalizeNl bool `json:"penalize_nl"`
Seed int `json:"seed"`
Stop []string `json:"stop,omitempty"`
} }
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
nextContext.WriteString(prevConvo) nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt) nextContext.WriteString(prompt)
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) request := map[string]any{
predReq := PredictRequest{ "prompt": nextContext.String(),
Prompt: nextContext.String(), "stream": true,
Stream: true, "n_predict": llm.NumPredict,
NPredict: llm.NumPredict, "n_keep": llm.NumKeep,
NKeep: llm.NumKeep, "temperature": llm.Temperature,
Temperature: llm.Temperature, "top_k": llm.TopK,
TopK: llm.TopK, "top_p": llm.TopP,
TopP: llm.TopP, "tfs_z": llm.TFSZ,
TfsZ: llm.TFSZ, "typical_p": llm.TypicalP,
TypicalP: llm.TypicalP, "repeat_last_n": llm.RepeatLastN,
RepeatLastN: llm.RepeatLastN, "repeat_penalty": llm.RepeatPenalty,
RepeatPenalty: llm.RepeatPenalty, "presence_penalty": llm.PresencePenalty,
PresencePenalty: llm.PresencePenalty, "frequency_penalty": llm.FrequencyPenalty,
FrequencyPenalty: llm.FrequencyPenalty, "mirostat": llm.Mirostat,
Mirostat: llm.Mirostat, "mirostat_tau": llm.MirostatTau,
MirostatTau: llm.MirostatTau, "mirostat_eta": llm.MirostatEta,
MirostatEta: llm.MirostatEta, "penalize_nl": llm.PenalizeNewline,
PenalizeNl: llm.PenalizeNewline, "seed": llm.Seed,
Seed: llm.Seed, "stop": llm.Stop,
Stop: llm.Stop,
} }
// Handling JSON marshaling with special characters unescaped. // Handling JSON marshaling with special characters unescaped.
@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
enc := json.NewEncoder(buffer) enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
if err := enc.Encode(predReq); err != nil { if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err) return fmt.Errorf("failed to marshal data: %v", err)
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil { if err != nil {
return fmt.Errorf("error creating POST request: %v", err) return fmt.Errorf("error creating POST request: %v", err)
@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
// This handles the request cancellation // This handles the request cancellation
return ctx.Err() return ctx.Err()
default: default:
line := scanner.Text() line := scanner.Bytes()
if line == "" { if len(line) == 0 {
continue continue
} }
// Read data from the server-side event stream if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
if strings.HasPrefix(line, "data: ") { var p prediction
evt := line[6:] if err := json.Unmarshal(evt, &p); err != nil {
var p Prediction
if err := json.Unmarshal([]byte(evt), &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err) return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
} }
@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
fn(api.GenerateResponse{ fn(api.GenerateResponse{
Done: true, Done: true,
Context: embd, Context: embd,
PromptEvalCount: p.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.PredictedN, EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.PredictedMS), EvalDuration: parseDurationMs(p.Timings.PredictedMS),
}) })
return nil return nil