diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index e0424a92..1f1b8bde 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -2714,21 +2714,6 @@ static json format_partial_response( return res; } -static json format_tokenizer_response(const std::vector &tokens) -{ - return json { - {"tokens", tokens} - }; -} - -static json format_detokenized_response(std::string content) -{ - return json { - {"content", content} - }; -} - - static void log_server_request(const httplib::Request &req, const httplib::Response &res) { // skip GH copilot requests when using default port @@ -3218,34 +3203,6 @@ int main(int argc, char **argv) { } }); - svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::vector tokens; - if (body.count("content") != 0) - { - tokens = llama.tokenize(body["content"], false); - } - const json data = format_tokenizer_response(tokens); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - - svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) - { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - const json body = json::parse(req.body); - std::string content; - if (body.count("tokens") != 0) - { - const std::vector tokens = body["tokens"]; - content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); - } - - const json data = format_detokenized_response(content); - return res.set_content(data.dump(), "application/json; charset=utf-8"); - }); - svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); diff --git a/llm/llm.go b/llm/llm.go index 2a0c4b91..82a2d72d 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -12,6 +12,7 @@ package llm import "C" import ( "fmt" + "strings" "unsafe" ) @@ -37,3 +38,47 @@ func Quantize(infile, outfile string, ftype fileType) error { return nil } + +type llamaModel struct { + m *C.struct_llama_model +} + +func newLlamaModel(p string) *llamaModel { + cs := C.CString(p) + defer C.free(unsafe.Pointer(cs)) + + return &llamaModel{ + C.llama_load_model_from_file( + cs, + C.llama_model_default_params(), + ), + } +} + +func (llm *llamaModel) Close() { + C.llama_free_model(llm.m) +} + +func (llm *llamaModel) Tokenize(s string) []int { + cs := C.CString(s) + defer C.free(unsafe.Pointer(cs)) + + tokens := make([]int, len(s)+2) + if n := C.llama_tokenize(llm.m, cs, C.int(len(s)), (*C.llama_token)(unsafe.Pointer(&tokens[0])), C.int(len(s)+2), false, true); n > 0 { + return tokens[:n] + } + + return nil +} + +func (llm *llamaModel) Detokenize(i32s []int) string { + var sb strings.Builder + for _, i32 := range i32s { + c := make([]byte, 512) + if n := C.llama_token_to_piece(llm.m, C.llama_token(i32), (*C.char)(unsafe.Pointer(&c[0])), C.int(len(c)), false); n > 0 { + sb.WriteString(unsafe.String(&c[0], n)) + } + } + + return sb.String() +} diff --git a/llm/server.go b/llm/server.go index 462f8484..9b5d0f06 100644 --- a/llm/server.go +++ b/llm/server.go @@ -57,6 +57,8 @@ type llmServer struct { loadDuration time.Duration // Record how long it took the model to load loadProgress float32 + *llamaModel + sem *semaphore.Weighted } @@ -306,6 +308,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr totalLayers: ggml.KV().BlockCount() + 1, gpuCount: gpuCount, done: make(chan error, 1), + llamaModel: newLlamaModel(model), } s.cmd.Env = os.Environ() @@ -843,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - data, err := json.Marshal(TokenizeRequest{Content: prompt}) - if err != nil { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(EmbeddingRequest{Content: prompt}); err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), &b) if err != nil { return nil, fmt.Errorf("error creating embed request: %w", err) } @@ -878,108 +881,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return embedding.Embedding, nil } -type TokenizeRequest struct { - Content string `json:"content"` -} - -type TokenizeResponse struct { - Tokens []int `json:"tokens"` -} - func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return nil, err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) - } - - data, err := json.Marshal(TokenizeRequest{Content: content}) - if err != nil { - return nil, fmt.Errorf("marshaling encode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return nil, fmt.Errorf("encode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("do encode request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read encode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm encode error: %s", body) - return nil, fmt.Errorf("%s", body) - } - - var encoded TokenizeResponse - if err := json.Unmarshal(body, &encoded); err != nil { - return nil, fmt.Errorf("unmarshal encode response: %w", err) - } - - return encoded.Tokens, nil -} - -type DetokenizeRequest struct { - Tokens []int `json:"tokens"` -} - -type DetokenizeResponse struct { - Content string `json:"content"` + return s.llamaModel.Tokenize(content), nil } func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { - // Make sure the server is ready - status, err := s.getServerStatus(ctx) - if err != nil { - return "", err - } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { - return "", fmt.Errorf("unexpected server status: %s", status.ToString()) - } - - data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) - if err != nil { - return "", fmt.Errorf("marshaling decode data: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data)) - if err != nil { - return "", fmt.Errorf("decode request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", fmt.Errorf("do decode request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("read decode request: %w", err) - } - - if resp.StatusCode >= 400 { - log.Printf("llm decode error: %s", body) - return "", fmt.Errorf("%s", body) - } - - var decoded DetokenizeResponse - if err := json.Unmarshal(body, &decoded); err != nil { - return "", fmt.Errorf("unmarshal encode response: %w", err) - } - - return decoded.Content, nil + return s.llamaModel.Detokenize(tokens), nil } func (s *llmServer) Close() error { @@ -997,6 +904,10 @@ func (s *llmServer) Close() error { slog.Debug("llama server stopped") } + if s.llamaModel != nil { + s.llamaModel.Close() + } + return nil }