Merge pull request #813 from jmorganca/mxyng/llama
refactor llm/llama.go
This commit is contained in:
commit
08b0e04f40
1 changed files with 38 additions and 90 deletions
126
llm/llama.go
126
llm/llama.go
|
@ -442,68 +442,18 @@ func (llm *llama) SetOptions(opts api.Options) {
|
||||||
llm.Options = opts
|
llm.Options = opts
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerationSettings struct {
|
type prediction struct {
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty"`
|
|
||||||
IgnoreEOS bool `json:"ignore_eos"`
|
|
||||||
LogitBias []interface{} `json:"logit_bias"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatEta float64 `json:"mirostat_eta"`
|
|
||||||
MirostatTau float64 `json:"mirostat_tau"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
NCtx int `json:"n_ctx"`
|
|
||||||
NKeep int `json:"n_keep"`
|
|
||||||
NPredict int `json:"n_predict"`
|
|
||||||
NProbs int `json:"n_probs"`
|
|
||||||
PenalizeNl bool `json:"penalize_nl"`
|
|
||||||
PresencePenalty float64 `json:"presence_penalty"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
RepeatPenalty float64 `json:"repeat_penalty"`
|
|
||||||
Seed uint32 `json:"seed"`
|
|
||||||
Stop []string `json:"stop"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Temp float64 `json:"temp"`
|
|
||||||
TfsZ float64 `json:"tfs_z"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float64 `json:"top_p"`
|
|
||||||
TypicalP float64 `json:"typical_p"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Timings struct {
|
|
||||||
PredictedN int `json:"predicted_n"`
|
|
||||||
PredictedMS float64 `json:"predicted_ms"`
|
|
||||||
PromptN int `json:"prompt_n"`
|
|
||||||
PromptMS float64 `json:"prompt_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Prediction struct {
|
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
Timings `json:"timings"`
|
Timings struct {
|
||||||
|
PredictedN int `json:"predicted_n"`
|
||||||
|
PredictedMS float64 `json:"predicted_ms"`
|
||||||
|
PromptN int `json:"prompt_n"`
|
||||||
|
PromptMS float64 `json:"prompt_ms"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PredictRequest struct {
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
NPredict int `json:"n_predict"`
|
|
||||||
NKeep int `json:"n_keep"`
|
|
||||||
Temperature float32 `json:"temperature"`
|
|
||||||
TopK int `json:"top_k"`
|
|
||||||
TopP float32 `json:"top_p"`
|
|
||||||
TfsZ float32 `json:"tfs_z"`
|
|
||||||
TypicalP float32 `json:"typical_p"`
|
|
||||||
RepeatLastN int `json:"repeat_last_n"`
|
|
||||||
RepeatPenalty float32 `json:"repeat_penalty"`
|
|
||||||
PresencePenalty float32 `json:"presence_penalty"`
|
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
||||||
Mirostat int `json:"mirostat"`
|
|
||||||
MirostatTau float32 `json:"mirostat_tau"`
|
|
||||||
MirostatEta float32 `json:"mirostat_eta"`
|
|
||||||
PenalizeNl bool `json:"penalize_nl"`
|
|
||||||
Seed int `json:"seed"`
|
|
||||||
Stop []string `json:"stop,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
@ -518,27 +468,26 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
nextContext.WriteString(prevConvo)
|
nextContext.WriteString(prevConvo)
|
||||||
nextContext.WriteString(prompt)
|
nextContext.WriteString(prompt)
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
request := map[string]any{
|
||||||
predReq := PredictRequest{
|
"prompt": nextContext.String(),
|
||||||
Prompt: nextContext.String(),
|
"stream": true,
|
||||||
Stream: true,
|
"n_predict": llm.NumPredict,
|
||||||
NPredict: llm.NumPredict,
|
"n_keep": llm.NumKeep,
|
||||||
NKeep: llm.NumKeep,
|
"temperature": llm.Temperature,
|
||||||
Temperature: llm.Temperature,
|
"top_k": llm.TopK,
|
||||||
TopK: llm.TopK,
|
"top_p": llm.TopP,
|
||||||
TopP: llm.TopP,
|
"tfs_z": llm.TFSZ,
|
||||||
TfsZ: llm.TFSZ,
|
"typical_p": llm.TypicalP,
|
||||||
TypicalP: llm.TypicalP,
|
"repeat_last_n": llm.RepeatLastN,
|
||||||
RepeatLastN: llm.RepeatLastN,
|
"repeat_penalty": llm.RepeatPenalty,
|
||||||
RepeatPenalty: llm.RepeatPenalty,
|
"presence_penalty": llm.PresencePenalty,
|
||||||
PresencePenalty: llm.PresencePenalty,
|
"frequency_penalty": llm.FrequencyPenalty,
|
||||||
FrequencyPenalty: llm.FrequencyPenalty,
|
"mirostat": llm.Mirostat,
|
||||||
Mirostat: llm.Mirostat,
|
"mirostat_tau": llm.MirostatTau,
|
||||||
MirostatTau: llm.MirostatTau,
|
"mirostat_eta": llm.MirostatEta,
|
||||||
MirostatEta: llm.MirostatEta,
|
"penalize_nl": llm.PenalizeNewline,
|
||||||
PenalizeNl: llm.PenalizeNewline,
|
"seed": llm.Seed,
|
||||||
Seed: llm.Seed,
|
"stop": llm.Stop,
|
||||||
Stop: llm.Stop,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handling JSON marshaling with special characters unescaped.
|
// Handling JSON marshaling with special characters unescaped.
|
||||||
|
@ -546,10 +495,11 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
enc := json.NewEncoder(buffer)
|
enc := json.NewEncoder(buffer)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
if err := enc.Encode(predReq); err != nil {
|
if err := enc.Encode(request); err != nil {
|
||||||
return fmt.Errorf("failed to marshal data: %v", err)
|
return fmt.Errorf("failed to marshal data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port)
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating POST request: %v", err)
|
return fmt.Errorf("error creating POST request: %v", err)
|
||||||
|
@ -581,16 +531,14 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
// This handles the request cancellation
|
// This handles the request cancellation
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
line := scanner.Text()
|
line := scanner.Bytes()
|
||||||
if line == "" {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read data from the server-side event stream
|
if evt, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
||||||
if strings.HasPrefix(line, "data: ") {
|
var p prediction
|
||||||
evt := line[6:]
|
if err := json.Unmarshal(evt, &p); err != nil {
|
||||||
var p Prediction
|
|
||||||
if err := json.Unmarshal([]byte(evt), &p); err != nil {
|
|
||||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -608,10 +556,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
fn(api.GenerateResponse{
|
fn(api.GenerateResponse{
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: embd,
|
Context: embd,
|
||||||
PromptEvalCount: p.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(p.PromptMS),
|
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||||
EvalCount: p.PredictedN,
|
EvalCount: p.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(p.PredictedMS),
|
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in a new issue