llm: reserve required number of slots for embeddings (#6219)
This commit is contained in:
parent
e04c7012c2
commit
de4fc29773
1 changed files with 12 additions and 7 deletions
|
@ -44,11 +44,12 @@ type LlamaServer interface {
|
||||||
|
|
||||||
// llmServer is an instance of the llama.cpp server
|
// llmServer is an instance of the llama.cpp server
|
||||||
type llmServer struct {
|
type llmServer struct {
|
||||||
port int
|
port int
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
done chan error // Channel to signal when the process exits
|
done chan error // Channel to signal when the process exits
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
|
numParallel int
|
||||||
|
|
||||||
estimate MemoryEstimate
|
estimate MemoryEstimate
|
||||||
totalLayers uint64
|
totalLayers uint64
|
||||||
|
@ -343,6 +344,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
status: NewStatusWriter(os.Stderr),
|
status: NewStatusWriter(os.Stderr),
|
||||||
options: opts,
|
options: opts,
|
||||||
estimate: estimate,
|
estimate: estimate,
|
||||||
|
numParallel: numParallel,
|
||||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
totalLayers: ggml.KV().BlockCount() + 1,
|
totalLayers: ggml.KV().BlockCount() + 1,
|
||||||
gpus: gpus,
|
gpus: gpus,
|
||||||
|
@ -890,11 +892,14 @@ type EmbedResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
// each input will use a slot, so we need to acquire the semaphore for
|
||||||
|
// the number of inputs up to numParallel
|
||||||
|
slots := int64(min(len(input), s.numParallel))
|
||||||
|
if err := s.sem.Acquire(ctx, slots); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer s.sem.Release(1)
|
defer s.sem.Release(slots)
|
||||||
|
|
||||||
// Make sure the server is ready
|
// Make sure the server is ready
|
||||||
status, err := s.getServerStatusRetry(ctx)
|
status, err := s.getServerStatusRetry(ctx)
|
||||||
|
|
Loading…
Reference in a new issue