llama server wrapper
This commit is contained in:
parent
8fa91332fa
commit
0758cb2d4b
7 changed files with 83 additions and 138 deletions
34
server/README.md
Normal file
34
server/README.md
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# Server
|
||||||
|
|
||||||
|
🙊
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
If using Apple silicon, you need a Python version that supports arm64:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
|
||||||
|
bash Miniforge3-MacOSX-arm64.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Get the dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install llama-cpp-python
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
Put your model in `models/` and run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
### `POST /generate`
|
||||||
|
|
||||||
|
model: `string` - The name of the model to use in the `models` folder.
|
||||||
|
prompt: `string` - The prompt to use.
|
|
@ -1,2 +0,0 @@
|
||||||
LIBRARY_PATH=$PWD/go-llama.cpp C_INCLUDE_PATH=$PWD/go-llama.cpp go build .
|
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
module github.com/keypairdev/keypair
|
|
||||||
|
|
||||||
go 1.20
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1
|
|
||||||
github.com/sashabaranov/go-openai v1.11.3
|
|
||||||
)
|
|
|
@ -1,15 +0,0 @@
|
||||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 h1:UQ8y3kHxBgh3BnaW06y/X97fEN48yHPwWobMz8/aztU=
|
|
||||||
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
|
|
||||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
|
||||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
|
|
||||||
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
|
|
||||||
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
|
|
||||||
github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
|
|
||||||
github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
|
||||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
|
||||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
|
||||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
|
||||||
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
113
server/main.go
113
server/main.go
|
@ -1,113 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/sashabaranov/go-openai"
|
|
||||||
|
|
||||||
llama "github.com/go-skynet/go-llama.cpp"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
type Model interface {
|
|
||||||
Name() string
|
|
||||||
Handler(w http.ResponseWriter, r *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
type LLama7B struct {
|
|
||||||
llama *llama.LLama
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLLama7B() *LLama7B {
|
|
||||||
llama, err := llama.New("./models/7B/ggml-model-q4_0.bin", llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(128))
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Loading the model failed:", err.Error())
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &LLama7B{
|
|
||||||
llama: llama,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *LLama7B) Name() string {
|
|
||||||
return "LLaMA 7B"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *LLama7B) Handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var text bytes.Buffer
|
|
||||||
io.Copy(&text, r.Body)
|
|
||||||
|
|
||||||
_, err := m.llama.Predict(text.String(), llama.Debug, llama.SetTokenCallback(func(token string) bool {
|
|
||||||
w.Write([]byte(token))
|
|
||||||
return true
|
|
||||||
}), llama.SetTokens(512), llama.SetThreads(runtime.NumCPU()), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println("Predict failed:", err.Error())
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
embeds, err := m.llama.Embeddings(text.String())
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Embeddings: error %s \n", err.Error())
|
|
||||||
}
|
|
||||||
fmt.Printf("Embeddings: %v", embeds)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
}
|
|
||||||
|
|
||||||
type GPT4 struct {
|
|
||||||
apiKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *GPT4) Name() string {
|
|
||||||
return "OpenAI GPT-4"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *GPT4) Handler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
client := openai.NewClient("your token")
|
|
||||||
resp, err := client.CreateChatCompletion(
|
|
||||||
context.Background(),
|
|
||||||
openai.ChatCompletionRequest{
|
|
||||||
Model: openai.GPT3Dot5Turbo,
|
|
||||||
Messages: []openai.ChatCompletionMessage{
|
|
||||||
{
|
|
||||||
Role: openai.ChatMessageRoleUser,
|
|
||||||
Content: "Hello!",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("chat completion error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(resp.Choices[0].Message.Content)
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: add subcommands to spawn different models
|
|
||||||
func main() {
|
|
||||||
model := &LLama7B{}
|
|
||||||
|
|
||||||
http.HandleFunc("/generate", model.Handler)
|
|
||||||
|
|
||||||
fmt.Println("Starting server on :8080")
|
|
||||||
if err := http.ListenAndServe(":8080", nil); err != nil {
|
|
||||||
fmt.Printf("Error starting server: %s\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
2
server/requirements.txt
Normal file
2
server/requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Flask==2.3.2
|
||||||
|
flask_cors==3.0.10
|
47
server/server.py
Normal file
47
server/server.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from llama_cpp import Llama
|
||||||
|
from flask import Flask, Response, stream_with_context, request
|
||||||
|
from flask_cors import CORS, cross_origin
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app) # enable CORS for all routes
|
||||||
|
|
||||||
|
# llms tracks which models are loaded
|
||||||
|
llms = {}
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/generate", methods=["POST"])
|
||||||
|
def generate():
|
||||||
|
data = request.get_json()
|
||||||
|
model = data.get("model")
|
||||||
|
prompt = data.get("prompt")
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return Response("Model is required", status=400)
|
||||||
|
if not prompt:
|
||||||
|
return Response("Prompt is required", status=400)
|
||||||
|
if not os.path.exists(f"../models/{model}.bin"):
|
||||||
|
return {"error": "The model file does not exist."}, 400
|
||||||
|
|
||||||
|
if model not in llms:
|
||||||
|
llms[model] = Llama(model_path=f"../models/{model}.bin")
|
||||||
|
|
||||||
|
def stream_response():
|
||||||
|
stream = llms[model](
|
||||||
|
str(prompt), # TODO: optimize prompt based on model
|
||||||
|
max_tokens=4096,
|
||||||
|
stop=["Q:", "\n"],
|
||||||
|
echo=True,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for output in stream:
|
||||||
|
yield json.dumps(output)
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
stream_with_context(stream_response()), mimetype="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(debug=True, threaded=True, port=5000)
|
Loading…
Reference in a new issue