diff --git a/llm/llama.go b/llm/llama.go index 33052e3c..26a7ee77 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -412,10 +412,6 @@ func newLlama(model string, adapters, projectors []string, runners []ModelRunner port := rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range params := append(params, "--port", strconv.Itoa(port)) - if runner.Type == "gguf" { - params = append(params, "--parallel", "2") - } - ctx, cancel := context.WithCancel(context.Background()) cmd := exec.CommandContext( ctx, @@ -549,6 +545,8 @@ type prediction struct { } const maxBufferSize = 512 * format.KiloByte +const maxRetries = 3 +const retryDelay = 1 * time.Second type PredictOpts struct { Prompt string @@ -570,6 +568,11 @@ type PredictResult struct { EvalDuration time.Duration } +// IsRetryable checks if the line matches a condition that can be retried +func isRetryable(line []byte) bool { + return bytes.Contains(line, []byte("slot unavailable")) +} + func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { imageData := llm.ImageData if len(predict.Images) > 0 { @@ -607,98 +610,116 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred request["grammar"] = jsonGrammar } - // Handling JSON marshaling with special characters unescaped. - buffer := &bytes.Buffer{} - enc := json.NewEncoder(buffer) - enc.SetEscapeHTML(false) + for retries := 0; retries < maxRetries; retries++ { + if retries > 0 { + time.Sleep(retryDelay) // wait before retrying + } - if err := enc.Encode(request); err != nil { - return fmt.Errorf("failed to marshal data: %v", err) - } + // Handling JSON marshaling with special characters unescaped. + buffer := &bytes.Buffer{} + enc := json.NewEncoder(buffer) + enc.SetEscapeHTML(false) - endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) - if err != nil { - return fmt.Errorf("error creating POST request: %v", err) - } - req.Header.Set("Content-Type", "application/json") + if err := enc.Encode(request); err != nil { + return fmt.Errorf("failed to marshal data: %v", err) + } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("POST predict: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - bodyBytes, err := io.ReadAll(resp.Body) + endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", llm.Port) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) if err != nil { - return fmt.Errorf("failed reading llm error response: %w", err) + return fmt.Errorf("error creating POST request: %v", err) } - log.Printf("llm predict error: %s", bodyBytes) - return fmt.Errorf("%s", bodyBytes) - } + req.Header.Set("Content-Type", "application/json") - scanner := bufio.NewScanner(resp.Body) - // increase the buffer size to avoid running out of space - buf := make([]byte, 0, maxBufferSize) - scanner.Buffer(buf, maxBufferSize) - for scanner.Scan() { - select { - case <-ctx.Done(): - // This handles the request cancellation - return ctx.Err() - default: - line := scanner.Bytes() - if len(line) == 0 { - continue + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("POST predict: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed reading llm error response: %w", err) } + log.Printf("llm predict error: %s", bodyBytes) + return fmt.Errorf("%s", bodyBytes) + } - evt, ok := bytes.CutPrefix(line, []byte("data: ")) - if !ok { - return fmt.Errorf("error parsing llm response stream: %s", line) + scanner := bufio.NewScanner(resp.Body) + // increase the buffer size to avoid running out of space + buf := make([]byte, 0, maxBufferSize) + scanner.Buffer(buf, maxBufferSize) + + retryNeeded := false + for scanner.Scan() { + select { + case <-ctx.Done(): + // This handles the request cancellation + return ctx.Err() + default: + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + if isRetryable(line) { + retryNeeded = true + break + } + + evt, ok := bytes.CutPrefix(line, []byte("data: ")) + if !ok { + return fmt.Errorf("error parsing llm response stream: %s", line) + } + + var p prediction + if err := json.Unmarshal(evt, &p); err != nil { + return fmt.Errorf("error unmarshaling llm prediction response: %v", err) + } + + if p.Content != "" { + fn(PredictResult{ + CreatedAt: time.Now().UTC(), + Content: p.Content, + }) + } + + if p.Stop { + fn(PredictResult{ + CreatedAt: time.Now().UTC(), + TotalDuration: time.Since(predict.CheckpointStart), + + Done: true, + PromptEvalCount: p.Timings.PromptN, + PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), + EvalCount: p.Timings.PredictedN, + EvalDuration: parseDurationMs(p.Timings.PredictedMS), + }) + return nil + } } + } - var p prediction - if err := json.Unmarshal(evt, &p); err != nil { - return fmt.Errorf("error unmarshaling llm prediction response: %v", err) + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "unexpected EOF") { + // this means the llama runner subprocess crashed + llm.Close() + if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" { + return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg) + } + return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model") } + return fmt.Errorf("error reading llm response: %v", err) + } - if p.Content != "" { - fn(PredictResult{ - CreatedAt: time.Now().UTC(), - Content: p.Content, - }) - } - - if p.Stop { - fn(PredictResult{ - CreatedAt: time.Now().UTC(), - TotalDuration: time.Since(predict.CheckpointStart), - - Done: true, - PromptEvalCount: p.Timings.PromptN, - PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), - EvalCount: p.Timings.PredictedN, - EvalDuration: parseDurationMs(p.Timings.PredictedMS), - }) - return nil - } + if !retryNeeded { + return nil // success } } - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "unexpected EOF") { - // this means the llama runner subprocess crashed - llm.Close() - if llm.StatusWriter != nil && llm.StatusWriter.LastErrMsg != "" { - return fmt.Errorf("llama runner exited: %v", llm.StatusWriter.LastErrMsg) - } - return fmt.Errorf("llama runner exited, you may not have enough available memory to run this model") - } - return fmt.Errorf("error reading llm response: %v", err) - } - - return nil + // should never reach here ideally + return fmt.Errorf("max retries exceeded") } type TokenizeRequest struct {