llama server wrapper

This commit is contained in:
Bruce MacDonald 2023-06-23 13:10:13 -04:00
parent 8fa91332fa
commit 0758cb2d4b
7 changed files with 83 additions and 138 deletions

34
server/README.md Normal file
View file

@ -0,0 +1,34 @@
# Server
🙊
## Installation
If using Apple silicon, you need a Python version that supports arm64:
```bash
wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
bash Miniforge3-MacOSX-arm64.sh
```
Get the dependencies:
```bash
pip install llama-cpp-python
pip install -r requirements.txt
```
## Running
Put your model in `models/` and run:
```bash
python server.py
```
## API
### `POST /generate`
model: `string` - The name of the model to use in the `models` folder.
prompt: `string` - The prompt to use.

View file

@ -1,2 +0,0 @@
LIBRARY_PATH=$PWD/go-llama.cpp C_INCLUDE_PATH=$PWD/go-llama.cpp go build .

View file

@ -1,8 +0,0 @@
module github.com/keypairdev/keypair
go 1.20
require (
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1
github.com/sashabaranov/go-openai v1.11.3
)

View file

@ -1,15 +0,0 @@
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 h1:UQ8y3kHxBgh3BnaW06y/X97fEN48yHPwWobMz8/aztU=
github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -1,113 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"runtime"
"github.com/sashabaranov/go-openai"
llama "github.com/go-skynet/go-llama.cpp"
)
type Model interface {
Name() string
Handler(w http.ResponseWriter, r *http.Request)
}
type LLama7B struct {
llama *llama.LLama
}
func NewLLama7B() *LLama7B {
llama, err := llama.New("./models/7B/ggml-model-q4_0.bin", llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(128))
if err != nil {
fmt.Println("Loading the model failed:", err.Error())
os.Exit(1)
}
return &LLama7B{
llama: llama,
}
}
func (l *LLama7B) Name() string {
return "LLaMA 7B"
}
func (m *LLama7B) Handler(w http.ResponseWriter, r *http.Request) {
var text bytes.Buffer
io.Copy(&text, r.Body)
_, err := m.llama.Predict(text.String(), llama.Debug, llama.SetTokenCallback(func(token string) bool {
w.Write([]byte(token))
return true
}), llama.SetTokens(512), llama.SetThreads(runtime.NumCPU()), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
if err != nil {
fmt.Println("Predict failed:", err.Error())
os.Exit(1)
}
embeds, err := m.llama.Embeddings(text.String())
if err != nil {
fmt.Printf("Embeddings: error %s \n", err.Error())
}
fmt.Printf("Embeddings: %v", embeds)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
}
type GPT4 struct {
apiKey string
}
func (g *GPT4) Name() string {
return "OpenAI GPT-4"
}
func (g *GPT4) Handler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
client := openai.NewClient("your token")
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
},
},
)
if err != nil {
fmt.Printf("chat completion error: %v\n", err)
return
}
fmt.Println(resp.Choices[0].Message.Content)
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
}
// TODO: add subcommands to spawn different models
func main() {
model := &LLama7B{}
http.HandleFunc("/generate", model.Handler)
fmt.Println("Starting server on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
fmt.Printf("Error starting server: %s\n", err)
return
}
}

2
server/requirements.txt Normal file
View file

@ -0,0 +1,2 @@
Flask==2.3.2
flask_cors==3.0.10

47
server/server.py Normal file
View file

@ -0,0 +1,47 @@
import json
import os
from llama_cpp import Llama
from flask import Flask, Response, stream_with_context, request
from flask_cors import CORS, cross_origin
app = Flask(__name__)
CORS(app) # enable CORS for all routes
# llms tracks which models are loaded
llms = {}
@app.route("/generate", methods=["POST"])
def generate():
data = request.get_json()
model = data.get("model")
prompt = data.get("prompt")
if not model:
return Response("Model is required", status=400)
if not prompt:
return Response("Prompt is required", status=400)
if not os.path.exists(f"../models/{model}.bin"):
return {"error": "The model file does not exist."}, 400
if model not in llms:
llms[model] = Llama(model_path=f"../models/{model}.bin")
def stream_response():
stream = llms[model](
str(prompt), # TODO: optimize prompt based on model
max_tokens=4096,
stop=["Q:", "\n"],
echo=True,
stream=True,
)
for output in stream:
yield json.dumps(output)
return Response(
stream_with_context(stream_response()), mimetype="text/event-stream"
)
if __name__ == "__main__":
app.run(debug=True, threaded=True, port=5000)