Merge pull request #813 from jmorganca/mxyng/llama

refactor llm/llama.go
This commit is contained in:
Michael Yang 2023-10-17 14:05:58 -07:00 committed by GitHub
commit 08b0e04f40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
llm.Options = opts
}
type GenerationSettings struct {
FrequencyPenalty float64 `json:"frequency_penalty"`
IgnoreEOS bool `json:"ignore_eos"`
LogitBias []interface{} `json:"logit_bias"`
Mirostat int `json:"mirostat"`
MirostatEta float64 `json:"mirostat_eta"`
MirostatTau float64 `json:"mirostat_tau"`
Model string `json:"model"`
NCtx int `json:"n_ctx"`
NKeep int `json:"n_keep"`
NPredict int `json:"n_predict"`
NProbs int `json:"n_probs"`
PenalizeNl bool `json:"penalize_nl"`
PresencePenalty float64 `json:"presence_penalty"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float64 `json:"repeat_penalty"`
Seed uint32 `json:"seed"`
Stop []string `json:"stop"`
Stream bool `json:"stream"`
Temp float64 `json:"temp"`
TfsZ float64 `json:"tfs_z"`
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
TypicalP float64 `json:"typical_p"`
}
type Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type Prediction struct {
type prediction struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
Timings `json:"timings"`
Timings struct {
PredictedN int `json:"predicted_n"`
PredictedMS float64 `json:"predicted_ms"`
PromptN int `json:"prompt_n"`
PromptMS float64 `json:"prompt_ms"`
}
type PredictRequest struct {
Prompt string `json:"prompt"`
Stream bool `json:"stream"`
NPredict int `json:"n_predict"`
NKeep int `json:"n_keep"`
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
TfsZ float32 `json:"tfs_z"`
TypicalP float32 `json:"typical_p"`
RepeatLastN int `json:"repeat_last_n"`
RepeatPenalty float32 `json:"repeat_penalty"`
PresencePenalty float32 `json:"presence_penalty"`
FrequencyPenalty float32 `json:"frequency_penalty"`
Mirostat int `json:"mirostat"`
MirostatTau float32 `json:"mirostat_tau"`
MirostatEta float32 `json:"mirostat_eta"`
PenalizeNl bool `json:"penalize_nl"`
Seed int `json:"seed"`
Stop []string `json:"stop,omitempty"`
}
const maxBufferSize = 512 * format.KiloByte
@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
nextContext.WriteString(prevConvo)
nextContext.WriteString(prompt)
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
predReq := PredictRequest{
Prompt: nextContext.String(),
Stream: true,
NPredict: llm.NumPredict,
NKeep: llm.NumKeep,
Temperature: llm.Temperature,
TopK: llm.TopK,
TopP: llm.TopP,
TfsZ: llm.TFSZ,
TypicalP: llm.TypicalP,
RepeatLastN: llm.RepeatLastN,
RepeatPenalty: llm.RepeatPenalty,
PresencePenalty: llm.PresencePenalty,
FrequencyPenalty: llm.FrequencyPenalty,
Mirostat: llm.Mirostat,
MirostatTau: llm.MirostatTau,
MirostatEta: llm.MirostatEta,
PenalizeNl: llm.PenalizeNewline,
Seed: llm.Seed,
Stop: llm.Stop,
request := map[string]any{
"prompt": nextContext.String(),
"stream": true,
"n_predict": llm.NumPredict,
"n_keep": llm.NumKeep,
"temperature": llm.Temperature,
"top_k": llm.TopK,
"top_p": llm.TopP,
"tfs_z": llm.TFSZ,
"typical_p": llm.TypicalP,
"repeat_last_n": llm.RepeatLastN,
"repeat_penalty": llm.RepeatPenalty,
"presence_penalty": llm.PresencePenalty,
"frequency_penalty": llm.FrequencyPenalty,
"mirostat": llm.Mirostat,
"mirostat_tau": llm.MirostatTau,
"mirostat_eta": llm.MirostatEta,
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
}
// Handling JSON marshaling with special characters unescaped.
@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
if err := enc.Encode(predReq); err != nil {
if err := enc.Encode(request); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
// This handles the request cancellation
return ctx.Err()
default:
line := scanner.Text()
if line == "" {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
// Read data from the server-side event stream
if strings.HasPrefix(line, "data: ") {
evt := line[6:]
var p Prediction
if err := json.Unmarshal([]byte(evt), &p); err != nil {
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
var p prediction
if err := json.Unmarshal(evt, &p); err != nil {
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
}
@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
fn(api.GenerateResponse{
Done: true,
Context: embd,
PromptEvalCount: p.PromptN,
PromptEvalDuration: parseDurationMs(p.PromptMS),
EvalCount: p.PredictedN,
EvalDuration: parseDurationMs(p.PredictedMS),
PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
})
return nil