llm: reserve required number of slots for embeddings (#6219)

This commit is contained in:
Jeffrey Morgan 2024-08-06 23:20:49 -04:00 committed by GitHub
parent e04c7012c2
commit de4fc29773
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -44,11 +44,12 @@ type LlamaServer interface {
// llmServer is an instance of the llama.cpp server // llmServer is an instance of the llama.cpp server
type llmServer struct { type llmServer struct {
port int port int
cmd *exec.Cmd cmd *exec.Cmd
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
numParallel int
estimate MemoryEstimate estimate MemoryEstimate
totalLayers uint64 totalLayers uint64
@ -343,6 +344,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimate: estimate, estimate: estimate,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)), sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1, totalLayers: ggml.KV().BlockCount() + 1,
gpus: gpus, gpus: gpus,
@ -890,11 +892,14 @@ type EmbedResponse struct {
} }
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) { func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { // each input will use a slot, so we need to acquire the semaphore for
// the number of inputs up to numParallel
slots := int64(min(len(input), s.numParallel))
if err := s.sem.Acquire(ctx, slots); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
} }
defer s.sem.Release(1) defer s.sem.Release(slots)
// Make sure the server is ready // Make sure the server is ready
status, err := s.getServerStatusRetry(ctx) status, err := s.getServerStatusRetry(ctx)