diff --git a/llm/ext_server/server.cpp b/llm/ext_server/server.cpp index c1a803f1..8a0dffea 100644 --- a/llm/ext_server/server.cpp +++ b/llm/ext_server/server.cpp @@ -2625,6 +2625,21 @@ static json format_partial_response( return res; } +static json format_tokenizer_response(const std::vector &tokens) +{ + return json { + {"tokens", tokens} + }; +} + +static json format_detokenized_response(std::string content) +{ + return json { + {"content", content} + }; +} + + static void log_server_request(const httplib::Request &req, const httplib::Response &res) { // skip GH copilot requests when using default port @@ -3114,6 +3129,34 @@ int main(int argc, char **argv) { } }); + svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) + { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + std::vector tokens; + if (body.count("content") != 0) + { + tokens = llama.tokenize(body["content"], false); + } + const json data = format_tokenizer_response(tokens); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + + svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) + { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + std::string content; + if (body.count("tokens") != 0) + { + const std::vector tokens = body["tokens"]; + content = tokens_to_str(llama.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + return res.set_content(data.dump(), "application/json; charset=utf-8"); + }); + svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); diff --git a/llm/llm.go b/llm/llm.go index 4492d39f..2a0c4b91 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -12,7 +12,6 @@ package llm import "C" import ( "fmt" - "strings" "unsafe" ) @@ -38,62 +37,3 @@ func Quantize(infile, outfile string, ftype fileType) error { return nil } - -type llamaModel struct { - m *C.struct_llama_model -} - -func newLlamaModel(p string) *llamaModel { - cs := C.CString(p) - defer C.free(unsafe.Pointer(cs)) - - params := C.llama_model_default_params() - params.vocab_only = true - - return &llamaModel{ - C.llama_load_model_from_file(cs, params), - } -} - -func (llm *llamaModel) Close() { - C.llama_free_model(llm.m) -} - -func (llm *llamaModel) Tokenize(s string) []int { - cs := C.CString(s) - defer C.free(unsafe.Pointer(cs)) - - ltokens := make([]C.llama_token, len(s)+2) - n := C.llama_tokenize( - llm.m, - cs, - C.int32_t(len(s)), - <okens[0], - C.int32_t(len(ltokens)), - false, - true, - ) - - if n < 0 { - return nil - } - - tokens := make([]int, n) - for i := 0; i < int(n); i++ { - tokens[i] = int(ltokens[i]) - } - - return tokens -} - -func (llm *llamaModel) Detokenize(i32s []int) string { - var sb strings.Builder - for _, i32 := range i32s { - c := make([]byte, 512) - if n := C.llama_token_to_piece(llm.m, C.llama_token(i32), (*C.char)(unsafe.Pointer(&c[0])), C.int(len(c)), false); n > 0 { - sb.WriteString(unsafe.String(&c[0], n)) - } - } - - return sb.String() -} diff --git a/llm/server.go b/llm/server.go index 97aa2a15..3af8a329 100644 --- a/llm/server.go +++ b/llm/server.go @@ -57,8 +57,6 @@ type llmServer struct { loadDuration time.Duration // Record how long it took the model to load loadProgress float32 - *llamaModel - sem *semaphore.Weighted } @@ -311,7 +309,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr totalLayers: ggml.KV().BlockCount() + 1, gpuCount: gpuCount, done: make(chan error, 1), - llamaModel: newLlamaModel(model), } s.cmd.Env = os.Environ() @@ -849,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) } - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(EmbeddingRequest{Content: prompt}); err != nil { + data, err := json.Marshal(TokenizeRequest{Content: prompt}) + if err != nil { return nil, fmt.Errorf("error marshaling embed data: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), &b) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) if err != nil { return nil, fmt.Errorf("error creating embed request: %w", err) } @@ -884,12 +881,108 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er return embedding.Embedding, nil } +type TokenizeRequest struct { + Content string `json:"content"` +} + +type TokenizeResponse struct { + Tokens []int `json:"tokens"` +} + func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { - return s.llamaModel.Tokenize(content), nil + // Make sure the server is ready + status, err := s.getServerStatus(ctx) + if err != nil { + return nil, err + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(TokenizeRequest{Content: content}) + if err != nil { + return nil, fmt.Errorf("marshaling encode data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("encode request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("do encode request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read encode request: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm encode error: %s", body) + return nil, fmt.Errorf("%s", body) + } + + var encoded TokenizeResponse + if err := json.Unmarshal(body, &encoded); err != nil { + return nil, fmt.Errorf("unmarshal encode response: %w", err) + } + + return encoded.Tokens, nil +} + +type DetokenizeRequest struct { + Tokens []int `json:"tokens"` +} + +type DetokenizeResponse struct { + Content string `json:"content"` } func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { - return s.llamaModel.Detokenize(tokens), nil + // Make sure the server is ready + status, err := s.getServerStatus(ctx) + if err != nil { + return "", err + } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable { + return "", fmt.Errorf("unexpected server status: %s", status.ToString()) + } + + data, err := json.Marshal(DetokenizeRequest{Tokens: tokens}) + if err != nil { + return "", fmt.Errorf("marshaling decode data: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data)) + if err != nil { + return "", fmt.Errorf("decode request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("do decode request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read decode request: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm decode error: %s", body) + return "", fmt.Errorf("%s", body) + } + + var decoded DetokenizeResponse + if err := json.Unmarshal(body, &decoded); err != nil { + return "", fmt.Errorf("unmarshal encode response: %w", err) + } + + return decoded.Content, nil } func (s *llmServer) Close() error { @@ -907,10 +1000,6 @@ func (s *llmServer) Close() error { slog.Debug("llama server stopped") } - if s.llamaModel != nil { - s.llamaModel.Close() - } - return nil }