Compare commits
52 commits
f2d1c842ad
...
9e08c23ba9
Author | SHA1 | Date | |
---|---|---|---|
9e08c23ba9 | |||
|
0a8d6ea86d | ||
|
8e1050f366 | ||
|
eda8a32a09 | ||
|
a0a40aa20c | ||
|
2697d7f5aa | ||
|
1f32276178 | ||
|
4c4fe3f87f | ||
|
feedf49c71 | ||
|
8b00a415ab | ||
|
01b80e9ffc | ||
|
bd5e432630 | ||
|
aec77d6a05 | ||
|
6ffb5cb017 | ||
|
f7e3b9190f | ||
|
980dd15f81 | ||
|
01d544d373 | ||
|
1dc3ef3aa9 | ||
|
8aac22438e | ||
|
15c2d8fe14 | ||
|
25906d72d1 | ||
|
023451ce47 | ||
|
9b53e39d8e | ||
|
97fae2df95 | ||
|
160d9d4900 | ||
|
d4e6407464 | ||
|
b7f7d8cd15 | ||
|
2fa1db4345 | ||
|
71b0945fc6 | ||
|
5bca2e60a7 | ||
|
67472e0e89 | ||
|
e9aa5117c4 | ||
|
2473bdba5e | ||
|
7d1c0047fa | ||
|
7b61eba471 | ||
|
7edaf6e7e8 | ||
|
97ec8cfd4e | ||
|
5b3a21b578 | ||
|
ad0c19dde4 | ||
|
69eb06c40e | ||
|
1829fb61bd | ||
|
ce67706037 | ||
|
685a53534b | ||
|
de4fc29773 | ||
|
e04c7012c2 | ||
|
d4a7216c82 | ||
|
a4fdd03c3b | ||
|
fc85f50a2b | ||
|
04210aa6dd | ||
|
43f9d92008 | ||
|
ed6c8bfe57 | ||
|
df3802a65f |
50 changed files with 1245 additions and 771 deletions
3
.gitattributes
vendored
3
.gitattributes
vendored
|
@ -1,2 +1,3 @@
|
||||||
llm/ext_server/* linguist-vendored
|
llm/ext_server/* linguist-vendored
|
||||||
* text eol=lf
|
* text=auto
|
||||||
|
*.go text eol=lf
|
||||||
|
|
10
.github/workflows/release.yaml
vendored
10
.github/workflows/release.yaml
vendored
|
@ -31,7 +31,7 @@ jobs:
|
||||||
security set-keychain-settings -lut 3600 build.keychain
|
security set-keychain-settings -lut 3600 build.keychain
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: Build Darwin
|
- name: Build Darwin
|
||||||
env:
|
env:
|
||||||
|
@ -87,7 +87,7 @@ jobs:
|
||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
|
@ -141,7 +141,7 @@ jobs:
|
||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
|
@ -218,7 +218,7 @@ jobs:
|
||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
|
@ -306,7 +306,7 @@ jobs:
|
||||||
write-host "plugin installed"
|
write-host "plugin installed"
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get
|
- run: go get
|
||||||
- uses: actions/download-artifact@v4
|
- uses: actions/download-artifact@v4
|
||||||
|
|
10
.github/workflows/test.yaml
vendored
10
.github/workflows/test.yaml
vendored
|
@ -63,7 +63,7 @@ jobs:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: go get ./...
|
- run: go get ./...
|
||||||
- run: |
|
- run: |
|
||||||
|
@ -163,7 +163,7 @@ jobs:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install ROCm'
|
- name: 'Install ROCm'
|
||||||
run: |
|
run: |
|
||||||
|
@ -200,7 +200,7 @@ jobs:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- name: 'Install CUDA'
|
- name: 'Install CUDA'
|
||||||
run: |
|
run: |
|
||||||
|
@ -255,7 +255,7 @@ jobs:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: false
|
cache: false
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
|
@ -297,7 +297,7 @@ jobs:
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "stable"
|
go-version-file: go.mod
|
||||||
cache: true
|
cache: true
|
||||||
- run: |
|
- run: |
|
||||||
case ${{ matrix.arch }} in
|
case ${{ matrix.arch }} in
|
||||||
|
|
|
@ -24,7 +24,6 @@ linters:
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- tenv
|
- tenv
|
||||||
- testifylint
|
|
||||||
- unconvert
|
- unconvert
|
||||||
- unused
|
- unused
|
||||||
- usestdlibvars
|
- usestdlibvars
|
||||||
|
|
|
@ -325,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||||
- [tlm](https://github.com/yusufcanb/tlm)
|
- [tlm](https://github.com/yusufcanb/tlm)
|
||||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||||
- [gollama](https://github.com/sammcj/gollama)
|
- [gollama](https://github.com/sammcj/gollama)
|
||||||
|
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
|
|
||||||
|
|
|
@ -298,7 +298,7 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||||
return &lr, nil
|
return &lr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// List running models.
|
// ListRunning lists running models.
|
||||||
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||||
var lr ProcessResponse
|
var lr ProcessResponse
|
||||||
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
||||||
|
@ -333,7 +333,7 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
|
||||||
return &resp, nil
|
return &resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hearbeat checks if the server has started and is responsive; if yes, it
|
// Heartbeat checks if the server has started and is responsive; if yes, it
|
||||||
// returns nil, otherwise an error.
|
// returns nil, otherwise an error.
|
||||||
func (c *Client) Heartbeat(ctx context.Context) error {
|
func (c *Client) Heartbeat(ctx context.Context) error {
|
||||||
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {
|
||||||
|
|
|
@ -504,7 +504,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||||
for key, val := range m {
|
for key, val := range m {
|
||||||
opt, ok := jsonOpts[key]
|
opt, ok := jsonOpts[key]
|
||||||
if !ok {
|
if !ok {
|
||||||
slog.Warn("invalid option provided", "option", opt.Name)
|
slog.Warn("invalid option provided", "option", key)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,12 +11,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
updatAvailableMenuID = 1
|
updateAvailableMenuID = 1
|
||||||
updateMenuID = updatAvailableMenuID + 1
|
updateMenuID = updateAvailableMenuID + 1
|
||||||
separatorMenuID = updateMenuID + 1
|
separatorMenuID = updateMenuID + 1
|
||||||
diagLogsMenuID = separatorMenuID + 1
|
diagLogsMenuID = separatorMenuID + 1
|
||||||
diagSeparatorMenuID = diagLogsMenuID + 1
|
diagSeparatorMenuID = diagLogsMenuID + 1
|
||||||
quitMenuID = diagSeparatorMenuID + 1
|
quitMenuID = diagSeparatorMenuID + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *winTray) initMenus() error {
|
func (t *winTray) initMenus() error {
|
||||||
|
@ -35,7 +35,7 @@ func (t *winTray) initMenus() error {
|
||||||
func (t *winTray) UpdateAvailable(ver string) error {
|
func (t *winTray) UpdateAvailable(ver string) error {
|
||||||
if !t.updateNotified {
|
if !t.updateNotified {
|
||||||
slog.Debug("updating menu and sending notification for new update")
|
slog.Debug("updating menu and sending notification for new update")
|
||||||
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
|
||||||
return fmt.Errorf("unable to create menu entries %w", err)
|
return fmt.Errorf("unable to create menu entries %w", err)
|
||||||
}
|
}
|
||||||
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {
|
||||||
|
|
47
cmd/cmd.go
47
cmd/cmd.go
|
@ -22,6 +22,7 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
status := "transferring model data"
|
status := "transferring model data"
|
||||||
spinner := progress.NewSpinner(status)
|
spinner := progress.NewSpinner(status)
|
||||||
p.Add(status, spinner)
|
p.Add(status, spinner)
|
||||||
|
defer p.Stop()
|
||||||
|
|
||||||
for i := range modelfile.Commands {
|
for i := range modelfile.Commands {
|
||||||
switch modelfile.Commands[i].Name {
|
switch modelfile.Commands[i].Name {
|
||||||
|
@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
path = tempfile
|
path = tempfile
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, err := createBlob(cmd, client, path)
|
digest, err := createBlob(cmd, client, path, spinner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
|
||||||
return tempfile.Name(), nil
|
return tempfile.Name(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
|
||||||
bin, err := os.Open(path)
|
bin, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer bin.Close()
|
defer bin.Close()
|
||||||
|
|
||||||
|
// Get file info to retrieve the size
|
||||||
|
fileInfo, err := bin.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
fileSize := fileInfo.Size()
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
if _, err := io.Copy(hash, bin); err != nil {
|
if _, err := io.Copy(hash, bin); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var pw progressWriter
|
||||||
|
status := "transferring model data 0%"
|
||||||
|
spinner.SetMessage(status)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(60 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
|
||||||
|
case <-done:
|
||||||
|
spinner.SetMessage("transferring model data 100%")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type progressWriter struct {
|
||||||
|
n atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *progressWriter) Write(p []byte) (n int, err error) {
|
||||||
|
w.n.Add(int64(len(p)))
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
interactive := true
|
interactive := true
|
||||||
|
|
||||||
|
@ -1086,7 +1125,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
func RunServer(_ *cobra.Command, _ []string) error {
|
||||||
if err := initializeKeypair(); err != nil {
|
if err := initializeKeypair(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
|
||||||
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(t.Merges) > 0 {
|
||||||
|
kv["tokenizer.ggml.merges"] = t.Merges
|
||||||
|
}
|
||||||
|
|
||||||
if t.Template != "" {
|
if t.Template != "" {
|
||||||
kv["tokenizer.chat_template"] = t.Template
|
kv["tokenizer.chat_template"] = t.Template
|
||||||
}
|
}
|
||||||
|
@ -89,6 +93,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
|
||||||
conv = &mixtral{}
|
conv = &mixtral{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
conv = &gemma{}
|
conv = &gemma{}
|
||||||
|
case "Phi3ForCausalLM":
|
||||||
|
conv = &phi3{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return errors.New("unsupported architecture")
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,10 +90,6 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
|
||||||
kv["llama.attention.value_length"] = p.HeadDim
|
kv["llama.attention.value_length"] = p.HeadDim
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(t.Merges) > 0 {
|
|
||||||
kv["tokenizer.ggml.merges"] = t.Merges
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
125
convert/convert_phi3.go
Normal file
125
convert/convert_phi3.go
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type phi3 struct {
|
||||||
|
Parameters
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NLayers uint32 `json:"n_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NEmbd uint32 `json:"n_embd"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NHead uint32 `json:"n_head"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
NHeadKV uint32 `json:"n_head_kv"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScaling struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
LongFactor ropeFactor `json:"long_factor"`
|
||||||
|
ShortFactor ropeFactor `json:"short_factor"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
NPositions uint32 `json:"n_positions"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Converter = (*phi3)(nil)
|
||||||
|
|
||||||
|
func (p *phi3) KV(t *Tokenizer) llm.KV {
|
||||||
|
kv := p.Parameters.KV(t)
|
||||||
|
kv["general.architecture"] = "phi3"
|
||||||
|
kv["general.name"] = "phi3"
|
||||||
|
kv["phi3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["phi3.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
|
||||||
|
kv["phi3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["phi3.block_count"] = cmp.Or(p.NumHiddenLayers, p.NLayers)
|
||||||
|
kv["phi3.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||||
|
kv["phi3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NHeadKV)
|
||||||
|
kv["phi3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["phi3.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NumAttentionHeads, p.NHead)
|
||||||
|
kv["phi3.rope.freq_base"] = p.RopeTheta
|
||||||
|
kv["phi3.rope.scaling.original_context_length"] = p.OriginalMaxPositionEmbeddings
|
||||||
|
kv["phi3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
|
||||||
|
scale := float64(p.MaxPositionEmbeddings) / float64(p.OriginalMaxPositionEmbeddings)
|
||||||
|
|
||||||
|
switch p.RopeScaling.Type {
|
||||||
|
case "":
|
||||||
|
// no scaling
|
||||||
|
case "su", "longrope":
|
||||||
|
kv["phi3.rope.scaling.attn_factor"] = float32(max(math.Sqrt(1+math.Log(scale)/math.Log(float64(p.OriginalMaxPositionEmbeddings))), 1.0))
|
||||||
|
case "yarn":
|
||||||
|
kv["phi3.rope.scaling.attn_factor"] = float32(max(0.1*math.Log(scale)+1.0, 1.0))
|
||||||
|
default:
|
||||||
|
panic("unknown rope scaling type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
|
||||||
|
var addRopeFactors sync.Once
|
||||||
|
|
||||||
|
out := make([]llm.Tensor, 0, len(ts)+2)
|
||||||
|
for _, t := range ts {
|
||||||
|
name := p.tensorName(t.Name())
|
||||||
|
if strings.HasPrefix(name, "blk.0.") {
|
||||||
|
addRopeFactors.Do(func() {
|
||||||
|
out = append(out, llm.Tensor{
|
||||||
|
Name: "rope_factors_long.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||||
|
WriterTo: p.RopeScaling.LongFactor,
|
||||||
|
}, llm.Tensor{
|
||||||
|
Name: "rope_factors_short.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||||
|
WriterTo: p.RopeScaling.ShortFactor,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, llm.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *phi3) tensorName(n string) string {
|
||||||
|
return strings.NewReplacer(
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.qkv_proj", "attn_qkv",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_up_proj", "ffn_up",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
).Replace(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ropeFactor []float32
|
||||||
|
|
||||||
|
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
err := binary.Write(w, binary.LittleEndian, r)
|
||||||
|
return 0, err
|
||||||
|
}
|
|
@ -65,6 +65,8 @@ func TestConvertFull(t *testing.T) {
|
||||||
"Mistral-7B-Instruct-v0.2",
|
"Mistral-7B-Instruct-v0.2",
|
||||||
"Mixtral-8x7B-Instruct-v0.1",
|
"Mixtral-8x7B-Instruct-v0.1",
|
||||||
"gemma-2b-it",
|
"gemma-2b-it",
|
||||||
|
// microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
|
||||||
|
"Phi-3-mini-128k-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range cases {
|
for i := range cases {
|
||||||
|
|
225
convert/testdata/Phi-3-mini-128k-instruct.json
vendored
Normal file
225
convert/testdata/Phi-3-mini-128k-instruct.json
vendored
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
{
|
||||||
|
"general.architecture": "phi3",
|
||||||
|
"general.file_type": "1",
|
||||||
|
"general.quantization_version": "2",
|
||||||
|
"phi3.block_count": "32",
|
||||||
|
"phi3.context_length": "131072",
|
||||||
|
"phi3.embedding_length": "3072",
|
||||||
|
"phi3.feed_forward_length": "8192",
|
||||||
|
"phi3.rope.scaling.original_context_length": "4096",
|
||||||
|
"phi3.rope.dimension_count": "96",
|
||||||
|
"phi3.rope.freq_base": "10000",
|
||||||
|
"phi3.rope.scaling.attn_factor": "1.1902381",
|
||||||
|
"phi3.attention.head_count": "32",
|
||||||
|
"phi3.attention.head_count_kv": "32",
|
||||||
|
"phi3.attention.layer_norm_rms_epsilon": "1e-05",
|
||||||
|
"phi3.attention.sliding_window": "262144",
|
||||||
|
"tokenizer.ggml.model": "llama",
|
||||||
|
"tokenizer.ggml.pre": "default",
|
||||||
|
"tokenizer.ggml.add_bos_token": "false",
|
||||||
|
"tokenizer.ggml.add_eos_token": "false",
|
||||||
|
"tokenizer.ggml.bos_token_id": "1",
|
||||||
|
"tokenizer.ggml.eos_token_id": "32000",
|
||||||
|
"tokenizer.ggml.unknown_token_id": "0",
|
||||||
|
"tokenizer.ggml.padding_token_id": "32000",
|
||||||
|
"tokenizer.ggml.scores": "6e37bcde2adc7e350e87c496eddd7a2124329c1dc66c5bf3ad3997253e4f7a62",
|
||||||
|
"tokenizer.ggml.token_type": "b6ecf55ec64ee67d87750bdb8d757a2c58bf78377e9f4219f5689a6c4dea57ce",
|
||||||
|
"tokenizer.ggml.tokens": "d168da3ddd3eee820916945fcb9baf24dd3cde42f606cffa2d19e7c8a8743918",
|
||||||
|
"blk.0.attn_norm.weight": "216aeb2c9e0c271f899e1ef2a63cceeb8f41e97642e84fada54b1d3c1c11cf25",
|
||||||
|
"blk.0.attn_output.weight": "b597d56f7188ffc1fafc273fadc59d41738cffd677ae98c61a62c3285b3a3099",
|
||||||
|
"blk.0.attn_qkv.weight": "d28a6b44e13f59be5483e4be2bedb544e346168d720aca27f47d1a5a722be91e",
|
||||||
|
"blk.0.ffn_down.weight": "4a691370e5a61fcbbf540fbcbf4c0f1d15dec0364528c0e916d0744f6262b63b",
|
||||||
|
"blk.0.ffn_norm.weight": "0c00af2b4a3128bec64a0cbb1084b042fdbe13d9ad0d03bd577f9449dfead338",
|
||||||
|
"blk.0.ffn_up.weight": "b32b52f790c1c083bfb8a3126dc1111cfeeb28dc8c584a930a1e5334cb176bf4",
|
||||||
|
"blk.1.attn_norm.weight": "68748011503c6c029e8e69a84a8e5a89338f378769627b6dbf7f93d715c292e1",
|
||||||
|
"blk.1.attn_output.weight": "2267344add13b048ca59e4377c86dc512be8046a57156901fa32a20fa74e4ee0",
|
||||||
|
"blk.1.attn_qkv.weight": "9109d2e3d7a2eacfda5226587b8be124a3bf44b972da7ebb17aa15795897eacc",
|
||||||
|
"blk.1.ffn_down.weight": "d675df4df4dd039c0c339ad6445d39eddd2004db6bf35bed6314c7497245a633",
|
||||||
|
"blk.1.ffn_norm.weight": "3b5767ae977bc8baaa06b06efdbea193b6b3ba605ce76d77a76ce317e935500c",
|
||||||
|
"blk.1.ffn_up.weight": "80dfd6d9d234b00334c89b8e0a02f81899c2efd377321c34ba5ba51a5f61b5ff",
|
||||||
|
"blk.2.attn_norm.weight": "6a6743b057e5088f145bc179e92c9bfb41163e7295d7b81c62e23dd89d2b59c4",
|
||||||
|
"blk.2.attn_output.weight": "bc5491ea54e0db81462d7d9b7d25cbdda380c2db8de041bd1c4ab7b76a1d19c3",
|
||||||
|
"blk.2.attn_qkv.weight": "a61287a9852e2f5aca9c100b471d98398b2913a3497c743de3c70ec9ddd7087f",
|
||||||
|
"blk.2.ffn_down.weight": "4fddcc382c8dceeab027fe43d8d44e67edb5e8ce4b9a1b7f773c87770380ade1",
|
||||||
|
"blk.2.ffn_norm.weight": "07e05f82b3f63f711db3b684ca79aed25c0657917e66f88af47348a82065c227",
|
||||||
|
"blk.2.ffn_up.weight": "4835a682ef1826c12df01ae7663fc45f9c82bc8e64b665f13fb7da8e201ec0fb",
|
||||||
|
"blk.3.attn_norm.weight": "f22aba7c03999ba7136f39cda747a39715e498699dc1716cd97fc5dfc58d1b1c",
|
||||||
|
"blk.3.attn_output.weight": "53b579855366fd786c5126b2b30aac4d583ca7bda56833c4865f5cadb5c18c6d",
|
||||||
|
"blk.3.attn_qkv.weight": "bb56aba78158123140fcea59c69ac562ca208f6d3086819417cdad8c50f333ad",
|
||||||
|
"blk.3.ffn_down.weight": "97280897a7cd86db2830c004bccc5bc094f50e293baded0189159a2019145a6e",
|
||||||
|
"blk.3.ffn_norm.weight": "10a8c99f8b57a960e8e0a1133c4a26f9148403d1b9bff2eff114917de996f3b5",
|
||||||
|
"blk.3.ffn_up.weight": "7324046c915e75d621b2043597a245a428d8eea31869135e6257a861491d8dcc",
|
||||||
|
"blk.4.attn_norm.weight": "507d8e164de94646edbfe33def8e8fbf7c9a6ee3fbaedb5000f72d9f51ec5e36",
|
||||||
|
"blk.4.attn_output.weight": "bbb3429e6efa98c150e0fdbf48c16180cbf0d0cbc1b3c253c6c319d78f4593a2",
|
||||||
|
"blk.4.attn_qkv.weight": "b95ee5be0786d3901273d806c339fe6c20e6bfffd2a20672a9f56af80921e8ab",
|
||||||
|
"blk.4.ffn_down.weight": "806bbf91df92a5a22bd5aa1ffb7fc2869f7293ffc7704771c290ecc583b27975",
|
||||||
|
"blk.4.ffn_norm.weight": "cfc2930a81df7aee3a5e7f726a15c1182233e868bf0d9d37f6b6ae6d8c15c234",
|
||||||
|
"blk.4.ffn_up.weight": "c3390c69533de2c8424e8069323ccc5d0c4543111535da04cf2c7d26745576aa",
|
||||||
|
"blk.5.attn_norm.weight": "0d71c4fbcefabbd021569442853d2fe90668b19409ae2805a718a829ca60beab",
|
||||||
|
"blk.5.attn_output.weight": "10ebd93629112bf2df5c30dd0953a4a5e9020306768283181ed426934d47e14f",
|
||||||
|
"blk.5.attn_qkv.weight": "5cb05633369f12d4b00e0ff787736bd846856682115720ebc6cce05270c334f6",
|
||||||
|
"blk.5.ffn_down.weight": "e28bcc5094212eafc7476dbc5b7a520d25b79578cbf4229d698e2655956a80ad",
|
||||||
|
"blk.5.ffn_norm.weight": "b6f2c4cf9f34bb4d59989f96165c14a67dc1e266ad0a6d0fcc49f1add929e6ff",
|
||||||
|
"blk.5.ffn_up.weight": "0f9ef99423cc07ebedc0e9cfa95809f2d7108d910bb4ef97ebc0b0309c440750",
|
||||||
|
"blk.6.attn_norm.weight": "b3edcc47a42218234f7564d7470611b49401a41ae8cd42123f86557c69f5d7f2",
|
||||||
|
"blk.6.attn_output.weight": "eb9b7d257b388bb5b8fe0515e5c6873317239cb94cda236e4b6ada2a6c57c65c",
|
||||||
|
"blk.6.attn_qkv.weight": "eb968081f478c52f07bd9c2761741e982dba33cc4eeadeea3557d391b9ac2106",
|
||||||
|
"blk.6.ffn_down.weight": "1b8588bb7463206290322695577dcfced300895d6e6f4b26966c53a9ae2f0f84",
|
||||||
|
"blk.6.ffn_norm.weight": "1219c04b7770983c77814200eefe743f46d15328ea2b12711e44f8103eab08d3",
|
||||||
|
"blk.6.ffn_up.weight": "197ef287239fec47c55677f0fbb66eaf0644f775bc382de843971730721394f6",
|
||||||
|
"blk.7.attn_norm.weight": "b630ad08c80d564ed1c024384818e9fd3f22a36cd7a14aa96e7e2759a8285099",
|
||||||
|
"blk.7.attn_output.weight": "970255aa750828a47d6b9d399f9612b5bf25aefe7dadbcba41fc416d0d4067c1",
|
||||||
|
"blk.7.attn_qkv.weight": "ebb157c880293e6de8d629f263ba8853ed1dbdc02c311d43432bb8cfbb310739",
|
||||||
|
"blk.7.ffn_down.weight": "24bcd4db4cba844c89f878b81843c373dbbc0675e889d32c5b12e63384a7b670",
|
||||||
|
"blk.7.ffn_norm.weight": "b9c6f71001808ee873ce7db8056e4b53fb4cccec8b7f0f312899b575fae39d39",
|
||||||
|
"blk.7.ffn_up.weight": "979f1828d227455c26015a2a11afe9dd05f2bb97a8ba6b38c8dab3f50e627401",
|
||||||
|
"blk.8.attn_norm.weight": "4e8e347e3775010b7112ee630f2f4f2383be7ff64e6ca6154b9b22566552eaa6",
|
||||||
|
"blk.8.attn_output.weight": "65a44babf44a435a1829945211b3168f9ec78ac3cb7a049a733e93d11f0d6659",
|
||||||
|
"blk.8.attn_qkv.weight": "343ed07671da400b040812a4058482fa38284b5d9af9becfed07417fe26ce747",
|
||||||
|
"blk.8.ffn_down.weight": "7fb7e073e3c2c503c4e9d60efa0988fed7398d900cc003695fe3fffd3e188b82",
|
||||||
|
"blk.8.ffn_norm.weight": "b07c1f655d8593e3892a2cf73f8a0c19ce8e5cb613fafbe7cbd430da8ce4c57d",
|
||||||
|
"blk.8.ffn_up.weight": "8b26e14de54b3fdc2e2d3ea41720f9d9c236a93688c3b7fd7bf43f5fbb327c9b",
|
||||||
|
"blk.9.attn_norm.weight": "46394d408a8e316916177e6aa261de32e137a82d729c0b1800b072f0c38c39b6",
|
||||||
|
"blk.9.attn_output.weight": "d57f3d46107947a7073373a0b35d6ecf7759b5df15406f4a3590a60666af6b16",
|
||||||
|
"blk.9.attn_qkv.weight": "14bb8ace8c5453148f4b536e9f4279c813f31136716947256f5cca333448639c",
|
||||||
|
"blk.9.ffn_down.weight": "2b8d98e2b5ed68338f6e4de43bf7de0c4858cc69103cd5177725f7444eec7694",
|
||||||
|
"blk.9.ffn_norm.weight": "41a499dfd418cc4c6b8c12313f673f7e2cd4a3f9c4065eb6c4feb5eed02fb542",
|
||||||
|
"blk.9.ffn_up.weight": "143aab7533a64b17fbe201490a6f674bc7f0bd370c094500b2e100419073d1c2",
|
||||||
|
"blk.10.attn_norm.weight": "ebb670aafd36816a794347287269d8f1a5b19c1e3c0a1e38023bc19fdba9b073",
|
||||||
|
"blk.10.attn_output.weight": "b5d65bbc0ed5e49fdd9d754bc18163cd042a285024d0cf6f954c503bc8c877cb",
|
||||||
|
"blk.10.attn_qkv.weight": "f06b15bac88da798fa34a62b03eaac0dbe8b846020516603c387541f2d8dd672",
|
||||||
|
"blk.10.ffn_down.weight": "fb091fcd1b4de25d1bea94d1755e255cb02914a030d23e3a234e57b8d46bde6e",
|
||||||
|
"blk.10.ffn_norm.weight": "eb347bdf9c40414af87e13a8e72e40b31f004b50f7cb366f1a219ced60a61355",
|
||||||
|
"blk.10.ffn_up.weight": "ed2d52fc881a173f404fe8a1067862c9856d6c3e0d2e90a330a7aa394e3f84d1",
|
||||||
|
"blk.11.attn_norm.weight": "64e252603cf010a0e502ca39fdf8d0a196a79aec67c0d2bb9213fc0cb80c47d4",
|
||||||
|
"blk.11.attn_output.weight": "228e33e21c69f52efc74fdfc831bc9af271e44b2a29a3dced1d64e667ce36eb5",
|
||||||
|
"blk.11.attn_qkv.weight": "ab9ce6d4ef9e42ee0da3f20a7708a3bbc5e79e967b05fa86ba946a05e2eb63eb",
|
||||||
|
"blk.11.ffn_down.weight": "0ca133b7835c98dc77c25d64e4eb7873778bdb5e4d22d8b80f920f46865b43bd",
|
||||||
|
"blk.11.ffn_norm.weight": "02455741a0dfd161c79aa1ecc381901721f229fdcda5615622a629631fb61cfd",
|
||||||
|
"blk.11.ffn_up.weight": "9fecdcc099fbb8e23c6b1ea9294702a027f4a58d265543ec5e7be79b8f63b354",
|
||||||
|
"blk.12.attn_norm.weight": "783bb459911b1b3609a9b2bdfe272f1670add73b5471da738e07ac47e2e07dfd",
|
||||||
|
"blk.12.attn_output.weight": "1e1a914c9e48b857206ac5a1f7cead994bc1ea91d5d4fff8c834d73f2e38ef5d",
|
||||||
|
"blk.12.attn_qkv.weight": "5953e7185ccb87fb4dae8f9426ec86315d4c7794326e8ab59b3a95d4af2189f0",
|
||||||
|
"blk.12.ffn_down.weight": "a3eecf0f394f86e2cfb48a5940a5c50ca86d71883b2f79fcc642a935fabce0d4",
|
||||||
|
"blk.12.ffn_norm.weight": "0a4272e41373c23bd72f10d2d82930aa3a1480aac75832bfbf01cebf0b86b6a4",
|
||||||
|
"blk.12.ffn_up.weight": "06f42776de3a7ceac3025f26a7a8bd20e062233cce2bdaa2183470dc4b30b87d",
|
||||||
|
"blk.13.attn_norm.weight": "5915da60fb03e201fa649faba780e5fdf1c761c262b206e5415cf83181f65780",
|
||||||
|
"blk.13.attn_output.weight": "4dbf6eab074fa3835fd32bd631a8208e511037d5056d2fd3015735cca7674ef7",
|
||||||
|
"blk.13.attn_qkv.weight": "d3d8339a1c4782d9e73d77fdebe154d3c5b83ac40c9175b3e91a4977d08f876b",
|
||||||
|
"blk.13.ffn_down.weight": "de6772b46a55e1fd42b007637dfbf68b6598e5d5b61622da0935002e1e192d3a",
|
||||||
|
"blk.13.ffn_norm.weight": "5a640ea3b8c7be49c95a58a2327e10d8e8d9d142504bde5c8091613e5b961d7a",
|
||||||
|
"blk.13.ffn_up.weight": "f35e3545e4bd3531b2e843b5efd31dee0c13c807ee6386e65473ba67bbec30d0",
|
||||||
|
"blk.14.attn_norm.weight": "9b34986450b7c98b4927e81e61a816f9e84b1addc7c14926402100037aad6678",
|
||||||
|
"blk.14.attn_output.weight": "155d52efb23d366016d861a251d4d1f4a0c13699188c50d50dba016a0d8bfcd9",
|
||||||
|
"blk.14.attn_qkv.weight": "8e1415084e1f33c73a777f19e752489f4dd312cca047733e5ea643cd4a955e04",
|
||||||
|
"blk.14.ffn_down.weight": "a2a142226b94baa01ccb65bdea2b7418e49085c1d9c3c63e544e3112c58a25da",
|
||||||
|
"blk.14.ffn_norm.weight": "8aecfd9b0ae6affaea31a80c5c9a4a14b31deaa0db7bd8f6da2a64d23447921c",
|
||||||
|
"blk.14.ffn_up.weight": "0c1407237b8c1bd02f193346b5681926fe698a5055eac6a7450451b0f991707c",
|
||||||
|
"blk.15.attn_norm.weight": "e037bd19880bfa83d983200fb0c7866f8ad16c3ff5cc4b4f3a37ca7373870ff6",
|
||||||
|
"blk.15.attn_output.weight": "045fe4fc95cc129a1b92771b179c11b12845c4c088786c607f17bd98857e68e1",
|
||||||
|
"blk.15.attn_qkv.weight": "7621b7559705cab1d4dea1c69f76dbf9dc1c8837a203b656f484703b9c1b70ce",
|
||||||
|
"blk.15.ffn_down.weight": "7e5ac20e290bc60761e1cd972354fde225b7fa861048d44d9a0dd9b046d55f58",
|
||||||
|
"blk.15.ffn_norm.weight": "b6d830d88f1db1825687973c8c2b1a24c6fa84f07af8d0e3ef9c86009baca0b2",
|
||||||
|
"blk.15.ffn_up.weight": "dcda0957cd04fc45476774dba2bbf9aa89d6b05d5ca7b10ae6f73ad2c49b1cd3",
|
||||||
|
"blk.16.attn_norm.weight": "4ee9b70ba15cb2a08240f93990e90f5068c48fceb481f8e2186bec8b7214eb3f",
|
||||||
|
"blk.16.attn_output.weight": "315cfe5536658d2498192b2980eade15b2c9a4ff220e4011911457b1727fa103",
|
||||||
|
"blk.16.attn_qkv.weight": "3c8122e3ad637583b9dcde8ff3a323267d3014bb1f0f9771e5322260ca9ecc8d",
|
||||||
|
"blk.16.ffn_down.weight": "3b5fbebd5ee2b86cad96fb8a9b45a8770d08f82c1c8b74d7061e866f7020a18d",
|
||||||
|
"blk.16.ffn_norm.weight": "ffab69f20bda372de6e5878f0539163e2fc6ba113621ded95705fc3b1465c9f0",
|
||||||
|
"blk.16.ffn_up.weight": "0935ea3d258da42d6258406365f39f58ddaabfe97ea5977580db3635188f24a1",
|
||||||
|
"blk.17.attn_norm.weight": "f030441733f3d147b4a06a1eb4aeb8465c7c24d9c53bf4c48fe7e134d3629803",
|
||||||
|
"blk.17.attn_output.weight": "07a955ef09e8dc766ac0df647d0b2c69f23c4c69a7137654b4aad80303ed0eda",
|
||||||
|
"blk.17.attn_qkv.weight": "1c10688061e21e2fe12ad0cb54bf03895c1f83c3b0df743a42f548b52cbca1b2",
|
||||||
|
"blk.17.ffn_down.weight": "ebb9cc9836f41d88fdae2aa9a4355514e4edaec8d1577ffeb947a35204e77f52",
|
||||||
|
"blk.17.ffn_norm.weight": "50aff44f6528b13db5389f2ddcdb7676244947610bd7ffbff3f881c968c2a0d4",
|
||||||
|
"blk.17.ffn_up.weight": "d716537949582be33bde6b02e38f5a70081c9642a9fb05a61312126718b8d148",
|
||||||
|
"blk.18.attn_norm.weight": "0ea695c4e53d637902f46663a6ee42adc493c36794476acc7dbddaa05b13840d",
|
||||||
|
"blk.18.attn_output.weight": "5fd35b500221a612eb4f4bddf0e9b6b7db4d7733032a75f8802fb2d884647c2e",
|
||||||
|
"blk.18.attn_qkv.weight": "b0da37fd030fe69581f990bf23bfd35467a1bbe558af6de7c0924f6b72e92317",
|
||||||
|
"blk.18.ffn_down.weight": "b355c33f44b328f4bb977567de8f7544db4b005d7a8fbded658518ecf3c5a153",
|
||||||
|
"blk.18.ffn_norm.weight": "58b3fe9094079989a86e0387143259e1cc35952d24dc3df290c4ba6df44f5c51",
|
||||||
|
"blk.18.ffn_up.weight": "2ce530954c342c30ed2ead5353f931960bfae1d278868504c0efb973560fabbe",
|
||||||
|
"blk.19.attn_norm.weight": "533e9aed66feea8f0392aa81f9e293240e1f009a5334253915fb60c2749b615d",
|
||||||
|
"blk.19.attn_output.weight": "84f2d00f98a4113a779d3b5d1c3e7c914eb47784d3ab13b290367c124c2994aa",
|
||||||
|
"blk.19.attn_qkv.weight": "fbe6b9f53b07fa7537d3b3d452d20a9bc666f9fd41ec2091dd28bc2f70fc668f",
|
||||||
|
"blk.19.ffn_down.weight": "b30199e098c8bb3f890183d8b18471e80b62b604729b277ad62488dd71e1206b",
|
||||||
|
"blk.19.ffn_norm.weight": "c81373e41cd340b7badb19f9517c77c4250b4eb9a02dc758b8b49b652487d7ff",
|
||||||
|
"blk.19.ffn_up.weight": "5a5cb083ca7725720e3a890f7fa46354760e8007a8188849a092e305694a75e3",
|
||||||
|
"blk.20.attn_norm.weight": "4953091b4477e354357a8e743ba0a1900633e52f1599ee082a0c9b0b2b5cd978",
|
||||||
|
"blk.20.attn_output.weight": "62d54f7749cd6856097b2632066a322b0296df915fe66f382c5b5981be0d4f23",
|
||||||
|
"blk.20.attn_qkv.weight": "406de9e35b0729ebe902d7a47905cc7fb29a921431ed35dbef0c03e5690a1329",
|
||||||
|
"blk.20.ffn_down.weight": "62fb678b0d1261e19a4903a2b347d67afcc8acff01feb33a687a35a2d1e6f9a5",
|
||||||
|
"blk.20.ffn_norm.weight": "cd9d36b7e71e55c8925b97bb09c28219f182626bcff094878ae39c3db887a14b",
|
||||||
|
"blk.20.ffn_up.weight": "b9276771d79d3e932e73ccc520c3f8476342b9ef312ed2ee1e0da822e6e3ad18",
|
||||||
|
"blk.21.attn_norm.weight": "66d8c8a35e13ce9c2a0e75b670150e2c31484a55c2316df46075312196178ed3",
|
||||||
|
"blk.21.attn_output.weight": "12ab46c9382648f9b3350fdd92a6be6352743d62d6b520d7e2024e0c838588f5",
|
||||||
|
"blk.21.attn_qkv.weight": "a7909676ee1675ca23cd29a5fdd226df8dd9d68f94c6c9bbb51dd9fd38504008",
|
||||||
|
"blk.21.ffn_down.weight": "6fb317279c6542e82f97d5a12a60fac1bd0fa0405154f9fbe265e2fe39bd49cc",
|
||||||
|
"blk.21.ffn_norm.weight": "c0f703eb3ff161b5ba4490d87d8684b8a6c47a8f433e12f418333b9db439010a",
|
||||||
|
"blk.21.ffn_up.weight": "6dbdb80ef0c35e364bbce12d40d5e74c7963c7b55d58d9579567a07ffce7b863",
|
||||||
|
"blk.22.attn_norm.weight": "f94237433bf03d675cb2f655b81ca91a1ce2447bc6b00b13d6b0ccfe2d411eff",
|
||||||
|
"blk.22.attn_output.weight": "e821f95995ce497c01e63ca64f737713b1b65f11df1903e51d444aa516f33f71",
|
||||||
|
"blk.22.attn_qkv.weight": "1b0f717c73afb5eb4c82a1708c4e85c969e8a2a8770d9ddb78b1870a2d8a781e",
|
||||||
|
"blk.22.ffn_down.weight": "0f33f7a3cdc685484be99aa0c03642b0b20850a27d1fddbe054b13a9382f3ccb",
|
||||||
|
"blk.22.ffn_norm.weight": "9df285cf211ddd7df2b36a50489af574755c7d4d98b29a05cd04566ae613c8dc",
|
||||||
|
"blk.22.ffn_up.weight": "63ac300e1efb34041dd0136cf43ea622fac6f0caccce1cd9262f5e08d2cf179c",
|
||||||
|
"blk.23.attn_norm.weight": "5f72d9e88689b4027b28f5f8f26cd3abb03635ceea7ec98a4c91a9fc691f6707",
|
||||||
|
"blk.23.attn_output.weight": "6ecf04ff61125c5fc768f8656497152149373daf321ee9c957e8f7245a1184d1",
|
||||||
|
"blk.23.attn_qkv.weight": "a9d9978806724c2959f2cf386c233831f08e1e933dbf2b32665e788d9d512ea4",
|
||||||
|
"blk.23.ffn_down.weight": "72c7d17886a3da17fa0daa456aa5e877b2ef5b8b403182b870d9ca5ca9c70347",
|
||||||
|
"blk.23.ffn_norm.weight": "971e4b712e3025a13419b5b57d674b5e4ab7f18f74b57b9afc4671623da90c4b",
|
||||||
|
"blk.23.ffn_up.weight": "df2b5c7dbd5834545b815073af0c7355b065124e6d6f0fee78d8fa5b2076dc3e",
|
||||||
|
"blk.24.attn_norm.weight": "c41957c4a79ad3b16f6e11daec1c7f530b9f3f4b618e1e4367c3b67787ac4ab6",
|
||||||
|
"blk.24.attn_output.weight": "ef7d61f5fc88ac6f31bf60cb5f4d2d6b8df42d38825807112361a7224b0dee3b",
|
||||||
|
"blk.24.attn_qkv.weight": "3e6a58fe7d49c90bb6971efbad3371c32256881173ea5aee4b0c296cb206490f",
|
||||||
|
"blk.24.ffn_down.weight": "f43619144047de42fed81dfa495f1815d3cb771330e574043e2b67620819292c",
|
||||||
|
"blk.24.ffn_norm.weight": "5501d4a2a98c8ca6b42e77b53b221dbc08f530f6a067256d787534ec6fe028bd",
|
||||||
|
"blk.24.ffn_up.weight": "d64c8b0e509e2b1118f6000176f8956cacecdbb200c7e95ed93fb78b6e26c84a",
|
||||||
|
"blk.25.attn_norm.weight": "502fa3c302d371f61c5791f4615b73018ffb1daa09b6499b227116581244c5d4",
|
||||||
|
"blk.25.attn_output.weight": "ad8391d4e9c980856f2547aa945b2b6a407a6382158dc1ddd4f08d94ecc24be6",
|
||||||
|
"blk.25.attn_qkv.weight": "42e8983780d4a01a02c54ad23d4df21eea437f119a10af5a9c12a76a42d308c1",
|
||||||
|
"blk.25.ffn_down.weight": "302dd010d4e0ab4eeaee89090409ea0dddeeeed3236415eb8f97c942497eea91",
|
||||||
|
"blk.25.ffn_norm.weight": "fb34c1ee5bca96986c08834df0a0c047ba041c1123ac1f563e9d64312bf82d6a",
|
||||||
|
"blk.25.ffn_up.weight": "10739a8de156816d93c92b935386540bfa976bdbef204f0312960f6fc657582f",
|
||||||
|
"blk.26.attn_norm.weight": "7036c711609128c4e55968ff3681d3043338879a5737efd6c2ac9e1a2a61f1a0",
|
||||||
|
"blk.26.attn_output.weight": "db5db45dead5cb911fa01da59832f121b7c18b2d167bf53741c40819f24d346c",
|
||||||
|
"blk.26.attn_qkv.weight": "cae34c6b7f82ed14348d5ed30a79919c383737c1694a9cb9c0de609d3b0c1d0a",
|
||||||
|
"blk.26.ffn_down.weight": "491ec3a4da9b4f49f8ebc6be658ce397a9b801ae9fb35e82177e47808c65e5d0",
|
||||||
|
"blk.26.ffn_norm.weight": "fd7059d75d7f0e5288511ddeeb0f772eb3cae3ccfe4226b877015834edc3c386",
|
||||||
|
"blk.26.ffn_up.weight": "ea1ee1274c56458ce056d2205e5bb6e5422ce4cb0ad58006b8141749b97a0c39",
|
||||||
|
"blk.27.attn_norm.weight": "cc362c9a937609265052cd38544af17a1a7448cea086d4c801139e1fc865832d",
|
||||||
|
"blk.27.attn_output.weight": "ba757a81dabde9cb1b069d1bb616fe79649a1724f756567ec61caed1304fe6cf",
|
||||||
|
"blk.27.attn_qkv.weight": "1ab8d7d02d87756c12c2275636823aa5ede3d683178225c4cac4bd892c319bd4",
|
||||||
|
"blk.27.ffn_down.weight": "deb1c711c8a66acf4dcd2d088e1548f8e08f296f755e4067d6557fa55afde88c",
|
||||||
|
"blk.27.ffn_norm.weight": "fc6242d8cb8a4a37a8ddb7e41e7e60a63d4a89edf36acb35df052f10b9c91ece",
|
||||||
|
"blk.27.ffn_up.weight": "8df39b09c4801f343aca78f2918a1f6db78c8c55e591eda4c69eadb74c26e180",
|
||||||
|
"blk.28.attn_norm.weight": "75b539308f77e3cefdc6d98484d8b5cbf0538f0c2869a77b7373a145a18bc850",
|
||||||
|
"blk.28.attn_output.weight": "ae128940eb60a6d2e121762ef4b3e9dcf9eb3e105b249507fa7f12de0e19822c",
|
||||||
|
"blk.28.attn_qkv.weight": "bdda781c288e9326c240e33905f8e621b6a2ad902e620739d34f93fcd6f933de",
|
||||||
|
"blk.28.ffn_down.weight": "f1d6e6d1c286b1138bfd7e53fe477f399ae93bc2c04e35416f84218ed7247965",
|
||||||
|
"blk.28.ffn_norm.weight": "3f837ce82c8b9bde0d61d08b6f5fe5574886ea5328dbdc53f2929f18da8b4087",
|
||||||
|
"blk.28.ffn_up.weight": "2af027002e31d1b6cfedbdb30a2b9d7213f3aa691167c353913adfd48fda31e4",
|
||||||
|
"blk.29.attn_norm.weight": "61e8003b5329462ffe0fe172f2b160260de006aed858332d49d75504b6b6aa7a",
|
||||||
|
"blk.29.attn_output.weight": "ca44542a72a37476dc73dbdcc01f5b7497cb3ebc4ea230a55c9634ccd8e56ad4",
|
||||||
|
"blk.29.attn_qkv.weight": "abb3d9d6abe57872ae3daa51935d43264093ded5ce63b49d1e280ee5758be0e4",
|
||||||
|
"blk.29.ffn_down.weight": "6764b895fce881df097489c263446f0106de36217997660c15984b3ee22a5a06",
|
||||||
|
"blk.29.ffn_norm.weight": "89e03e9a33fc0e6e31ba9f0c2bd7c5734a118c5602bb90148793e08a80e8d0ae",
|
||||||
|
"blk.29.ffn_up.weight": "fa7ad57a84954f4121653152efed1a871d8adb20a1ea9086e3e849ce359d7d2e",
|
||||||
|
"blk.30.attn_norm.weight": "91a697aca1e42af54f806a20211031c3369e8d0bd58df1b0147fe24954e1f5a4",
|
||||||
|
"blk.30.attn_output.weight": "36063fcf766c89ac75be56f688cc63cefe5f2c733fbf4378ea9956ad386fa148",
|
||||||
|
"blk.30.attn_qkv.weight": "2cacd1161f1121a2c0b979930134f4666f73fb8d7237b3b0659ae091b15955a6",
|
||||||
|
"blk.30.ffn_down.weight": "9f3fcb6217100595850c05dc98f9ab2a263afdb6ab28df2fcb08aeff512057d7",
|
||||||
|
"blk.30.ffn_norm.weight": "6c600bc1fc7de39d4f8917b81fc7d1d5ed2a9b56492234c13a4bd6028c30d880",
|
||||||
|
"blk.30.ffn_up.weight": "73cabd1bb011956b2689ea3338bb76642ef3a57c197377d666d2ab5f56317668",
|
||||||
|
"blk.31.attn_norm.weight": "72d3e1cc771380645fa75a899858c95f39857a4f3f1ed60fe1578df383b8bc53",
|
||||||
|
"blk.31.attn_output.weight": "40089cdd29994dc19a1d89fa15902a89cfeca3540f12dc9bf4d00ef82506e456",
|
||||||
|
"blk.31.attn_qkv.weight": "1d0bb40e9258071ae14290a53c619a8e331dda07354d2a02ef45766c029ae5e4",
|
||||||
|
"blk.31.ffn_down.weight": "8defa0e06335b793fa8be03883f0a322d6c5b33f52c69c943c35c60d16e42c0a",
|
||||||
|
"blk.31.ffn_norm.weight": "33c55d9d0c496ccfb130361fe131649346e098abaaac39c0519507e5d846721d",
|
||||||
|
"blk.31.ffn_up.weight": "599f6503f61c692c1f82001973d35119f9688db5e6be9d9c298411491c93f09b",
|
||||||
|
"output.weight": "14b8dc662bfa3308ebb2e102c562d8e52c15670e538f20f3216a9c310ca9dd41",
|
||||||
|
"output_norm.weight": "7f2294ba94ce65681df6c7ddd8698799199b9d77dc83c10bdad5c3999f0fdb82",
|
||||||
|
"rope_factors_long.weight": "e34d378664e354652c38f47d10dafb0498ccc2fb042d39ff7fef768146fff22b",
|
||||||
|
"rope_factors_short.weight": "9379146a4988f373d362fe47b06c75e7fe7c54aa4dc9558758df79b7a87471fd",
|
||||||
|
"token_embd.weight": "19a03c1fb5ac0baee93b0a7d8b0f26e9a9b011e229b694afc50ebfc13d84f8bf"
|
||||||
|
}
|
|
@ -669,7 +669,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||||
|
|
||||||
```
|
```
|
||||||
curl http://localhost:11434/api/chat -d '{
|
curl http://localhost:11434/api/chat -d '{
|
||||||
"model": "mistral",
|
"model": "llama3.1",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -708,7 +708,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"model": "mistral:7b-instruct-v0.3-q4_K_M",
|
"model": "llama3.1",
|
||||||
"created_at": "2024-07-22T20:33:28.123648Z",
|
"created_at": "2024-07-22T20:33:28.123648Z",
|
||||||
"message": {
|
"message": {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -1175,7 +1175,10 @@ curl http://localhost:11434/api/embed -d '{
|
||||||
"embeddings": [[
|
"embeddings": [[
|
||||||
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||||
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||||
]]
|
]],
|
||||||
|
"total_duration": 14143917,
|
||||||
|
"load_duration": 1019500,
|
||||||
|
"prompt_eval_count": 8
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,9 @@ If the model being imported is one of these architectures, it can be imported di
|
||||||
|
|
||||||
- LlamaForCausalLM
|
- LlamaForCausalLM
|
||||||
- MistralForCausalLM
|
- MistralForCausalLM
|
||||||
|
- MixtralForCausalLM
|
||||||
- GemmaForCausalLM
|
- GemmaForCausalLM
|
||||||
|
- Phi3ForCausalLM
|
||||||
|
|
||||||
```dockerfile
|
```dockerfile
|
||||||
FROM /path/to/safetensors/directory
|
FROM /path/to/safetensors/directory
|
||||||
|
|
|
@ -182,7 +182,6 @@ curl http://localhost:11434/v1/embeddings \
|
||||||
- [x] Reproducible outputs
|
- [x] Reproducible outputs
|
||||||
- [x] Vision
|
- [x] Vision
|
||||||
- [x] Tools (streaming support coming soon)
|
- [x] Tools (streaming support coming soon)
|
||||||
- [ ] Vision
|
|
||||||
- [ ] Logprobs
|
- [ ] Logprobs
|
||||||
|
|
||||||
#### Supported request fields
|
#### Supported request fields
|
||||||
|
|
|
@ -112,15 +112,9 @@ Keep the following tips and best practices in mind when working with Go template
|
||||||
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
|
||||||
|
|
||||||
```gotmpl
|
```gotmpl
|
||||||
{{- if .System }}<|im_start|>system
|
|
||||||
{{ .System }}<|im_end|>
|
|
||||||
{{ end }}
|
|
||||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||||
{{ .Content }}<|im_end|>
|
{{ .Content }}<|im_end|>
|
||||||
{{ end }}<|im_start|>assistant
|
{{ end }}<|im_start|>assistant
|
||||||
{{ else }}
|
|
||||||
{{ if .System }}<|im_start|>system
|
|
||||||
{{ .System }}<|im_end|>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example Tools
|
### Example Tools
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module github.com/ollama/ollama
|
module github.com/ollama/ollama
|
||||||
|
|
||||||
go 1.22.0
|
go 1.22.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/containerd/console v1.0.3
|
github.com/containerd/console v1.0.3
|
||||||
|
|
|
@ -49,13 +49,9 @@ func PayloadsDir() (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track our pid so we can clean up orphaned tmpdirs
|
// Track our pid so we can clean up orphaned tmpdirs
|
||||||
pidFilePath := filepath.Join(tmpDir, "ollama.pid")
|
n := filepath.Join(tmpDir, "ollama.pid")
|
||||||
pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
|
if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil {
|
||||||
if err != nil {
|
return "", fmt.Errorf("failed to write pid file %s: %w", n, err)
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We create a distinct subdirectory for payloads within the tmpdir
|
// We create a distinct subdirectory for payloads within the tmpdir
|
||||||
|
@ -67,37 +63,44 @@ func PayloadsDir() (string, error) {
|
||||||
|
|
||||||
// Best effort to clean up prior tmpdirs
|
// Best effort to clean up prior tmpdirs
|
||||||
func cleanupTmpDirs() {
|
func cleanupTmpDirs() {
|
||||||
dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*"))
|
matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, d := range dirs {
|
|
||||||
info, err := os.Stat(d)
|
for _, match := range matches {
|
||||||
if err != nil || !info.IsDir() {
|
raw, err := os.ReadFile(match)
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Debug("not a ollama runtime directory, skipping", "path", match)
|
||||||
continue
|
continue
|
||||||
}
|
} else if err != nil {
|
||||||
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
|
slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err)
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
|
|
||||||
// No pid, ignore this tmpdir
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pid, err := strconv.Atoi(string(raw))
|
pid, err := strconv.Atoi(string(raw))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to parse pid", "path", d, "error", err)
|
slog.Warn("invalid pid, skipping", "path", match, "error", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
proc, err := os.FindProcess(pid)
|
p, err := os.FindProcess(pid)
|
||||||
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||||
slog.Warn("found running ollama", "pid", pid, "path", d)
|
slog.Warn("process still running, skipping", "pid", pid, "path", match)
|
||||||
// Another running ollama, ignore this tmpdir
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Remove(d); err != nil {
|
if err := os.Remove(match); err != nil {
|
||||||
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
|
slog.Warn("could not cleanup stale pidfile", "path", match, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runners := filepath.Join(filepath.Dir(match), "runners")
|
||||||
|
if err := os.RemoveAll(runners); err != nil {
|
||||||
|
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Remove(filepath.Dir(match)); err != nil {
|
||||||
|
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
61
gpu/gpu.go
61
gpu/gpu.go
|
@ -305,38 +305,41 @@ func GetGPUInfo() GpuInfoList {
|
||||||
// Intel
|
// Intel
|
||||||
if envconfig.IntelGPU() {
|
if envconfig.IntelGPU() {
|
||||||
oHandles = initOneAPIHandles()
|
oHandles = initOneAPIHandles()
|
||||||
// On windows we bundle the oneapi library one level above the runner dir
|
if oHandles != nil && oHandles.oneapi != nil {
|
||||||
depPath = ""
|
|
||||||
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
|
||||||
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
|
|
||||||
}
|
|
||||||
|
|
||||||
for d := range oHandles.oneapi.num_drivers {
|
// On windows we bundle the oneapi library one level above the runner dir
|
||||||
if oHandles.oneapi == nil {
|
depPath = ""
|
||||||
// shouldn't happen
|
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
||||||
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
|
||||||
for i := range devCount {
|
for d := range oHandles.oneapi.num_drivers {
|
||||||
gpuInfo := OneapiGPUInfo{
|
if oHandles.oneapi == nil {
|
||||||
GpuInfo: GpuInfo{
|
// shouldn't happen
|
||||||
Library: "oneapi",
|
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
||||||
},
|
continue
|
||||||
driverIndex: int(d),
|
}
|
||||||
gpuIndex: int(i),
|
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
||||||
|
for i := range devCount {
|
||||||
|
gpuInfo := OneapiGPUInfo{
|
||||||
|
GpuInfo: GpuInfo{
|
||||||
|
Library: "oneapi",
|
||||||
|
},
|
||||||
|
driverIndex: int(d),
|
||||||
|
gpuIndex: int(i),
|
||||||
|
}
|
||||||
|
// TODO - split bootstrapping from updating free memory
|
||||||
|
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
||||||
|
// TODO - convert this to MinimumMemory based on testing...
|
||||||
|
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||||
|
memInfo.free = C.uint64_t(totalFreeMem)
|
||||||
|
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||||
|
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||||
|
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||||
|
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||||
|
gpuInfo.DependencyPath = depPath
|
||||||
|
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||||
}
|
}
|
||||||
// TODO - split bootstrapping from updating free memory
|
|
||||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
|
||||||
// TODO - convert this to MinimumMemory based on testing...
|
|
||||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
|
||||||
memInfo.free = C.uint64_t(totalFreeMem)
|
|
||||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
|
||||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
|
||||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
|
||||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
|
||||||
gpuInfo.DependencyPath = depPath
|
|
||||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
56
llm/ext_server/server.cpp
vendored
56
llm/ext_server/server.cpp
vendored
|
@ -403,7 +403,9 @@ struct llama_server_context
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
auto init_result = llama_init_from_gpt_params(params);
|
||||||
|
model = init_result.model;
|
||||||
|
ctx = init_result.context;
|
||||||
if (model == nullptr)
|
if (model == nullptr)
|
||||||
{
|
{
|
||||||
LOG_ERROR("unable to load model", {{"model", params.model}});
|
LOG_ERROR("unable to load model", {{"model", params.model}});
|
||||||
|
@ -1221,9 +1223,7 @@ struct llama_server_context
|
||||||
|
|
||||||
res.result_json = json
|
res.result_json = json
|
||||||
{
|
{
|
||||||
{"id", res.id},
|
|
||||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||||
{"timings", slot.get_formated_timings()},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2422,7 +2422,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.lora_adapter.emplace_back(argv[i], 1.0f);
|
params.lora_adapters.push_back({
|
||||||
|
std::string(argv[i]),
|
||||||
|
1.0,
|
||||||
|
});
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
}
|
}
|
||||||
else if (arg == "--lora-scaled")
|
else if (arg == "--lora-scaled")
|
||||||
|
@ -2438,7 +2441,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
|
params.lora_adapters.push_back({
|
||||||
|
lora_adapter,
|
||||||
|
std::stof(argv[i])
|
||||||
|
});
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
}
|
}
|
||||||
else if (arg == "-v" || arg == "--verbose")
|
else if (arg == "-v" || arg == "--verbose")
|
||||||
|
@ -3186,41 +3192,17 @@ int main(int argc, char **argv) {
|
||||||
prompt = "";
|
prompt = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prompt.size() == 1) {
|
|
||||||
prompt = prompt[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses;
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
{
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
const int id_task = llama.queue_tasks.get_new_id();
|
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
|
||||||
llama.queue_results.add_waiting_task_id(id_task);
|
|
||||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
task_result result = llama.queue_results.recv(id_task);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
llama.queue_results.remove_waiting_task_id(id_task);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
if (result.error) {
|
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
|
|
||||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
// send the result
|
||||||
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
|
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||||
return a["id"] < b["id"];
|
|
||||||
});
|
|
||||||
|
|
||||||
json embeddings = json::array();
|
|
||||||
|
|
||||||
int prompt_n = 0;
|
|
||||||
for (auto & elem : responses) {
|
|
||||||
embeddings.push_back(elem.at("embedding"));
|
|
||||||
prompt_n += elem.at("timings").at("prompt_n").get<int>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the result
|
|
||||||
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
|
|
||||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||||
|
|
|
@ -157,6 +157,14 @@ type Tensor struct {
|
||||||
io.WriterTo `json:"-"`
|
io.WriterTo `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Tensor) block() (n int) {
|
||||||
|
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
switch t.Kind {
|
||||||
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16
|
||||||
|
|
15
llm/gguf.go
15
llm/gguf.go
|
@ -532,15 +532,14 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(ts, func(a, b Tensor) int {
|
slices.SortStableFunc(ts, func(a, b Tensor) int {
|
||||||
var i, j int
|
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
||||||
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
|
return 1
|
||||||
return cmp.Compare(a.Name, b.Name)
|
} else if i > 0 && j < 0 {
|
||||||
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
|
return -1
|
||||||
return cmp.Compare(a.Name, b.Name)
|
} else {
|
||||||
|
return cmp.Compare(i, j)
|
||||||
}
|
}
|
||||||
|
|
||||||
return cmp.Compare(i, j)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 6eeaeba126ff701f3e8f79f246805b7023709972
|
Subproject commit 1e6f6554aa11fa10160a5fda689e736c3c34169f
|
|
@ -1,40 +1,32 @@
|
||||||
diff --git a/common/common.cpp b/common/common.cpp
|
diff --git a/common/common.cpp b/common/common.cpp
|
||||||
index dbb724fb..c26fe6ee 100644
|
index 2e8374d5..70d0afde 100644
|
||||||
--- a/common/common.cpp
|
--- a/common/common.cpp
|
||||||
+++ b/common/common.cpp
|
+++ b/common/common.cpp
|
||||||
@@ -2087,14 +2087,27 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
@@ -2110,9 +2110,21 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
|
||||||
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
|
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
|
||||||
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
|
if (loaded_la.adapter == nullptr) {
|
||||||
float lora_scale = std::get<1>(params.lora_adapter[i]);
|
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||||
+
|
|
||||||
+ // try to load as gguf
|
|
||||||
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
|
|
||||||
if (adapter == nullptr) {
|
|
||||||
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
|
||||||
- llama_free(lctx);
|
- llama_free(lctx);
|
||||||
- llama_free_model(model);
|
- llama_free_model(model);
|
||||||
- return std::make_tuple(nullptr, nullptr);
|
- return iparams;
|
||||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
|
|
||||||
+
|
+
|
||||||
+ // if that fails, try loading as ggla for compatibility
|
+ // if that fails, try loading as ggla for compatibility
|
||||||
+ int err = llama_model_apply_lora_from_file(model,
|
+ int err = llama_model_apply_lora_from_file(model,
|
||||||
+ lora_adapter.c_str(),
|
+ la.path.c_str(),
|
||||||
+ lora_scale,
|
+ la.scale,
|
||||||
+ nullptr,
|
+ nullptr,
|
||||||
+ params.n_threads);
|
+ params.n_threads);
|
||||||
+ if (err != 0) {
|
+ if (err != 0) {
|
||||||
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||||
+ llama_free(lctx);
|
+ llama_free(lctx);
|
||||||
+ llama_free_model(model);
|
+ llama_free_model(model);
|
||||||
+ return std::make_tuple(nullptr, nullptr);
|
+ return iparams;
|
||||||
|
+ } else {
|
||||||
|
+ break;
|
||||||
+ }
|
+ }
|
||||||
+ } else {
|
|
||||||
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
|
|
||||||
}
|
}
|
||||||
- llama_lora_adapter_set(lctx, adapter, lora_scale);
|
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.ignore_eos) {
|
|
||||||
diff --git a/include/llama.h b/include/llama.h
|
diff --git a/include/llama.h b/include/llama.h
|
||||||
index 93fd77ca..b0fb37a6 100644
|
index 93fd77ca..b0fb37a6 100644
|
||||||
--- a/include/llama.h
|
--- a/include/llama.h
|
||||||
|
@ -355,4 +347,4 @@ index 80a0dd0f..9d7b0e17 100644
|
||||||
+ return 1;
|
+ return 1;
|
||||||
+ }
|
+ }
|
||||||
+}
|
+}
|
||||||
\ No newline at end of file
|
\ No newline at end of file
|
|
@ -1,20 +0,0 @@
|
||||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
|
||||||
index a207451f..fba6b175 100644
|
|
||||||
--- a/src/llama.cpp
|
|
||||||
+++ b/src/llama.cpp
|
|
||||||
@@ -4969,6 +4969,7 @@ static void llm_load_hparams(
|
|
||||||
hparams.attn_soft_cap = true;
|
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
|
||||||
+ case 26: model.type = e_model::MODEL_2B; break;
|
|
||||||
case 42: model.type = e_model::MODEL_9B; break;
|
|
||||||
case 46: model.type = e_model::MODEL_27B; break;
|
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
|
||||||
@@ -11736,6 +11737,7 @@ struct llm_build_context {
|
|
||||||
|
|
||||||
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
|
||||||
switch (model.type) {
|
|
||||||
+ case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
|
|
||||||
case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
|
|
||||||
case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
|
|
||||||
default: GGML_ABORT("fatal error");
|
|
|
@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
|
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
|
@ -44,11 +44,12 @@ type LlamaServer interface {
|
||||||
|
|
||||||
// llmServer is an instance of the llama.cpp server
|
// llmServer is an instance of the llama.cpp server
|
||||||
type llmServer struct {
|
type llmServer struct {
|
||||||
port int
|
port int
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
done chan error // Channel to signal when the process exits
|
done chan error // Channel to signal when the process exits
|
||||||
status *StatusWriter
|
status *StatusWriter
|
||||||
options api.Options
|
options api.Options
|
||||||
|
numParallel int
|
||||||
|
|
||||||
estimate MemoryEstimate
|
estimate MemoryEstimate
|
||||||
totalLayers uint64
|
totalLayers uint64
|
||||||
|
@ -124,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// On linux, over-allocating CPU memory will almost always result in an error
|
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||||
if runtime.GOOS == "linux" {
|
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||||
available := systemFreeMemory + systemSwapFreeMemory
|
available := systemFreeMemory + systemSwapFreeMemory
|
||||||
if systemMemoryRequired > available {
|
if systemMemoryRequired > available {
|
||||||
|
@ -343,6 +345,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||||
status: NewStatusWriter(os.Stderr),
|
status: NewStatusWriter(os.Stderr),
|
||||||
options: opts,
|
options: opts,
|
||||||
estimate: estimate,
|
estimate: estimate,
|
||||||
|
numParallel: numParallel,
|
||||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
totalLayers: ggml.KV().BlockCount() + 1,
|
totalLayers: ggml.KV().BlockCount() + 1,
|
||||||
gpus: gpus,
|
gpus: gpus,
|
||||||
|
@ -880,16 +883,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Content []string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbedResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding [][]float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
PromptEvalCount int `json:"prompt_n"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -904,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -931,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var e EmbedResponse
|
var e EmbeddingResponse
|
||||||
if err := json.Unmarshal(body, &e); err != nil {
|
if err := json.Unmarshal(body, &e); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &e, nil
|
return e.Embedding, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|
|
@ -26,6 +26,7 @@ var errorPrefixes = []string{
|
||||||
"cudaMalloc failed",
|
"cudaMalloc failed",
|
||||||
"\"ERR\"",
|
"\"ERR\"",
|
||||||
"error loading model",
|
"error loading model",
|
||||||
|
"GGML_ASSERT",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||||
|
|
|
@ -7,27 +7,22 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
prefix = `data:image/jpeg;base64,`
|
prefix = `data:image/jpeg;base64,`
|
||||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
imageURL = prefix + image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func prepareRequest(req *http.Request, body any) {
|
var False = false
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
|
||||||
func TestChatMiddleware(t *testing.T) {
|
func TestChatMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
req api.ChatRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.ChatRequest
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
name: "chat handler",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"messages": [
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
{"role": "user", "content": "Hello"}
|
||||||
}
|
]
|
||||||
prepareRequest(req, body)
|
}`,
|
||||||
},
|
req: api.ChatRequest{
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
Model: "test-model",
|
||||||
if resp.Code != http.StatusOK {
|
Messages: []api.Message{
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
{
|
||||||
}
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
if req.Messages[0].Role != "user" {
|
},
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
},
|
||||||
}
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
if req.Messages[0].Content != "Hello" {
|
"top_p": 1.0,
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
},
|
||||||
}
|
Stream: &False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
name: "chat handler with image content",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"messages": [
|
||||||
Messages: []Message{
|
{
|
||||||
{
|
"role": "user",
|
||||||
Role: "user", Content: []map[string]any{
|
"content": [
|
||||||
{"type": "text", "text": "Hello"},
|
{
|
||||||
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
"type": "text",
|
||||||
|
"text": "Hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "` + prefix + image + `"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Images: []api.ImageData{
|
||||||
|
func() []byte {
|
||||||
|
img, _ := base64.StdEncoding.DecodeString(image)
|
||||||
|
return img
|
||||||
|
}(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
},
|
||||||
|
Stream: &False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chat handler with tools",
|
||||||
|
body: `{
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||||
|
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||||
|
]
|
||||||
|
}`,
|
||||||
|
req: api.ChatRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "What's the weather like in Paris Today?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: map[string]interface{}{
|
||||||
|
"location": "Paris, France",
|
||||||
|
"format": "celsius",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
},
|
||||||
prepareRequest(req, body)
|
Options: map[string]any{
|
||||||
},
|
"temperature": 1.0,
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
"top_p": 1.0,
|
||||||
if resp.Code != http.StatusOK {
|
},
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
Stream: &False,
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Role != "user" {
|
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Content != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
|
||||||
|
|
||||||
if req.Messages[1].Role != "user" {
|
|
||||||
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
|
||||||
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
{
|
{
|
||||||
Name: "chat handler with tools",
|
name: "chat handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := ChatCompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"messages": [
|
||||||
Messages: []Message{
|
{"role": "user", "content": 2}
|
||||||
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
]
|
||||||
{Role: "assistant", ToolCalls: []ToolCall{{
|
}`,
|
||||||
ID: "id",
|
err: ErrorResponse{
|
||||||
Type: "function",
|
Error: Error{
|
||||||
Function: struct {
|
Message: "invalid message content type: float64",
|
||||||
Name string `json:"name"`
|
Type: "invalid_request_error",
|
||||||
Arguments string `json:"arguments"`
|
},
|
||||||
}{
|
|
||||||
Name: "get_current_weather",
|
|
||||||
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
|
||||||
},
|
|
||||||
}}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != 200 {
|
|
||||||
t.Fatalf("expected 200, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
|
||||||
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
|
||||||
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
|
||||||
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "chat handler error forwarding",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := ChatCompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Messages: []Message{{Role: "user", Content: 2}},
|
|
||||||
}
|
|
||||||
prepareRequest(req, body)
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
|
|
||||||
func TestCompletionsMiddleware(t *testing.T) {
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
req api.GenerateRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.GenerateRequest
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "completions handler",
|
name: "completions handler",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
temp := float32(0.8)
|
"model": "test-model",
|
||||||
body := CompletionRequest{
|
"prompt": "Hello",
|
||||||
Model: "test-model",
|
"temperature": 0.8,
|
||||||
Prompt: "Hello",
|
"stop": ["\n", "stop"],
|
||||||
Temperature: &temp,
|
"suffix": "suffix"
|
||||||
Stop: []string{"\n", "stop"},
|
}`,
|
||||||
Suffix: "suffix",
|
req: api.GenerateRequest{
|
||||||
}
|
Model: "test-model",
|
||||||
prepareRequest(req, body)
|
Prompt: "Hello",
|
||||||
},
|
Options: map[string]any{
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
"frequency_penalty": 0.0,
|
||||||
if req.Prompt != "Hello" {
|
"presence_penalty": 0.0,
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
"temperature": 1.6,
|
||||||
}
|
"top_p": 1.0,
|
||||||
|
"stop": []any{"\n", "stop"},
|
||||||
if req.Options["temperature"] != 1.6 {
|
},
|
||||||
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
Suffix: "suffix",
|
||||||
}
|
Stream: &False,
|
||||||
|
|
||||||
stopTokens, ok := req.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Suffix != "suffix" {
|
|
||||||
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "completions handler error forwarding",
|
name: "completions handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := CompletionRequest{
|
"model": "test-model",
|
||||||
Model: "test-model",
|
"prompt": "Hello",
|
||||||
Prompt: "Hello",
|
"temperature": null,
|
||||||
Temperature: nil,
|
"stop": [1, 2],
|
||||||
Stop: []int{1, 2},
|
"suffix": "suffix"
|
||||||
Suffix: "suffix",
|
}`,
|
||||||
}
|
err: ErrorResponse{
|
||||||
prepareRequest(req, body)
|
Error: Error{
|
||||||
},
|
Message: "invalid type for 'stop' field: float64",
|
||||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
Type: "invalid_request_error",
|
||||||
if resp.Code != http.StatusBadRequest {
|
},
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
|
@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
|
||||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Setup func(t *testing.T, req *http.Request)
|
body string
|
||||||
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
req api.EmbedRequest
|
||||||
|
err ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
var capturedRequest *api.EmbedRequest
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
name: "embed handler single input",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"input": "Hello",
|
||||||
Input: "Hello",
|
"model": "test-model"
|
||||||
Model: "test-model",
|
}`,
|
||||||
}
|
req: api.EmbedRequest{
|
||||||
prepareRequest(req, body)
|
Input: "Hello",
|
||||||
},
|
Model: "test-model",
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
if req.Input != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", req.Input)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
name: "embed handler batch input",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"input": ["Hello", "World"],
|
||||||
Input: []string{"Hello", "World"},
|
"model": "test-model"
|
||||||
Model: "test-model",
|
}`,
|
||||||
}
|
req: api.EmbedRequest{
|
||||||
prepareRequest(req, body)
|
Input: []any{"Hello", "World"},
|
||||||
},
|
Model: "test-model",
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
|
||||||
input, ok := req.Input.([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected input to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if input[0].(string) != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", input[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if input[1].(string) != "World" {
|
|
||||||
t.Fatalf("expected 'World', got %s", input[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler error forwarding",
|
name: "embed handler error forwarding",
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
body: `{
|
||||||
body := EmbedRequest{
|
"model": "test-model"
|
||||||
Model: "test-model",
|
}`,
|
||||||
}
|
err: ErrorResponse{
|
||||||
prepareRequest(req, body)
|
Error: Error{
|
||||||
},
|
Message: "invalid input",
|
||||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
Type: "invalid_request_error",
|
||||||
if resp.Code != http.StatusBadRequest {
|
},
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), "invalid input") {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
tc.Setup(t, req)
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest, resp)
|
var errResp ErrorResponse
|
||||||
|
if resp.Code != http.StatusOK {
|
||||||
|
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||||
|
t.Fatal("requests did not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(tc.err, errResp) {
|
||||||
|
t.Fatal("errors did not match")
|
||||||
|
}
|
||||||
|
|
||||||
capturedRequest = nil
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMiddlewareResponses(t *testing.T) {
|
func TestListMiddleware(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
Name string
|
name string
|
||||||
Method string
|
endpoint func(c *gin.Context)
|
||||||
Path string
|
resp string
|
||||||
TestPath string
|
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Endpoint func(c *gin.Context)
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
name: "list handler",
|
||||||
Method: http.MethodGet,
|
endpoint: func(c *gin.Context) {
|
||||||
Path: "/api/tags",
|
|
||||||
TestPath: "/api/tags",
|
|
||||||
Handler: ListMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, api.ListResponse{
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
Models: []api.ListModelResponse{
|
Models: []api.ListModelResponse{
|
||||||
{
|
{
|
||||||
Name: "Test Model",
|
Name: "test-model",
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
resp: `{
|
||||||
var listResp ListCompletion
|
"object": "list",
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
"data": [
|
||||||
t.Fatal(err)
|
{
|
||||||
}
|
"id": "test-model",
|
||||||
|
"object": "model",
|
||||||
if listResp.Object != "list" {
|
"created": 1686935002,
|
||||||
t.Fatalf("expected list, got %s", listResp.Object)
|
"owned_by": "library"
|
||||||
}
|
}
|
||||||
|
]
|
||||||
if len(listResp.Data) != 1 {
|
}`,
|
||||||
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
|
||||||
}
|
|
||||||
|
|
||||||
if listResp.Data[0].Id != "Test Model" {
|
|
||||||
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "retrieve model",
|
name: "list handler empty output",
|
||||||
Method: http.MethodGet,
|
endpoint: func(c *gin.Context) {
|
||||||
Path: "/api/show/:model",
|
c.JSON(http.StatusOK, api.ListResponse{})
|
||||||
TestPath: "/api/show/test-model",
|
|
||||||
Handler: RetrieveMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, api.ShowResponse{
|
|
||||||
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
|
||||||
})
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
var retrieveResp Model
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieveResp.Object != "model" {
|
|
||||||
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
|
||||||
}
|
|
||||||
|
|
||||||
if retrieveResp.Id != "test-model" {
|
|
||||||
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
resp: `{
|
||||||
|
"object": "list",
|
||||||
|
"data": null
|
||||||
|
}`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
router := gin.New()
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
router := gin.New()
|
||||||
router = gin.New()
|
router.Use(ListMiddleware())
|
||||||
router.Use(tc.Handler())
|
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||||
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||||
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
resp := httptest.NewRecorder()
|
||||||
tc.Setup(t, req)
|
router.ServeHTTP(resp, req)
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
var expected, actual map[string]any
|
||||||
router.ServeHTTP(resp, req)
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
})
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrieveMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
endpoint func(c *gin.Context)
|
||||||
|
resp string
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
name: "retrieve handler",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"id":"test-model",
|
||||||
|
"object":"model",
|
||||||
|
"created":1686935002,
|
||||||
|
"owned_by":"library"}
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve handler error forwarding",
|
||||||
|
endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||||
|
},
|
||||||
|
resp: `{
|
||||||
|
"error": {
|
||||||
|
"code": null,
|
||||||
|
"message": "model not found",
|
||||||
|
"param": null,
|
||||||
|
"type": "api_error"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(RetrieveMiddleware())
|
||||||
|
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
var expected, actual map[string]any
|
||||||
|
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(expected, actual) {
|
||||||
|
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,11 +3,12 @@ package progress
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Spinner struct {
|
type Spinner struct {
|
||||||
message string
|
message atomic.Value
|
||||||
messageWidth int
|
messageWidth int
|
||||||
|
|
||||||
parts []string
|
parts []string
|
||||||
|
@ -21,20 +22,25 @@ type Spinner struct {
|
||||||
|
|
||||||
func NewSpinner(message string) *Spinner {
|
func NewSpinner(message string) *Spinner {
|
||||||
s := &Spinner{
|
s := &Spinner{
|
||||||
message: message,
|
|
||||||
parts: []string{
|
parts: []string{
|
||||||
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
|
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
|
||||||
},
|
},
|
||||||
started: time.Now(),
|
started: time.Now(),
|
||||||
}
|
}
|
||||||
|
s.SetMessage(message)
|
||||||
go s.start()
|
go s.start()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Spinner) SetMessage(message string) {
|
||||||
|
s.message.Store(message)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Spinner) String() string {
|
func (s *Spinner) String() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if len(s.message) > 0 {
|
|
||||||
message := strings.TrimSpace(s.message)
|
if message, ok := s.message.Load().(string); ok && len(message) > 0 {
|
||||||
|
message := strings.TrimSpace(message)
|
||||||
if s.messageWidth > 0 && len(message) > s.messageWidth {
|
if s.messageWidth > 0 && len(message) > s.messageWidth {
|
||||||
message = message[:s.messageWidth]
|
message = message[:s.messageWidth]
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ func (b *Buffer) MoveLeft() {
|
||||||
rLength := runewidth.RuneWidth(r)
|
rLength := runewidth.RuneWidth(r)
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
if rLength == 2 {
|
if rLength == 2 {
|
||||||
fmt.Print(CursorLeft)
|
fmt.Print(CursorLeft)
|
||||||
}
|
}
|
||||||
|
@ -74,7 +74,7 @@ func (b *Buffer) MoveLeft() {
|
||||||
fmt.Print(CursorLeft)
|
fmt.Print(CursorLeft)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Pos -= 1
|
b.Pos -= 1
|
||||||
|
@ -115,15 +115,15 @@ func (b *Buffer) MoveRight() {
|
||||||
b.DisplayPos += rLength
|
b.DisplayPos += rLength
|
||||||
|
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())+rLength))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
b.DisplayPos += 1
|
b.DisplayPos += 1
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(rLength))
|
fmt.Print(CursorRightN(rLength))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() {
|
||||||
fmt.Print(CursorUp)
|
fmt.Print(CursorUp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())))
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
b.DisplayPos = 0
|
b.DisplayPos = 0
|
||||||
}
|
}
|
||||||
|
@ -169,9 +169,9 @@ func (b *Buffer) MoveToEnd() {
|
||||||
fmt.Print(CursorDown)
|
fmt.Print(CursorDown)
|
||||||
}
|
}
|
||||||
remainder := b.DisplaySize() % b.LineWidth
|
remainder := b.DisplaySize() % b.LineWidth
|
||||||
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
|
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())+remainder))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
|
fmt.Print(CursorRightN(b.DisplaySize() - b.DisplayPos))
|
||||||
}
|
}
|
||||||
|
|
||||||
b.Pos = b.Buf.Size()
|
b.Pos = b.Buf.Size()
|
||||||
|
@ -286,8 +286,7 @@ func (b *Buffer) drawRemaining() {
|
||||||
remLength := runewidth.StringWidth(remainingText)
|
remLength := runewidth.StringWidth(remainingText)
|
||||||
|
|
||||||
if len(currLine) > 0 {
|
if len(currLine) > 0 {
|
||||||
fmt.Printf(ClearToEOL + currLine)
|
fmt.Print(ClearToEOL + currLine + CursorLeftN(currLineSpace))
|
||||||
fmt.Print(cursorLeftN(currLineSpace))
|
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL)
|
||||||
}
|
}
|
||||||
|
@ -301,9 +300,9 @@ func (b *Buffer) drawRemaining() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
|
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
|
||||||
fmt.Print(cursorRightN(currLineSpace))
|
fmt.Print(CursorRightN(currLineSpace))
|
||||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
|
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width-currLineSpace))
|
||||||
}
|
}
|
||||||
|
|
||||||
// render the other lines
|
// render the other lines
|
||||||
|
@ -333,9 +332,7 @@ func (b *Buffer) drawRemaining() {
|
||||||
lineLength += runewidth.RuneWidth(c)
|
lineLength += runewidth.RuneWidth(c)
|
||||||
fmt.Printf("%c", c)
|
fmt.Printf("%c", c)
|
||||||
}
|
}
|
||||||
fmt.Print(ClearToEOL)
|
fmt.Print(ClearToEOL + CursorUpN(totalLines) + CursorBOL + CursorRightN(b.Width-currLineSpace))
|
||||||
fmt.Print(cursorUpN(totalLines))
|
|
||||||
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
|
|
||||||
|
|
||||||
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
|
||||||
|
|
||||||
|
@ -357,8 +354,7 @@ func (b *Buffer) Remove() {
|
||||||
if b.DisplayPos%b.LineWidth == 0 {
|
if b.DisplayPos%b.LineWidth == 0 {
|
||||||
// if the user backspaces over the word boundary, do this magic to clear the line
|
// if the user backspaces over the word boundary, do this magic to clear the line
|
||||||
// and move to the end of the previous line
|
// and move to the end of the previous line
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
|
||||||
|
|
||||||
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
|
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
|
||||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
|
@ -370,24 +366,23 @@ func (b *Buffer) Remove() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if rLength == 2 {
|
if rLength == 2 {
|
||||||
fmt.Print(CursorLeft + " " + cursorLeftN(2))
|
fmt.Print(CursorLeft + " " + CursorLeftN(2))
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(" " + CursorLeft)
|
fmt.Print(" " + CursorLeft)
|
||||||
}
|
}
|
||||||
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
|
||||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
|
||||||
|
|
||||||
if b.Pos == b.Buf.Size() {
|
if b.Pos == b.Buf.Size() {
|
||||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||||
}
|
}
|
||||||
b.DisplayPos -= 1
|
b.DisplayPos -= 1
|
||||||
} else {
|
} else {
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
for range rLength {
|
for range rLength {
|
||||||
fmt.Print(" ")
|
fmt.Print(" ")
|
||||||
}
|
}
|
||||||
fmt.Print(cursorLeftN(rLength))
|
fmt.Print(CursorLeftN(rLength))
|
||||||
}
|
}
|
||||||
|
|
||||||
var eraseExtraLine bool
|
var eraseExtraLine bool
|
||||||
|
@ -405,9 +400,9 @@ func (b *Buffer) Remove() {
|
||||||
// are trailing characters which go over the line width boundary
|
// are trailing characters which go over the line width boundary
|
||||||
if eraseExtraLine {
|
if eraseExtraLine {
|
||||||
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
fmt.Print(CursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
|
||||||
place := b.DisplayPos % b.LineWidth
|
place := b.DisplayPos % b.LineWidth
|
||||||
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Print(CursorUpN(remainingLines+1) + CursorRightN(place+len(b.Prompt.prompt())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -422,9 +417,9 @@ func (b *Buffer) Delete() {
|
||||||
if b.DisplaySize()%b.LineWidth == 0 {
|
if b.DisplaySize()%b.LineWidth == 0 {
|
||||||
if b.DisplayPos != b.DisplaySize() {
|
if b.DisplayPos != b.DisplaySize() {
|
||||||
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
|
||||||
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
fmt.Print(CursorDownN(remainingLines) + CursorBOL + ClearToEOL)
|
||||||
place := b.DisplayPos % b.LineWidth
|
place := b.DisplayPos % b.LineWidth
|
||||||
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
|
fmt.Print(CursorUpN(remainingLines) + CursorRightN(place+len(b.Prompt.prompt())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -471,17 +466,17 @@ func (b *Buffer) DeleteWord() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) ClearScreen() {
|
func (b *Buffer) ClearScreen() {
|
||||||
fmt.Printf(ClearScreen + CursorReset + b.Prompt.prompt())
|
fmt.Print(ClearScreen + CursorReset + b.Prompt.prompt())
|
||||||
if b.IsEmpty() {
|
if b.IsEmpty() {
|
||||||
ph := b.Prompt.placeholder()
|
ph := b.Prompt.placeholder()
|
||||||
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
|
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
|
||||||
} else {
|
} else {
|
||||||
currPos := b.DisplayPos
|
currPos := b.DisplayPos
|
||||||
currIndex := b.Pos
|
currIndex := b.Pos
|
||||||
b.Pos = 0
|
b.Pos = 0
|
||||||
b.DisplayPos = 0
|
b.DisplayPos = 0
|
||||||
b.drawRemaining()
|
b.drawRemaining()
|
||||||
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
|
fmt.Print(CursorReset + CursorRightN(len(b.Prompt.prompt())))
|
||||||
if currPos > 0 {
|
if currPos > 0 {
|
||||||
targetLine := currPos / b.LineWidth
|
targetLine := currPos / b.LineWidth
|
||||||
if targetLine > 0 {
|
if targetLine > 0 {
|
||||||
|
@ -491,10 +486,10 @@ func (b *Buffer) ClearScreen() {
|
||||||
}
|
}
|
||||||
remainder := currPos % b.LineWidth
|
remainder := currPos % b.LineWidth
|
||||||
if remainder > 0 {
|
if remainder > 0 {
|
||||||
fmt.Print(cursorRightN(remainder))
|
fmt.Print(CursorRightN(remainder))
|
||||||
}
|
}
|
||||||
if currPos%b.LineWidth == 0 {
|
if currPos%b.LineWidth == 0 {
|
||||||
fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
|
fmt.Print(CursorBOL + b.Prompt.AltPrompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.Pos = currIndex
|
b.Pos = currIndex
|
||||||
|
@ -513,13 +508,13 @@ func (b *Buffer) Replace(r []rune) {
|
||||||
|
|
||||||
b.Buf.Clear()
|
b.Buf.Clear()
|
||||||
|
|
||||||
fmt.Printf(CursorBOL + ClearToEOL)
|
fmt.Print(CursorBOL + ClearToEOL)
|
||||||
|
|
||||||
for range lineNums {
|
for range lineNums {
|
||||||
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf(CursorBOL + b.Prompt.prompt())
|
fmt.Print(CursorBOL + b.Prompt.prompt())
|
||||||
|
|
||||||
for _, c := range r {
|
for _, c := range r {
|
||||||
b.Add(c)
|
b.Add(c)
|
||||||
|
@ -545,19 +540,3 @@ func (b *Buffer) StringNM(n, m int) string {
|
||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func cursorLeftN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorLeftN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorRightN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorRightN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorUpN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorUpN, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cursorDownN(n int) string {
|
|
||||||
return fmt.Sprintf(CursorDownN, n)
|
|
||||||
}
|
|
||||||
|
|
|
@ -98,7 +98,7 @@ func (i *Instance) Readline() (string, error) {
|
||||||
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
|
||||||
if buf.IsEmpty() && showPlaceholder {
|
if buf.IsEmpty() && showPlaceholder {
|
||||||
ph := i.Prompt.placeholder()
|
ph := i.Prompt.placeholder()
|
||||||
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
|
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
|
||||||
}
|
}
|
||||||
|
|
||||||
r, err := i.Terminal.Read()
|
r, err := i.Terminal.Read()
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package readline
|
package readline
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CharNull = 0
|
CharNull = 0
|
||||||
CharLineStart = 1
|
CharLineStart = 1
|
||||||
|
@ -41,34 +43,49 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
CursorUp = "\033[1A"
|
Esc = "\x1b"
|
||||||
CursorDown = "\033[1B"
|
|
||||||
CursorRight = "\033[1C"
|
|
||||||
CursorLeft = "\033[1D"
|
|
||||||
|
|
||||||
CursorSave = "\033[s"
|
CursorSave = Esc + "[s"
|
||||||
CursorRestore = "\033[u"
|
CursorRestore = Esc + "[u"
|
||||||
|
|
||||||
CursorUpN = "\033[%dA"
|
CursorEOL = Esc + "[E"
|
||||||
CursorDownN = "\033[%dB"
|
CursorBOL = Esc + "[1G"
|
||||||
CursorRightN = "\033[%dC"
|
CursorHide = Esc + "[?25l"
|
||||||
CursorLeftN = "\033[%dD"
|
CursorShow = Esc + "[?25h"
|
||||||
|
|
||||||
CursorEOL = "\033[E"
|
ClearToEOL = Esc + "[K"
|
||||||
CursorBOL = "\033[1G"
|
ClearLine = Esc + "[2K"
|
||||||
CursorHide = "\033[?25l"
|
ClearScreen = Esc + "[2J"
|
||||||
CursorShow = "\033[?25h"
|
CursorReset = Esc + "[0;0f"
|
||||||
|
|
||||||
ClearToEOL = "\033[K"
|
ColorGrey = Esc + "[38;5;245m"
|
||||||
ClearLine = "\033[2K"
|
ColorDefault = Esc + "[0m"
|
||||||
ClearScreen = "\033[2J"
|
|
||||||
CursorReset = "\033[0;0f"
|
|
||||||
|
|
||||||
ColorGrey = "\033[38;5;245m"
|
StartBracketedPaste = Esc + "[?2004h"
|
||||||
ColorDefault = "\033[0m"
|
EndBracketedPaste = Esc + "[?2004l"
|
||||||
|
)
|
||||||
|
|
||||||
StartBracketedPaste = "\033[?2004h"
|
func CursorUpN(n int) string {
|
||||||
EndBracketedPaste = "\033[?2004l"
|
return Esc + "[" + strconv.Itoa(n) + "A"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorDownN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "B"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorRightN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "C"
|
||||||
|
}
|
||||||
|
|
||||||
|
func CursorLeftN(n int) string {
|
||||||
|
return Esc + "[" + strconv.Itoa(n) + "D"
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
CursorUp = CursorUpN(1)
|
||||||
|
CursorDown = CursorDownN(1)
|
||||||
|
CursorRight = CursorRightN(1)
|
||||||
|
CursorLeft = CursorLeftN(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -209,15 +209,15 @@ install_cuda_driver_yum() {
|
||||||
case $PACKAGE_MANAGER in
|
case $PACKAGE_MANAGER in
|
||||||
yum)
|
yum)
|
||||||
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||||
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
dnf)
|
dnf)
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||||
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
|
@ -245,8 +245,8 @@ install_cuda_driver_yum() {
|
||||||
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
||||||
install_cuda_driver_apt() {
|
install_cuda_driver_apt() {
|
||||||
status 'Installing NVIDIA repository...'
|
status 'Installing NVIDIA repository...'
|
||||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
||||||
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
|
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb
|
||||||
else
|
else
|
||||||
error $CUDA_REPO_ERR_MSG
|
error $CUDA_REPO_ERR_MSG
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -94,7 +94,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
numDownloadParts = 64
|
numDownloadParts = 16
|
||||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
@ -216,6 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
setSparse(file)
|
||||||
|
|
||||||
_ = file.Truncate(b.Total)
|
_ = file.Truncate(b.Total)
|
||||||
|
|
||||||
|
@ -232,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||||
|
|
||||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||||
if len(via) > 10 {
|
if len(via) > 10 {
|
||||||
return errors.New("maxium redirects exceeded (10) for directURL")
|
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the hostname is the same, allow the redirect
|
// if the hostname is the same, allow the redirect
|
||||||
|
|
|
@ -250,19 +250,21 @@ func GetModel(name string) (*Model, error) {
|
||||||
Template: template.DefaultTemplate,
|
Template: template.DefaultTemplate,
|
||||||
}
|
}
|
||||||
|
|
||||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
if manifest.Config.Digest != "" {
|
||||||
if err != nil {
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
configFile, err := os.Open(filename)
|
configFile, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer configFile.Close()
|
defer configFile.Close()
|
||||||
|
|
||||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range manifest.Layers {
|
||||||
|
@ -371,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
var messages []*api.Message
|
var messages []*api.Message
|
||||||
parameters := make(map[string]any)
|
parameters := make(map[string]any)
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
for _, c := range modelfile.Commands {
|
for _, c := range modelfile.Commands {
|
||||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
|
|
||||||
|
@ -497,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
|
|
||||||
if c.Name != "license" {
|
if c.Name != "license" {
|
||||||
// replace
|
// replace
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||||
if layer.MediaType != mediatype {
|
if layer.MediaType != mediatype {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -543,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
}
|
}
|
||||||
|
|
||||||
var err2 error
|
var err2 error
|
||||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||||
switch layer.MediaType {
|
switch layer.MediaType {
|
||||||
case "application/vnd.ollama.image.message":
|
case "application/vnd.ollama.image.message":
|
||||||
// if there are new messages, remove the inherited ones
|
// if there are new messages, remove the inherited ones
|
||||||
|
@ -623,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range append(layers, layer) {
|
for _, layer := range append(layers, configLayer) {
|
||||||
if layer.status != "" {
|
if layer.status != "" {
|
||||||
fn(api.ProgressResponse{Status: layer.status})
|
fn(api.ProgressResponse{Status: layer.status})
|
||||||
}
|
}
|
||||||
|
@ -637,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
old, _ := ParseNamedManifest(name)
|
old, _ := ParseNamedManifest(name)
|
||||||
|
|
||||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||||
if err := WriteManifest(name, layer, layers); err != nil {
|
if err := WriteManifest(name, configLayer, layers); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -714,8 +716,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
|
||||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||||
manifest, _, err := GetManifest(fmp)
|
manifest, _, err := GetManifest(fmp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//nolint:nilerr
|
return err
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, layer := range manifest.Layers {
|
for _, layer := range manifest.Layers {
|
||||||
|
@ -782,7 +783,8 @@ func PruneLayers() error {
|
||||||
|
|
||||||
err = deleteUnusedLayers(nil, deleteMap)
|
err = deleteUnusedLayers(nil, deleteMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
|
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
|
||||||
|
@ -837,9 +839,11 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, manifest.Layers...)
|
||||||
layers = append(layers, manifest.Config)
|
if manifest.Config.Digest != "" {
|
||||||
|
layers = append(layers, manifest.Config)
|
||||||
|
}
|
||||||
|
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
||||||
|
@ -890,7 +894,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
for _, l := range manifest.Layers {
|
for _, l := range manifest.Layers {
|
||||||
deleteMap[l.Digest] = struct{}{}
|
deleteMap[l.Digest] = struct{}{}
|
||||||
}
|
}
|
||||||
deleteMap[manifest.Config.Digest] = struct{}{}
|
if manifest.Config.Digest != "" {
|
||||||
|
deleteMap[manifest.Config.Digest] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -905,9 +911,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
return fmt.Errorf("pull model manifest: %s", err)
|
return fmt.Errorf("pull model manifest: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []*Layer
|
var layers []Layer
|
||||||
layers = append(layers, manifest.Layers...)
|
layers = append(layers, manifest.Layers...)
|
||||||
layers = append(layers, manifest.Config)
|
if manifest.Config.Digest != "" {
|
||||||
|
layers = append(layers, manifest.Config)
|
||||||
|
}
|
||||||
|
|
||||||
skipVerify := make(map[string]bool)
|
skipVerify := make(map[string]bool)
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
|
@ -971,7 +979,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
fn(api.ProgressResponse{Status: "removing any unused layers"})
|
fn(api.ProgressResponse{Status: "removing any unused layers"})
|
||||||
err = deleteUnusedLayers(nil, deleteMap)
|
err = deleteUnusedLayers(nil, deleteMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
|
||||||
|
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
@ -15,15 +16,15 @@ type Layer struct {
|
||||||
status string
|
status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||||
blobs, err := GetBlobsPath("")
|
blobs, err := GetBlobsPath("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
temp, err := os.CreateTemp(blobs, "sha256-")
|
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
defer temp.Close()
|
defer temp.Close()
|
||||||
defer os.Remove(temp.Name())
|
defer os.Remove(temp.Name())
|
||||||
|
@ -31,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||||
sha256sum := sha256.New()
|
sha256sum := sha256.New()
|
||||||
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := temp.Close(); err != nil {
|
if err := temp.Close(); err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := "using existing layer"
|
status := "using existing layer"
|
||||||
if _, err := os.Stat(blob); err != nil {
|
if _, err := os.Stat(blob); err != nil {
|
||||||
status = "creating new layer"
|
status = "creating new layer"
|
||||||
if err := os.Rename(temp.Name(), blob); err != nil {
|
if err := os.Rename(temp.Name(), blob); err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: n,
|
Size: n,
|
||||||
|
@ -60,18 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||||
|
if digest == "" {
|
||||||
|
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||||
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(digest)
|
blob, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fi, err := os.Stat(blob)
|
fi, err := os.Stat(blob)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return Layer{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Layer{
|
return Layer{
|
||||||
MediaType: mediatype,
|
MediaType: mediatype,
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Size: fi.Size(),
|
Size: fi.Size(),
|
||||||
|
@ -81,6 +86,10 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||||
|
if l.Digest == "" {
|
||||||
|
return nil, errors.New("opening layer with empty digest")
|
||||||
|
}
|
||||||
|
|
||||||
blob, err := GetBlobsPath(l.Digest)
|
blob, err := GetBlobsPath(l.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -90,6 +99,10 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Layer) Remove() error {
|
func (l *Layer) Remove() error {
|
||||||
|
if l.Digest == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
ms, err := Manifests()
|
ms, err := Manifests()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -14,10 +14,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
MediaType string `json:"mediaType"`
|
MediaType string `json:"mediaType"`
|
||||||
Config *Layer `json:"config"`
|
Config Layer `json:"config"`
|
||||||
Layers []*Layer `json:"layers"`
|
Layers []Layer `json:"layers"`
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
|
@ -47,10 +47,12 @@ func (m *Manifest) Remove() error {
|
||||||
|
|
||||||
func (m *Manifest) RemoveLayers() error {
|
func (m *Manifest) RemoveLayers() error {
|
||||||
for _, layer := range append(m.Layers, m.Config) {
|
for _, layer := range append(m.Layers, m.Config) {
|
||||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
if layer.Digest != "" {
|
||||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||||
} else if err != nil {
|
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||||
return err
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||||
manifests, err := GetManifestPath()
|
manifests, err := GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -26,7 +26,7 @@ import (
|
||||||
var intermediateBlobs map[string]string = make(map[string]string)
|
var intermediateBlobs map[string]string = make(map[string]string)
|
||||||
|
|
||||||
type layerGGML struct {
|
type layerGGML struct {
|
||||||
*Layer
|
Layer
|
||||||
*llm.GGML
|
*llm.GGML
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,9 +176,20 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
||||||
mediatype = "application/vnd.ollama.image.projector"
|
mediatype = "application/vnd.ollama.image.projector"
|
||||||
}
|
}
|
||||||
|
|
||||||
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
|
var layer Layer
|
||||||
if err != nil {
|
if digest != "" && n == stat.Size() && offset == 0 {
|
||||||
return nil, err
|
layer, err = NewLayerFromLayer(digest, mediatype, file.Name())
|
||||||
|
if err != nil {
|
||||||
|
slog.Debug("could not create new layer from layer", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
||||||
|
if layer.Digest == "" {
|
||||||
|
layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
layers = append(layers, &layerGGML{layer, ggml})
|
layers = append(layers, &layerGGML{layer, ggml})
|
||||||
|
|
|
@ -2,8 +2,10 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -11,6 +13,7 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -133,3 +136,82 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseFromFileFromLayer(t *testing.T) {
|
||||||
|
tempModels := t.TempDir()
|
||||||
|
|
||||||
|
file, err := os.CreateTemp(tempModels, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
|
||||||
|
t.Fatalf("failed to write gguf: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(layers) != 1 {
|
||||||
|
t.Fatalf("got %d != want 1", len(layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
if len(layers2) != 1 {
|
||||||
|
t.Fatalf("got %d != want 1", len(layers2))
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].Digest != layers2[0].Digest {
|
||||||
|
t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].Size != layers2[0].Size {
|
||||||
|
t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
if layers[0].MediaType != layers2[0].MediaType {
|
||||||
|
t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLayerFromCopy(t *testing.T) {
|
||||||
|
tempModels := t.TempDir()
|
||||||
|
|
||||||
|
file2, err := os.CreateTemp(tempModels, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file2.Close()
|
||||||
|
|
||||||
|
for range 5 {
|
||||||
|
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
|
||||||
|
t.Fatalf("failed to write gguf: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := file2.Seek(0, io.SeekStart); err != nil {
|
||||||
|
t.Fatalf("failed to seek to start: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse from file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(layers) != 5 {
|
||||||
|
t.Fatalf("got %d != want 5", len(layers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
|
|
||||||
"github.com/gin-contrib/cors"
|
"github.com/gin-contrib/cors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
@ -323,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
input = append(input, v.(string))
|
input = append(input, v.(string))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
if req.Input != nil {
|
||||||
return
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
|
||||||
}
|
return
|
||||||
|
}
|
||||||
if len(input) == 0 {
|
|
||||||
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||||
|
@ -340,12 +338,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
if len(input) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
kvData, err := getKVData(m.ModelPath, false)
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var count int
|
||||||
for i, s := range input {
|
for i, s := range input {
|
||||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -368,25 +372,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count += len(tokens)
|
||||||
|
|
||||||
input[i] = s
|
input[i] = s
|
||||||
}
|
}
|
||||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
|
||||||
if err != nil {
|
var g errgroup.Group
|
||||||
slog.Error("embedding generation failed", "error", err)
|
embeddings := make([][]float32, len(input))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
for i, text := range input {
|
||||||
return
|
g.Go(func() error {
|
||||||
|
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
embeddings[i] = normalize(embedding)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, e := range embeddings.Embedding {
|
if err := g.Wait(); err != nil {
|
||||||
embeddings.Embedding[i] = normalize(e)
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbedResponse{
|
resp := api.EmbedResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Embeddings: embeddings.Embedding,
|
Embeddings: embeddings,
|
||||||
TotalDuration: time.Since(checkpointStart),
|
TotalDuration: time.Since(checkpointStart),
|
||||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||||
PromptEvalCount: embeddings.PromptEvalCount,
|
PromptEvalCount: count,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
@ -430,21 +445,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
embedding := make([]float64, len(embeddings.Embedding[0]))
|
var e []float64
|
||||||
|
for _, v := range embedding {
|
||||||
for i, v := range embeddings.Embedding[0] {
|
e = append(e, float64(v))
|
||||||
embedding[i] = float64(v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
resp := api.EmbeddingResponse{
|
||||||
Embedding: embedding,
|
Embedding: e,
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
@ -824,17 +838,20 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||||
|
|
||||||
models := []api.ListModelResponse{}
|
models := []api.ListModelResponse{}
|
||||||
for n, m := range ms {
|
for n, m := range ms {
|
||||||
f, err := m.Config.Open()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
var cf ConfigV2
|
var cf ConfigV2
|
||||||
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
|
||||||
slog.Warn("bad manifest config", "name", n, "error", err)
|
if m.Config.Digest != "" {
|
||||||
continue
|
f, err := m.Config.Open()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("bad manifest filepath", "name", n, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err := json.NewDecoder(f).Decode(&cf); err != nil {
|
||||||
|
slog.Warn("bad manifest config", "name", n, "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tag should never be masked
|
// tag should never be masked
|
||||||
|
|
|
@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a manifest with duplicate layers
|
// create a manifest with duplicate layers
|
||||||
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
|
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Name: "Embed Handler Empty Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: "",
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var embedResp api.EmbedResponse
|
|
||||||
err = json.Unmarshal(body, &embedResp)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Model != "t-bone" {
|
|
||||||
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedResp.Embeddings == nil {
|
|
||||||
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embedResp.Embeddings) != 0 {
|
|
||||||
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "Embed Handler Invalid Input",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
embedReq := api.EmbedRequest{
|
|
||||||
Model: "t-bone",
|
|
||||||
Input: 2,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(embedReq)
|
|
||||||
require.NoError(t, err)
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *http.Response) {
|
|
||||||
contentType := resp.Header.Get("Content-Type")
|
|
||||||
if contentType != "application/json; charset=utf-8" {
|
|
||||||
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
|
|
||||||
}
|
|
||||||
_, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
|
@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
|
||||||
// some older models are not compatible with newer versions of llama.cpp
|
// some older models are not compatible with newer versions of llama.cpp
|
||||||
// show a generalized compatibility error until there is a better way to
|
// show a generalized compatibility error until there is a better way to
|
||||||
// check for model compatibility
|
// check for model compatibility
|
||||||
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
|
if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||||
}
|
}
|
||||||
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||||
|
|
|
@ -708,8 +708,8 @@ type mockLlm struct {
|
||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embedResp *llm.EmbedResponse
|
embeddingResp []float32
|
||||||
embedRespErr error
|
embeddingRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
detokenizeResp string
|
detokenizeResp string
|
||||||
|
@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||||
return s.embedResp, s.embedRespErr
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|
8
server/sparse_common.go
Normal file
8
server/sparse_common.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
func setSparse(*os.File) {
|
||||||
|
}
|
17
server/sparse_windows.go
Normal file
17
server/sparse_windows.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setSparse(file *os.File) {
|
||||||
|
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||||
|
windows.DeviceIoControl( //nolint:errcheck
|
||||||
|
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||||
|
nil, 0,
|
||||||
|
nil, 0,
|
||||||
|
nil, nil,
|
||||||
|
)
|
||||||
|
}
|
|
@ -26,7 +26,7 @@ import (
|
||||||
var blobUploadManager sync.Map
|
var blobUploadManager sync.Map
|
||||||
|
|
||||||
type blobUpload struct {
|
type blobUpload struct {
|
||||||
*Layer
|
Layer
|
||||||
|
|
||||||
Total int64
|
Total int64
|
||||||
Completed atomic.Int64
|
Completed atomic.Int64
|
||||||
|
@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
|
||||||
p.written = 0
|
p.written = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
requestURL := mp.BaseURL()
|
requestURL := mp.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||||
|
|
||||||
|
|
|
@ -219,7 +219,7 @@ func (n Name) String() string {
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DisplayShort returns a short string version of the name.
|
// DisplayShortest returns a short string version of the name.
|
||||||
func (n Name) DisplayShortest() string {
|
func (n Name) DisplayShortest() string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue