Compare commits

..

52 commits

Author SHA1 Message Date
9e08c23ba9
Merge https://github.com/ollama/ollama 2024-08-14 21:04:15 +05:30
longtao
0a8d6ea86d
Fix typo and improve readability (#5964)
* Fix typo and improve readability

Summary:
* Rename updatAvailableMenuID to updateAvailableMenuID
* Replace unused cmd parameter with _ in RunServer function
* Fix typos in comments

(cherry picked from commit 5b8715f0b04773369e8eb1f9e6737995a0ab3ba7)

* Update api/client.go

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
2024-08-13 17:54:19 -07:00
Blake Mizerany
8e1050f366
server: reduce max connections used in download (#6347)
The previous value of 64 was WAY too high and unnecessary. It reached
diminishing returns and blew past it. This is a more reasonable number
for _most_ normal cases. For users on cloud servers with excellent
network quality, this will keep screaming for them, without hitting our
CDN limits. For users with relatively poor network quality, this will
keep them from saturating their network and causing other issues.
2024-08-13 16:47:35 -07:00
Bruce MacDonald
eda8a32a09
update chatml template format to latest in docs (#6344) 2024-08-13 16:39:18 -07:00
Michael Yang
a0a40aa20c
Merge pull request #6346 from ollama/mxyng/lint 2024-08-13 14:58:35 -07:00
Michael Yang
2697d7f5aa lint
- fixes printf: non-constant format string in call to fmt.Printf
- fixes SA1032: arguments have the wrong order
- disables testifylint
2024-08-13 14:36:33 -07:00
Pamela Fox
1f32276178
Update openai.md to remove extra checkbox (#6345) 2024-08-13 13:36:05 -07:00
Daniel Hiltgen
4c4fe3f87f
Merge pull request #6343 from dhiltgen/revert_win_go_version
Go back to a pinned Go version
2024-08-13 11:53:49 -07:00
Daniel Hiltgen
feedf49c71 Go back to a pinned Go version
Go version 1.22.6 is triggering AV false positives, so go back to 1.22.5
2024-08-13 11:45:44 -07:00
royjhan
8b00a415ab
Load Embedding Model on Empty Input (#6325)
* load on empty input

* no load on invalid input
2024-08-13 10:19:56 -07:00
Michael Yang
01b80e9ffc
Merge pull request #5443 from ollama/mxyng/convert-phi3
add conversion for microsoft phi 3 mini/medium 4k, 128k
2024-08-12 15:47:58 -07:00
Michael Yang
bd5e432630 update import.md 2024-08-12 15:13:29 -07:00
Bruce MacDonald
aec77d6a05 support new "longrope" attention factor 2024-08-12 15:13:29 -07:00
Michael Yang
6ffb5cb017 add conversion for microsoft phi 3 mini/medium 4k, 128 2024-08-12 15:13:29 -07:00
Josh
f7e3b9190f
cmd: spinner progress for transfer model data (#6100) 2024-08-12 11:46:32 -07:00
Josh
980dd15f81
cmd: speed up gguf creates (#6324) 2024-08-12 11:46:09 -07:00
royjhan
01d544d373
OpenAI: Simplify input output in testing (#5858)
* simplify input output

* direct comp

* in line image

* rm error pointer type

* update response testing

* lint
2024-08-12 10:33:34 -07:00
Josh
1dc3ef3aa9
Revert "server: speed up single gguf creates (#5898)" (#6323)
This reverts commit 8aac22438e.
2024-08-12 09:57:51 -07:00
Josh
8aac22438e
server: speed up single gguf creates (#5898) 2024-08-12 09:28:55 -07:00
Jeffrey Morgan
15c2d8fe14
server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)
For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
2024-08-11 11:57:10 -07:00
Daniel Hiltgen
25906d72d1
llm: prevent loading too large models on windows (#5926)
Don't allow loading models that would lead to memory exhaustion (across vram, system memory and disk paging). This check was already applied on Linux but should also be applied on Windows as well.
2024-08-11 11:30:20 -07:00
CognitiveTech
023451ce47
add integration obook-summary (#6305) 2024-08-10 18:43:08 -07:00
Jesse Gross
9b53e39d8e
Merge pull request #6258 from coolljt0725/fix_typo
server/download.go: Fix a typo in log
2024-08-09 17:19:48 -07:00
Michael Yang
97fae2df95
Merge pull request #6235 from Nicholas42/fix_line_endings
Set *.png and *.ico to be treated as binary files.
2024-08-09 17:06:30 -07:00
Michael Yang
160d9d4900
Merge pull request #6171 from ollama/mxyng/remove-temp
removeall to remove non-empty temp dirs
2024-08-09 15:47:13 -07:00
Nicholas Schwab
d4e6407464 Restrict text files with explicit line feeds to *.go.
This partially reverts b732beba6a. It
seems like explicitly setting all files to use line feeds was done due
to issues with the go linter, hence it can be restricted to those files
(https://github.com/ollama/ollama/pull/6235#issuecomment-2278745953).
2024-08-09 23:14:13 +02:00
Daniel Hiltgen
b7f7d8cd15
Merge pull request #6291 from dhiltgen/no_sparse_fail
Don't hard fail on sparse setup error
2024-08-09 12:30:25 -07:00
Daniel Hiltgen
2fa1db4345 Don't hard fail on sparse setup error
It seems this can fail in some casees, but proceed
with the download anyway.
2024-08-09 12:16:19 -07:00
Daniel Hiltgen
71b0945fc6
Merge pull request #6290 from dhiltgen/intel_npe
Harden intel boostrap for nil pointers
2024-08-09 12:14:42 -07:00
Daniel Hiltgen
5bca2e60a7 Harden intel boostrap for nil pointers 2024-08-09 11:31:38 -07:00
Nicholas42
67472e0e89
Also flag *.icns as binary 2024-08-09 13:41:20 +02:00
Daniel Hiltgen
e9aa5117c4
Merge pull request #6133 from dhiltgen/cuda_repo
Adjust arm cuda repo paths
2024-08-08 12:33:35 -07:00
Daniel Hiltgen
2473bdba5e
Merge pull request #6182 from dhiltgen/more_patterns
Catch one more error log
2024-08-08 12:33:17 -07:00
Jesse Gross
7d1c0047fa
Merge pull request #6247 from ollama/jessegross/layers
Store layers inside manifests consistently as values.
2024-08-08 10:46:43 -07:00
Jitang Lei
7b61eba471 server/download.go: Fix a typo in log
Signed-off-by: Jitang Lei <leijitang@outlook.com>
2024-08-08 20:28:01 +08:00
Jesse Gross
7edaf6e7e8 manifest: Store layers inside manifests consistently as values.
Commit 1829fb61 ("manifest: Fix crash on startup when trying to clean up
unused files (#5840)") changed the config layer stored in manifests
from a pointer to a value. This was done in order to avoid potential
nil pointer dereferences after it is deserialized from JSON in the
event that the field is missing.

This changes the Layers slice to also be stored by value. This enables
consistency in handling across the two objects.
2024-08-07 17:03:06 -07:00
Jesse Gross
97ec8cfd4e image: Clarify argument to WriteManifest is config
When creating a model the config layer is appended to the list of
layers and then the last layer is used as the config when writing the
manifest. This change directly uses the config layer to write the
manifest. There is no behavior change but it is less error prone.
2024-08-07 16:58:42 -07:00
royjhan
5b3a21b578
add metrics to docs (#6079) 2024-08-07 14:43:44 -07:00
Kyle Kelley
ad0c19dde4
Use llama3.1 in tools example (#5985)
* Use llama3.1 in tools example

* Update api.md
2024-08-07 17:20:50 -04:00
Jesse Gross
69eb06c40e
Merge pull request #6145 from ollama/jessegross/bug5840
Fix crash on startup when trying to clean up unused files (#5840)
2024-08-07 11:24:15 -07:00
Jesse Gross
1829fb61bd manifest: Fix crash on startup when trying to clean up unused files (#5840)
Currently if the config field is missing in the manifest file (or
corrupted), Ollama will crash when it tries to read it. This can
happen at startup or when pulling new models.

This data is mostly just used for showing model information so we
can be tolerant of it not being present - it is not required to
run the models. Besides avoiding crashing, this also gives us the
ability to restructure the config in the future by pulling it
into the main manifest file.
2024-08-07 10:30:44 -07:00
Nicholas Schwab
ce67706037 Set *.png and *.ico to be treated as binary files.
The change b732beba6 makes all files text files and sets lf as eol. This
will automatically change all files to have lf if they are touched by
git (e.g. via git status). This change cannot be stashed and makes it
hard to work with the repo (rebase and checkout don't really work). See
also #6183.

Here, we set the offending files (*.png and *.ico, but that might be
more in the future) to be treated as binary files and not be changed by
git.
2024-08-07 18:20:11 +02:00
Jesse Gross
685a53534b manifest: Don't prune layers if we can't open a manifest file
If there is an error when opening a manifest file (corrupted, permission denied, etc.)
then the referenced layers will not be included in the list of active
layers. This causes them to be deleted when pruning happens at startup
or a model is pulled.

In such a situation, we should prefer to preserve data in the hopes that
it can be recovered rather than being agressive about deletion.
2024-08-06 23:11:19 -07:00
Jeffrey Morgan
de4fc29773
llm: reserve required number of slots for embeddings (#6219) 2024-08-06 23:20:49 -04:00
Jeffrey Morgan
e04c7012c2
update llama.cpp submodule to 1e6f6554 (#6208) 2024-08-06 15:11:45 -04:00
Chua Chee Seng
d4a7216c82
Fixed invalid option provided not displaying the invalid option name problem. (#6202) 2024-08-06 14:37:16 -04:00
Daniel Hiltgen
a4fdd03c3b
Merge pull request #6207 from dhiltgen/sparse_win
Ensure sparse files on windows during download
2024-08-06 11:06:06 -07:00
Daniel Hiltgen
fc85f50a2b Ensure sparse files on windows during download
The file.Truncate call on windows will write the whole file
unless you set the sparse flag, leading to heavy I/O at the
beginning of download.  This should improve our
I/O behavior on windows and put less stress on the users disk.
2024-08-06 10:58:08 -07:00
Daniel Hiltgen
04210aa6dd Catch one more error log 2024-08-05 09:28:07 -07:00
Michael Yang
43f9d92008 close pid file 2024-08-05 00:41:16 -07:00
Michael Yang
ed6c8bfe57 removeall to remove non-empty temp dirs 2024-08-05 00:41:16 -07:00
Daniel Hiltgen
df3802a65f Adjust arm cuda repo paths
Ubuntu distros fail to install cuda drivers since aarch64 isn't valid
2024-08-01 17:22:25 -07:00
50 changed files with 1245 additions and 771 deletions

3
.gitattributes vendored
View file

@ -1,2 +1,3 @@
llm/ext_server/* linguist-vendored llm/ext_server/* linguist-vendored
* text eol=lf * text=auto
*.go text eol=lf

View file

@ -31,7 +31,7 @@ jobs:
security set-keychain-settings -lut 3600 build.keychain security set-keychain-settings -lut 3600 build.keychain
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: Build Darwin - name: Build Darwin
env: env:
@ -87,7 +87,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@ -141,7 +141,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@ -218,7 +218,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@ -306,7 +306,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get - run: go get
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4

View file

@ -63,7 +63,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@ -163,7 +163,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@ -200,7 +200,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@ -255,7 +255,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: false cache: false
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in
@ -297,7 +297,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in

View file

@ -24,7 +24,6 @@ linters:
- nosprintfhostport - nosprintfhostport
- staticcheck - staticcheck
- tenv - tenv
- testifylint
- unconvert - unconvert
- unused - unused
- usestdlibvars - usestdlibvars

View file

@ -325,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [tlm](https://github.com/yusufcanb/tlm) - [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama) - [podman-ollama](https://github.com/ericcurtin/podman-ollama)
- [gollama](https://github.com/sammcj/gollama) - [gollama](https://github.com/sammcj/gollama)
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
### Database ### Database

View file

@ -298,7 +298,7 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil return &lr, nil
} }
// List running models. // ListRunning lists running models.
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
var lr ProcessResponse var lr ProcessResponse
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil { if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
@ -333,7 +333,7 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
return &resp, nil return &resp, nil
} }
// Hearbeat checks if the server has started and is responsive; if yes, it // Heartbeat checks if the server has started and is responsive; if yes, it
// returns nil, otherwise an error. // returns nil, otherwise an error.
func (c *Client) Heartbeat(ctx context.Context) error { func (c *Client) Heartbeat(ctx context.Context) error {
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {

View file

@ -504,7 +504,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
for key, val := range m { for key, val := range m {
opt, ok := jsonOpts[key] opt, ok := jsonOpts[key]
if !ok { if !ok {
slog.Warn("invalid option provided", "option", opt.Name) slog.Warn("invalid option provided", "option", key)
continue continue
} }

View file

@ -11,12 +11,12 @@ import (
) )
const ( const (
updatAvailableMenuID = 1 updateAvailableMenuID = 1
updateMenuID = updatAvailableMenuID + 1 updateMenuID = updateAvailableMenuID + 1
separatorMenuID = updateMenuID + 1 separatorMenuID = updateMenuID + 1
diagLogsMenuID = separatorMenuID + 1 diagLogsMenuID = separatorMenuID + 1
diagSeparatorMenuID = diagLogsMenuID + 1 diagSeparatorMenuID = diagLogsMenuID + 1
quitMenuID = diagSeparatorMenuID + 1 quitMenuID = diagSeparatorMenuID + 1
) )
func (t *winTray) initMenus() error { func (t *winTray) initMenus() error {
@ -35,7 +35,7 @@ func (t *winTray) initMenus() error {
func (t *winTray) UpdateAvailable(ver string) error { func (t *winTray) UpdateAvailable(ver string) error {
if !t.updateNotified { if !t.updateNotified {
slog.Debug("updating menu and sending notification for new update") slog.Debug("updating menu and sending notification for new update")
if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
return fmt.Errorf("unable to create menu entries %w", err) return fmt.Errorf("unable to create menu entries %w", err)
} }
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {

View file

@ -22,6 +22,7 @@ import (
"runtime" "runtime"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -78,6 +79,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@ -112,7 +114,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
@ -263,13 +265,20 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
@ -279,13 +288,43 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err return "", err
} }
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
type progressWriter struct {
n atomic.Int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n.Add(int64(len(p)))
return len(p), nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
@ -1086,7 +1125,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
return nil return nil
} }
func RunServer(cmd *cobra.Command, _ []string) error { func RunServer(_ *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil { if err := initializeKeypair(); err != nil {
return err return err
} }

View file

@ -27,6 +27,10 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
"tokenizer.ggml.token_type": t.Vocabulary.Types, "tokenizer.ggml.token_type": t.Vocabulary.Types,
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
if t.Template != "" { if t.Template != "" {
kv["tokenizer.chat_template"] = t.Template kv["tokenizer.chat_template"] = t.Template
} }
@ -89,6 +93,8 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
conv = &mixtral{} conv = &mixtral{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
conv = &gemma{} conv = &gemma{}
case "Phi3ForCausalLM":
conv = &phi3{}
default: default:
return errors.New("unsupported architecture") return errors.New("unsupported architecture")
} }

View file

@ -90,10 +90,6 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
kv["llama.attention.value_length"] = p.HeadDim kv["llama.attention.value_length"] = p.HeadDim
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
return kv return kv
} }

125
convert/convert_phi3.go Normal file
View file

@ -0,0 +1,125 @@
package convert
import (
"cmp"
"encoding/binary"
"io"
"math"
"strings"
"sync"
"github.com/ollama/ollama/llm"
)
type phi3 struct {
Parameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayers uint32 `json:"n_layers"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NHeadKV uint32 `json:"n_head_kv"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
LongFactor ropeFactor `json:"long_factor"`
ShortFactor ropeFactor `json:"short_factor"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
NPositions uint32 `json:"n_positions"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
}
var _ Converter = (*phi3)(nil)
func (p *phi3) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t)
kv["general.architecture"] = "phi3"
kv["general.name"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings
kv["phi3.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
kv["phi3.feed_forward_length"] = p.IntermediateSize
kv["phi3.block_count"] = cmp.Or(p.NumHiddenLayers, p.NLayers)
kv["phi3.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
kv["phi3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NHeadKV)
kv["phi3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["phi3.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NumAttentionHeads, p.NHead)
kv["phi3.rope.freq_base"] = p.RopeTheta
kv["phi3.rope.scaling.original_context_length"] = p.OriginalMaxPositionEmbeddings
kv["phi3.attention.sliding_window"] = p.SlidingWindow
scale := float64(p.MaxPositionEmbeddings) / float64(p.OriginalMaxPositionEmbeddings)
switch p.RopeScaling.Type {
case "":
// no scaling
case "su", "longrope":
kv["phi3.rope.scaling.attn_factor"] = float32(max(math.Sqrt(1+math.Log(scale)/math.Log(float64(p.OriginalMaxPositionEmbeddings))), 1.0))
case "yarn":
kv["phi3.rope.scaling.attn_factor"] = float32(max(0.1*math.Log(scale)+1.0, 1.0))
default:
panic("unknown rope scaling type")
}
return kv
}
func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
var addRopeFactors sync.Once
out := make([]llm.Tensor, 0, len(ts)+2)
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasPrefix(name, "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, llm.Tensor{
Name: "rope_factors_long.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor,
}, llm.Tensor{
Name: "rope_factors_short.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
WriterTo: p.RopeScaling.ShortFactor,
})
})
}
out = append(out, llm.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *phi3) tensorName(n string) string {
return strings.NewReplacer(
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.qkv_proj", "attn_qkv",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
).Replace(n)
}
type ropeFactor []float32
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, r)
return 0, err
}

View file

@ -65,6 +65,8 @@ func TestConvertFull(t *testing.T) {
"Mistral-7B-Instruct-v0.2", "Mistral-7B-Instruct-v0.2",
"Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x7B-Instruct-v0.1",
"gemma-2b-it", "gemma-2b-it",
// microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
"Phi-3-mini-128k-instruct",
} }
for i := range cases { for i := range cases {

View file

@ -0,0 +1,225 @@
{
"general.architecture": "phi3",
"general.file_type": "1",
"general.quantization_version": "2",
"phi3.block_count": "32",
"phi3.context_length": "131072",
"phi3.embedding_length": "3072",
"phi3.feed_forward_length": "8192",
"phi3.rope.scaling.original_context_length": "4096",
"phi3.rope.dimension_count": "96",
"phi3.rope.freq_base": "10000",
"phi3.rope.scaling.attn_factor": "1.1902381",
"phi3.attention.head_count": "32",
"phi3.attention.head_count_kv": "32",
"phi3.attention.layer_norm_rms_epsilon": "1e-05",
"phi3.attention.sliding_window": "262144",
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.pre": "default",
"tokenizer.ggml.add_bos_token": "false",
"tokenizer.ggml.add_eos_token": "false",
"tokenizer.ggml.bos_token_id": "1",
"tokenizer.ggml.eos_token_id": "32000",
"tokenizer.ggml.unknown_token_id": "0",
"tokenizer.ggml.padding_token_id": "32000",
"tokenizer.ggml.scores": "6e37bcde2adc7e350e87c496eddd7a2124329c1dc66c5bf3ad3997253e4f7a62",
"tokenizer.ggml.token_type": "b6ecf55ec64ee67d87750bdb8d757a2c58bf78377e9f4219f5689a6c4dea57ce",
"tokenizer.ggml.tokens": "d168da3ddd3eee820916945fcb9baf24dd3cde42f606cffa2d19e7c8a8743918",
"blk.0.attn_norm.weight": "216aeb2c9e0c271f899e1ef2a63cceeb8f41e97642e84fada54b1d3c1c11cf25",
"blk.0.attn_output.weight": "b597d56f7188ffc1fafc273fadc59d41738cffd677ae98c61a62c3285b3a3099",
"blk.0.attn_qkv.weight": "d28a6b44e13f59be5483e4be2bedb544e346168d720aca27f47d1a5a722be91e",
"blk.0.ffn_down.weight": "4a691370e5a61fcbbf540fbcbf4c0f1d15dec0364528c0e916d0744f6262b63b",
"blk.0.ffn_norm.weight": "0c00af2b4a3128bec64a0cbb1084b042fdbe13d9ad0d03bd577f9449dfead338",
"blk.0.ffn_up.weight": "b32b52f790c1c083bfb8a3126dc1111cfeeb28dc8c584a930a1e5334cb176bf4",
"blk.1.attn_norm.weight": "68748011503c6c029e8e69a84a8e5a89338f378769627b6dbf7f93d715c292e1",
"blk.1.attn_output.weight": "2267344add13b048ca59e4377c86dc512be8046a57156901fa32a20fa74e4ee0",
"blk.1.attn_qkv.weight": "9109d2e3d7a2eacfda5226587b8be124a3bf44b972da7ebb17aa15795897eacc",
"blk.1.ffn_down.weight": "d675df4df4dd039c0c339ad6445d39eddd2004db6bf35bed6314c7497245a633",
"blk.1.ffn_norm.weight": "3b5767ae977bc8baaa06b06efdbea193b6b3ba605ce76d77a76ce317e935500c",
"blk.1.ffn_up.weight": "80dfd6d9d234b00334c89b8e0a02f81899c2efd377321c34ba5ba51a5f61b5ff",
"blk.2.attn_norm.weight": "6a6743b057e5088f145bc179e92c9bfb41163e7295d7b81c62e23dd89d2b59c4",
"blk.2.attn_output.weight": "bc5491ea54e0db81462d7d9b7d25cbdda380c2db8de041bd1c4ab7b76a1d19c3",
"blk.2.attn_qkv.weight": "a61287a9852e2f5aca9c100b471d98398b2913a3497c743de3c70ec9ddd7087f",
"blk.2.ffn_down.weight": "4fddcc382c8dceeab027fe43d8d44e67edb5e8ce4b9a1b7f773c87770380ade1",
"blk.2.ffn_norm.weight": "07e05f82b3f63f711db3b684ca79aed25c0657917e66f88af47348a82065c227",
"blk.2.ffn_up.weight": "4835a682ef1826c12df01ae7663fc45f9c82bc8e64b665f13fb7da8e201ec0fb",
"blk.3.attn_norm.weight": "f22aba7c03999ba7136f39cda747a39715e498699dc1716cd97fc5dfc58d1b1c",
"blk.3.attn_output.weight": "53b579855366fd786c5126b2b30aac4d583ca7bda56833c4865f5cadb5c18c6d",
"blk.3.attn_qkv.weight": "bb56aba78158123140fcea59c69ac562ca208f6d3086819417cdad8c50f333ad",
"blk.3.ffn_down.weight": "97280897a7cd86db2830c004bccc5bc094f50e293baded0189159a2019145a6e",
"blk.3.ffn_norm.weight": "10a8c99f8b57a960e8e0a1133c4a26f9148403d1b9bff2eff114917de996f3b5",
"blk.3.ffn_up.weight": "7324046c915e75d621b2043597a245a428d8eea31869135e6257a861491d8dcc",
"blk.4.attn_norm.weight": "507d8e164de94646edbfe33def8e8fbf7c9a6ee3fbaedb5000f72d9f51ec5e36",
"blk.4.attn_output.weight": "bbb3429e6efa98c150e0fdbf48c16180cbf0d0cbc1b3c253c6c319d78f4593a2",
"blk.4.attn_qkv.weight": "b95ee5be0786d3901273d806c339fe6c20e6bfffd2a20672a9f56af80921e8ab",
"blk.4.ffn_down.weight": "806bbf91df92a5a22bd5aa1ffb7fc2869f7293ffc7704771c290ecc583b27975",
"blk.4.ffn_norm.weight": "cfc2930a81df7aee3a5e7f726a15c1182233e868bf0d9d37f6b6ae6d8c15c234",
"blk.4.ffn_up.weight": "c3390c69533de2c8424e8069323ccc5d0c4543111535da04cf2c7d26745576aa",
"blk.5.attn_norm.weight": "0d71c4fbcefabbd021569442853d2fe90668b19409ae2805a718a829ca60beab",
"blk.5.attn_output.weight": "10ebd93629112bf2df5c30dd0953a4a5e9020306768283181ed426934d47e14f",
"blk.5.attn_qkv.weight": "5cb05633369f12d4b00e0ff787736bd846856682115720ebc6cce05270c334f6",
"blk.5.ffn_down.weight": "e28bcc5094212eafc7476dbc5b7a520d25b79578cbf4229d698e2655956a80ad",
"blk.5.ffn_norm.weight": "b6f2c4cf9f34bb4d59989f96165c14a67dc1e266ad0a6d0fcc49f1add929e6ff",
"blk.5.ffn_up.weight": "0f9ef99423cc07ebedc0e9cfa95809f2d7108d910bb4ef97ebc0b0309c440750",
"blk.6.attn_norm.weight": "b3edcc47a42218234f7564d7470611b49401a41ae8cd42123f86557c69f5d7f2",
"blk.6.attn_output.weight": "eb9b7d257b388bb5b8fe0515e5c6873317239cb94cda236e4b6ada2a6c57c65c",
"blk.6.attn_qkv.weight": "eb968081f478c52f07bd9c2761741e982dba33cc4eeadeea3557d391b9ac2106",
"blk.6.ffn_down.weight": "1b8588bb7463206290322695577dcfced300895d6e6f4b26966c53a9ae2f0f84",
"blk.6.ffn_norm.weight": "1219c04b7770983c77814200eefe743f46d15328ea2b12711e44f8103eab08d3",
"blk.6.ffn_up.weight": "197ef287239fec47c55677f0fbb66eaf0644f775bc382de843971730721394f6",
"blk.7.attn_norm.weight": "b630ad08c80d564ed1c024384818e9fd3f22a36cd7a14aa96e7e2759a8285099",
"blk.7.attn_output.weight": "970255aa750828a47d6b9d399f9612b5bf25aefe7dadbcba41fc416d0d4067c1",
"blk.7.attn_qkv.weight": "ebb157c880293e6de8d629f263ba8853ed1dbdc02c311d43432bb8cfbb310739",
"blk.7.ffn_down.weight": "24bcd4db4cba844c89f878b81843c373dbbc0675e889d32c5b12e63384a7b670",
"blk.7.ffn_norm.weight": "b9c6f71001808ee873ce7db8056e4b53fb4cccec8b7f0f312899b575fae39d39",
"blk.7.ffn_up.weight": "979f1828d227455c26015a2a11afe9dd05f2bb97a8ba6b38c8dab3f50e627401",
"blk.8.attn_norm.weight": "4e8e347e3775010b7112ee630f2f4f2383be7ff64e6ca6154b9b22566552eaa6",
"blk.8.attn_output.weight": "65a44babf44a435a1829945211b3168f9ec78ac3cb7a049a733e93d11f0d6659",
"blk.8.attn_qkv.weight": "343ed07671da400b040812a4058482fa38284b5d9af9becfed07417fe26ce747",
"blk.8.ffn_down.weight": "7fb7e073e3c2c503c4e9d60efa0988fed7398d900cc003695fe3fffd3e188b82",
"blk.8.ffn_norm.weight": "b07c1f655d8593e3892a2cf73f8a0c19ce8e5cb613fafbe7cbd430da8ce4c57d",
"blk.8.ffn_up.weight": "8b26e14de54b3fdc2e2d3ea41720f9d9c236a93688c3b7fd7bf43f5fbb327c9b",
"blk.9.attn_norm.weight": "46394d408a8e316916177e6aa261de32e137a82d729c0b1800b072f0c38c39b6",
"blk.9.attn_output.weight": "d57f3d46107947a7073373a0b35d6ecf7759b5df15406f4a3590a60666af6b16",
"blk.9.attn_qkv.weight": "14bb8ace8c5453148f4b536e9f4279c813f31136716947256f5cca333448639c",
"blk.9.ffn_down.weight": "2b8d98e2b5ed68338f6e4de43bf7de0c4858cc69103cd5177725f7444eec7694",
"blk.9.ffn_norm.weight": "41a499dfd418cc4c6b8c12313f673f7e2cd4a3f9c4065eb6c4feb5eed02fb542",
"blk.9.ffn_up.weight": "143aab7533a64b17fbe201490a6f674bc7f0bd370c094500b2e100419073d1c2",
"blk.10.attn_norm.weight": "ebb670aafd36816a794347287269d8f1a5b19c1e3c0a1e38023bc19fdba9b073",
"blk.10.attn_output.weight": "b5d65bbc0ed5e49fdd9d754bc18163cd042a285024d0cf6f954c503bc8c877cb",
"blk.10.attn_qkv.weight": "f06b15bac88da798fa34a62b03eaac0dbe8b846020516603c387541f2d8dd672",
"blk.10.ffn_down.weight": "fb091fcd1b4de25d1bea94d1755e255cb02914a030d23e3a234e57b8d46bde6e",
"blk.10.ffn_norm.weight": "eb347bdf9c40414af87e13a8e72e40b31f004b50f7cb366f1a219ced60a61355",
"blk.10.ffn_up.weight": "ed2d52fc881a173f404fe8a1067862c9856d6c3e0d2e90a330a7aa394e3f84d1",
"blk.11.attn_norm.weight": "64e252603cf010a0e502ca39fdf8d0a196a79aec67c0d2bb9213fc0cb80c47d4",
"blk.11.attn_output.weight": "228e33e21c69f52efc74fdfc831bc9af271e44b2a29a3dced1d64e667ce36eb5",
"blk.11.attn_qkv.weight": "ab9ce6d4ef9e42ee0da3f20a7708a3bbc5e79e967b05fa86ba946a05e2eb63eb",
"blk.11.ffn_down.weight": "0ca133b7835c98dc77c25d64e4eb7873778bdb5e4d22d8b80f920f46865b43bd",
"blk.11.ffn_norm.weight": "02455741a0dfd161c79aa1ecc381901721f229fdcda5615622a629631fb61cfd",
"blk.11.ffn_up.weight": "9fecdcc099fbb8e23c6b1ea9294702a027f4a58d265543ec5e7be79b8f63b354",
"blk.12.attn_norm.weight": "783bb459911b1b3609a9b2bdfe272f1670add73b5471da738e07ac47e2e07dfd",
"blk.12.attn_output.weight": "1e1a914c9e48b857206ac5a1f7cead994bc1ea91d5d4fff8c834d73f2e38ef5d",
"blk.12.attn_qkv.weight": "5953e7185ccb87fb4dae8f9426ec86315d4c7794326e8ab59b3a95d4af2189f0",
"blk.12.ffn_down.weight": "a3eecf0f394f86e2cfb48a5940a5c50ca86d71883b2f79fcc642a935fabce0d4",
"blk.12.ffn_norm.weight": "0a4272e41373c23bd72f10d2d82930aa3a1480aac75832bfbf01cebf0b86b6a4",
"blk.12.ffn_up.weight": "06f42776de3a7ceac3025f26a7a8bd20e062233cce2bdaa2183470dc4b30b87d",
"blk.13.attn_norm.weight": "5915da60fb03e201fa649faba780e5fdf1c761c262b206e5415cf83181f65780",
"blk.13.attn_output.weight": "4dbf6eab074fa3835fd32bd631a8208e511037d5056d2fd3015735cca7674ef7",
"blk.13.attn_qkv.weight": "d3d8339a1c4782d9e73d77fdebe154d3c5b83ac40c9175b3e91a4977d08f876b",
"blk.13.ffn_down.weight": "de6772b46a55e1fd42b007637dfbf68b6598e5d5b61622da0935002e1e192d3a",
"blk.13.ffn_norm.weight": "5a640ea3b8c7be49c95a58a2327e10d8e8d9d142504bde5c8091613e5b961d7a",
"blk.13.ffn_up.weight": "f35e3545e4bd3531b2e843b5efd31dee0c13c807ee6386e65473ba67bbec30d0",
"blk.14.attn_norm.weight": "9b34986450b7c98b4927e81e61a816f9e84b1addc7c14926402100037aad6678",
"blk.14.attn_output.weight": "155d52efb23d366016d861a251d4d1f4a0c13699188c50d50dba016a0d8bfcd9",
"blk.14.attn_qkv.weight": "8e1415084e1f33c73a777f19e752489f4dd312cca047733e5ea643cd4a955e04",
"blk.14.ffn_down.weight": "a2a142226b94baa01ccb65bdea2b7418e49085c1d9c3c63e544e3112c58a25da",
"blk.14.ffn_norm.weight": "8aecfd9b0ae6affaea31a80c5c9a4a14b31deaa0db7bd8f6da2a64d23447921c",
"blk.14.ffn_up.weight": "0c1407237b8c1bd02f193346b5681926fe698a5055eac6a7450451b0f991707c",
"blk.15.attn_norm.weight": "e037bd19880bfa83d983200fb0c7866f8ad16c3ff5cc4b4f3a37ca7373870ff6",
"blk.15.attn_output.weight": "045fe4fc95cc129a1b92771b179c11b12845c4c088786c607f17bd98857e68e1",
"blk.15.attn_qkv.weight": "7621b7559705cab1d4dea1c69f76dbf9dc1c8837a203b656f484703b9c1b70ce",
"blk.15.ffn_down.weight": "7e5ac20e290bc60761e1cd972354fde225b7fa861048d44d9a0dd9b046d55f58",
"blk.15.ffn_norm.weight": "b6d830d88f1db1825687973c8c2b1a24c6fa84f07af8d0e3ef9c86009baca0b2",
"blk.15.ffn_up.weight": "dcda0957cd04fc45476774dba2bbf9aa89d6b05d5ca7b10ae6f73ad2c49b1cd3",
"blk.16.attn_norm.weight": "4ee9b70ba15cb2a08240f93990e90f5068c48fceb481f8e2186bec8b7214eb3f",
"blk.16.attn_output.weight": "315cfe5536658d2498192b2980eade15b2c9a4ff220e4011911457b1727fa103",
"blk.16.attn_qkv.weight": "3c8122e3ad637583b9dcde8ff3a323267d3014bb1f0f9771e5322260ca9ecc8d",
"blk.16.ffn_down.weight": "3b5fbebd5ee2b86cad96fb8a9b45a8770d08f82c1c8b74d7061e866f7020a18d",
"blk.16.ffn_norm.weight": "ffab69f20bda372de6e5878f0539163e2fc6ba113621ded95705fc3b1465c9f0",
"blk.16.ffn_up.weight": "0935ea3d258da42d6258406365f39f58ddaabfe97ea5977580db3635188f24a1",
"blk.17.attn_norm.weight": "f030441733f3d147b4a06a1eb4aeb8465c7c24d9c53bf4c48fe7e134d3629803",
"blk.17.attn_output.weight": "07a955ef09e8dc766ac0df647d0b2c69f23c4c69a7137654b4aad80303ed0eda",
"blk.17.attn_qkv.weight": "1c10688061e21e2fe12ad0cb54bf03895c1f83c3b0df743a42f548b52cbca1b2",
"blk.17.ffn_down.weight": "ebb9cc9836f41d88fdae2aa9a4355514e4edaec8d1577ffeb947a35204e77f52",
"blk.17.ffn_norm.weight": "50aff44f6528b13db5389f2ddcdb7676244947610bd7ffbff3f881c968c2a0d4",
"blk.17.ffn_up.weight": "d716537949582be33bde6b02e38f5a70081c9642a9fb05a61312126718b8d148",
"blk.18.attn_norm.weight": "0ea695c4e53d637902f46663a6ee42adc493c36794476acc7dbddaa05b13840d",
"blk.18.attn_output.weight": "5fd35b500221a612eb4f4bddf0e9b6b7db4d7733032a75f8802fb2d884647c2e",
"blk.18.attn_qkv.weight": "b0da37fd030fe69581f990bf23bfd35467a1bbe558af6de7c0924f6b72e92317",
"blk.18.ffn_down.weight": "b355c33f44b328f4bb977567de8f7544db4b005d7a8fbded658518ecf3c5a153",
"blk.18.ffn_norm.weight": "58b3fe9094079989a86e0387143259e1cc35952d24dc3df290c4ba6df44f5c51",
"blk.18.ffn_up.weight": "2ce530954c342c30ed2ead5353f931960bfae1d278868504c0efb973560fabbe",
"blk.19.attn_norm.weight": "533e9aed66feea8f0392aa81f9e293240e1f009a5334253915fb60c2749b615d",
"blk.19.attn_output.weight": "84f2d00f98a4113a779d3b5d1c3e7c914eb47784d3ab13b290367c124c2994aa",
"blk.19.attn_qkv.weight": "fbe6b9f53b07fa7537d3b3d452d20a9bc666f9fd41ec2091dd28bc2f70fc668f",
"blk.19.ffn_down.weight": "b30199e098c8bb3f890183d8b18471e80b62b604729b277ad62488dd71e1206b",
"blk.19.ffn_norm.weight": "c81373e41cd340b7badb19f9517c77c4250b4eb9a02dc758b8b49b652487d7ff",
"blk.19.ffn_up.weight": "5a5cb083ca7725720e3a890f7fa46354760e8007a8188849a092e305694a75e3",
"blk.20.attn_norm.weight": "4953091b4477e354357a8e743ba0a1900633e52f1599ee082a0c9b0b2b5cd978",
"blk.20.attn_output.weight": "62d54f7749cd6856097b2632066a322b0296df915fe66f382c5b5981be0d4f23",
"blk.20.attn_qkv.weight": "406de9e35b0729ebe902d7a47905cc7fb29a921431ed35dbef0c03e5690a1329",
"blk.20.ffn_down.weight": "62fb678b0d1261e19a4903a2b347d67afcc8acff01feb33a687a35a2d1e6f9a5",
"blk.20.ffn_norm.weight": "cd9d36b7e71e55c8925b97bb09c28219f182626bcff094878ae39c3db887a14b",
"blk.20.ffn_up.weight": "b9276771d79d3e932e73ccc520c3f8476342b9ef312ed2ee1e0da822e6e3ad18",
"blk.21.attn_norm.weight": "66d8c8a35e13ce9c2a0e75b670150e2c31484a55c2316df46075312196178ed3",
"blk.21.attn_output.weight": "12ab46c9382648f9b3350fdd92a6be6352743d62d6b520d7e2024e0c838588f5",
"blk.21.attn_qkv.weight": "a7909676ee1675ca23cd29a5fdd226df8dd9d68f94c6c9bbb51dd9fd38504008",
"blk.21.ffn_down.weight": "6fb317279c6542e82f97d5a12a60fac1bd0fa0405154f9fbe265e2fe39bd49cc",
"blk.21.ffn_norm.weight": "c0f703eb3ff161b5ba4490d87d8684b8a6c47a8f433e12f418333b9db439010a",
"blk.21.ffn_up.weight": "6dbdb80ef0c35e364bbce12d40d5e74c7963c7b55d58d9579567a07ffce7b863",
"blk.22.attn_norm.weight": "f94237433bf03d675cb2f655b81ca91a1ce2447bc6b00b13d6b0ccfe2d411eff",
"blk.22.attn_output.weight": "e821f95995ce497c01e63ca64f737713b1b65f11df1903e51d444aa516f33f71",
"blk.22.attn_qkv.weight": "1b0f717c73afb5eb4c82a1708c4e85c969e8a2a8770d9ddb78b1870a2d8a781e",
"blk.22.ffn_down.weight": "0f33f7a3cdc685484be99aa0c03642b0b20850a27d1fddbe054b13a9382f3ccb",
"blk.22.ffn_norm.weight": "9df285cf211ddd7df2b36a50489af574755c7d4d98b29a05cd04566ae613c8dc",
"blk.22.ffn_up.weight": "63ac300e1efb34041dd0136cf43ea622fac6f0caccce1cd9262f5e08d2cf179c",
"blk.23.attn_norm.weight": "5f72d9e88689b4027b28f5f8f26cd3abb03635ceea7ec98a4c91a9fc691f6707",
"blk.23.attn_output.weight": "6ecf04ff61125c5fc768f8656497152149373daf321ee9c957e8f7245a1184d1",
"blk.23.attn_qkv.weight": "a9d9978806724c2959f2cf386c233831f08e1e933dbf2b32665e788d9d512ea4",
"blk.23.ffn_down.weight": "72c7d17886a3da17fa0daa456aa5e877b2ef5b8b403182b870d9ca5ca9c70347",
"blk.23.ffn_norm.weight": "971e4b712e3025a13419b5b57d674b5e4ab7f18f74b57b9afc4671623da90c4b",
"blk.23.ffn_up.weight": "df2b5c7dbd5834545b815073af0c7355b065124e6d6f0fee78d8fa5b2076dc3e",
"blk.24.attn_norm.weight": "c41957c4a79ad3b16f6e11daec1c7f530b9f3f4b618e1e4367c3b67787ac4ab6",
"blk.24.attn_output.weight": "ef7d61f5fc88ac6f31bf60cb5f4d2d6b8df42d38825807112361a7224b0dee3b",
"blk.24.attn_qkv.weight": "3e6a58fe7d49c90bb6971efbad3371c32256881173ea5aee4b0c296cb206490f",
"blk.24.ffn_down.weight": "f43619144047de42fed81dfa495f1815d3cb771330e574043e2b67620819292c",
"blk.24.ffn_norm.weight": "5501d4a2a98c8ca6b42e77b53b221dbc08f530f6a067256d787534ec6fe028bd",
"blk.24.ffn_up.weight": "d64c8b0e509e2b1118f6000176f8956cacecdbb200c7e95ed93fb78b6e26c84a",
"blk.25.attn_norm.weight": "502fa3c302d371f61c5791f4615b73018ffb1daa09b6499b227116581244c5d4",
"blk.25.attn_output.weight": "ad8391d4e9c980856f2547aa945b2b6a407a6382158dc1ddd4f08d94ecc24be6",
"blk.25.attn_qkv.weight": "42e8983780d4a01a02c54ad23d4df21eea437f119a10af5a9c12a76a42d308c1",
"blk.25.ffn_down.weight": "302dd010d4e0ab4eeaee89090409ea0dddeeeed3236415eb8f97c942497eea91",
"blk.25.ffn_norm.weight": "fb34c1ee5bca96986c08834df0a0c047ba041c1123ac1f563e9d64312bf82d6a",
"blk.25.ffn_up.weight": "10739a8de156816d93c92b935386540bfa976bdbef204f0312960f6fc657582f",
"blk.26.attn_norm.weight": "7036c711609128c4e55968ff3681d3043338879a5737efd6c2ac9e1a2a61f1a0",
"blk.26.attn_output.weight": "db5db45dead5cb911fa01da59832f121b7c18b2d167bf53741c40819f24d346c",
"blk.26.attn_qkv.weight": "cae34c6b7f82ed14348d5ed30a79919c383737c1694a9cb9c0de609d3b0c1d0a",
"blk.26.ffn_down.weight": "491ec3a4da9b4f49f8ebc6be658ce397a9b801ae9fb35e82177e47808c65e5d0",
"blk.26.ffn_norm.weight": "fd7059d75d7f0e5288511ddeeb0f772eb3cae3ccfe4226b877015834edc3c386",
"blk.26.ffn_up.weight": "ea1ee1274c56458ce056d2205e5bb6e5422ce4cb0ad58006b8141749b97a0c39",
"blk.27.attn_norm.weight": "cc362c9a937609265052cd38544af17a1a7448cea086d4c801139e1fc865832d",
"blk.27.attn_output.weight": "ba757a81dabde9cb1b069d1bb616fe79649a1724f756567ec61caed1304fe6cf",
"blk.27.attn_qkv.weight": "1ab8d7d02d87756c12c2275636823aa5ede3d683178225c4cac4bd892c319bd4",
"blk.27.ffn_down.weight": "deb1c711c8a66acf4dcd2d088e1548f8e08f296f755e4067d6557fa55afde88c",
"blk.27.ffn_norm.weight": "fc6242d8cb8a4a37a8ddb7e41e7e60a63d4a89edf36acb35df052f10b9c91ece",
"blk.27.ffn_up.weight": "8df39b09c4801f343aca78f2918a1f6db78c8c55e591eda4c69eadb74c26e180",
"blk.28.attn_norm.weight": "75b539308f77e3cefdc6d98484d8b5cbf0538f0c2869a77b7373a145a18bc850",
"blk.28.attn_output.weight": "ae128940eb60a6d2e121762ef4b3e9dcf9eb3e105b249507fa7f12de0e19822c",
"blk.28.attn_qkv.weight": "bdda781c288e9326c240e33905f8e621b6a2ad902e620739d34f93fcd6f933de",
"blk.28.ffn_down.weight": "f1d6e6d1c286b1138bfd7e53fe477f399ae93bc2c04e35416f84218ed7247965",
"blk.28.ffn_norm.weight": "3f837ce82c8b9bde0d61d08b6f5fe5574886ea5328dbdc53f2929f18da8b4087",
"blk.28.ffn_up.weight": "2af027002e31d1b6cfedbdb30a2b9d7213f3aa691167c353913adfd48fda31e4",
"blk.29.attn_norm.weight": "61e8003b5329462ffe0fe172f2b160260de006aed858332d49d75504b6b6aa7a",
"blk.29.attn_output.weight": "ca44542a72a37476dc73dbdcc01f5b7497cb3ebc4ea230a55c9634ccd8e56ad4",
"blk.29.attn_qkv.weight": "abb3d9d6abe57872ae3daa51935d43264093ded5ce63b49d1e280ee5758be0e4",
"blk.29.ffn_down.weight": "6764b895fce881df097489c263446f0106de36217997660c15984b3ee22a5a06",
"blk.29.ffn_norm.weight": "89e03e9a33fc0e6e31ba9f0c2bd7c5734a118c5602bb90148793e08a80e8d0ae",
"blk.29.ffn_up.weight": "fa7ad57a84954f4121653152efed1a871d8adb20a1ea9086e3e849ce359d7d2e",
"blk.30.attn_norm.weight": "91a697aca1e42af54f806a20211031c3369e8d0bd58df1b0147fe24954e1f5a4",
"blk.30.attn_output.weight": "36063fcf766c89ac75be56f688cc63cefe5f2c733fbf4378ea9956ad386fa148",
"blk.30.attn_qkv.weight": "2cacd1161f1121a2c0b979930134f4666f73fb8d7237b3b0659ae091b15955a6",
"blk.30.ffn_down.weight": "9f3fcb6217100595850c05dc98f9ab2a263afdb6ab28df2fcb08aeff512057d7",
"blk.30.ffn_norm.weight": "6c600bc1fc7de39d4f8917b81fc7d1d5ed2a9b56492234c13a4bd6028c30d880",
"blk.30.ffn_up.weight": "73cabd1bb011956b2689ea3338bb76642ef3a57c197377d666d2ab5f56317668",
"blk.31.attn_norm.weight": "72d3e1cc771380645fa75a899858c95f39857a4f3f1ed60fe1578df383b8bc53",
"blk.31.attn_output.weight": "40089cdd29994dc19a1d89fa15902a89cfeca3540f12dc9bf4d00ef82506e456",
"blk.31.attn_qkv.weight": "1d0bb40e9258071ae14290a53c619a8e331dda07354d2a02ef45766c029ae5e4",
"blk.31.ffn_down.weight": "8defa0e06335b793fa8be03883f0a322d6c5b33f52c69c943c35c60d16e42c0a",
"blk.31.ffn_norm.weight": "33c55d9d0c496ccfb130361fe131649346e098abaaac39c0519507e5d846721d",
"blk.31.ffn_up.weight": "599f6503f61c692c1f82001973d35119f9688db5e6be9d9c298411491c93f09b",
"output.weight": "14b8dc662bfa3308ebb2e102c562d8e52c15670e538f20f3216a9c310ca9dd41",
"output_norm.weight": "7f2294ba94ce65681df6c7ddd8698799199b9d77dc83c10bdad5c3999f0fdb82",
"rope_factors_long.weight": "e34d378664e354652c38f47d10dafb0498ccc2fb042d39ff7fef768146fff22b",
"rope_factors_short.weight": "9379146a4988f373d362fe47b06c75e7fe7c54aa4dc9558758df79b7a87471fd",
"token_embd.weight": "19a03c1fb5ac0baee93b0a7d8b0f26e9a9b011e229b694afc50ebfc13d84f8bf"
}

View file

@ -669,7 +669,7 @@ curl http://localhost:11434/api/chat -d '{
``` ```
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "mistral", "model": "llama3.1",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -708,7 +708,7 @@ curl http://localhost:11434/api/chat -d '{
```json ```json
{ {
"model": "mistral:7b-instruct-v0.3-q4_K_M", "model": "llama3.1",
"created_at": "2024-07-22T20:33:28.123648Z", "created_at": "2024-07-22T20:33:28.123648Z",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -1175,7 +1175,10 @@ curl http://localhost:11434/api/embed -d '{
"embeddings": [[ "embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814, 0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348 0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
]] ]],
"total_duration": 14143917,
"load_duration": 1019500,
"prompt_eval_count": 8
} }
``` ```

View file

@ -16,7 +16,9 @@ If the model being imported is one of these architectures, it can be imported di
- LlamaForCausalLM - LlamaForCausalLM
- MistralForCausalLM - MistralForCausalLM
- MixtralForCausalLM
- GemmaForCausalLM - GemmaForCausalLM
- Phi3ForCausalLM
```dockerfile ```dockerfile
FROM /path/to/safetensors/directory FROM /path/to/safetensors/directory

View file

@ -182,7 +182,6 @@ curl http://localhost:11434/v1/embeddings \
- [x] Reproducible outputs - [x] Reproducible outputs
- [x] Vision - [x] Vision
- [x] Tools (streaming support coming soon) - [x] Tools (streaming support coming soon)
- [ ] Vision
- [ ] Logprobs - [ ] Logprobs
#### Supported request fields #### Supported request fields

View file

@ -112,15 +112,9 @@ Keep the following tips and best practices in mind when working with Go template
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2. ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
```gotmpl ```gotmpl
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }} {{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|> {{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
``` ```
### Example Tools ### Example Tools

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/ollama/ollama module github.com/ollama/ollama
go 1.22.0 go 1.22.5
require ( require (
github.com/containerd/console v1.0.3 github.com/containerd/console v1.0.3

View file

@ -49,13 +49,9 @@ func PayloadsDir() (string, error) {
} }
// Track our pid so we can clean up orphaned tmpdirs // Track our pid so we can clean up orphaned tmpdirs
pidFilePath := filepath.Join(tmpDir, "ollama.pid") n := filepath.Join(tmpDir, "ollama.pid")
pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm) if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil {
if err != nil { return "", fmt.Errorf("failed to write pid file %s: %w", n, err)
return "", err
}
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
return "", err
} }
// We create a distinct subdirectory for payloads within the tmpdir // We create a distinct subdirectory for payloads within the tmpdir
@ -67,37 +63,44 @@ func PayloadsDir() (string, error) {
// Best effort to clean up prior tmpdirs // Best effort to clean up prior tmpdirs
func cleanupTmpDirs() { func cleanupTmpDirs() {
dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*")) matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid"))
if err != nil { if err != nil {
return return
} }
for _, d := range dirs {
info, err := os.Stat(d) for _, match := range matches {
if err != nil || !info.IsDir() { raw, err := os.ReadFile(match)
if errors.Is(err, os.ErrNotExist) {
slog.Debug("not a ollama runtime directory, skipping", "path", match)
continue continue
} } else if err != nil {
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid")) slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err)
if err != nil {
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
// No pid, ignore this tmpdir
continue continue
} }
pid, err := strconv.Atoi(string(raw)) pid, err := strconv.Atoi(string(raw))
if err != nil { if err != nil {
slog.Warn("failed to parse pid", "path", d, "error", err) slog.Warn("invalid pid, skipping", "path", match, "error", err)
continue continue
} }
proc, err := os.FindProcess(pid) p, err := os.FindProcess(pid)
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) {
slog.Warn("found running ollama", "pid", pid, "path", d) slog.Warn("process still running, skipping", "pid", pid, "path", match)
// Another running ollama, ignore this tmpdir
continue continue
} }
if err := os.Remove(d); err != nil { if err := os.Remove(match); err != nil {
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err) slog.Warn("could not cleanup stale pidfile", "path", match, "error", err)
}
runners := filepath.Join(filepath.Dir(match), "runners")
if err := os.RemoveAll(runners); err != nil {
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
}
if err := os.Remove(filepath.Dir(match)); err != nil {
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
} }
} }
} }

View file

@ -305,38 +305,41 @@ func GetGPUInfo() GpuInfoList {
// Intel // Intel
if envconfig.IntelGPU() { if envconfig.IntelGPU() {
oHandles = initOneAPIHandles() oHandles = initOneAPIHandles()
// On windows we bundle the oneapi library one level above the runner dir if oHandles != nil && oHandles.oneapi != nil {
depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
}
for d := range oHandles.oneapi.num_drivers { // On windows we bundle the oneapi library one level above the runner dir
if oHandles.oneapi == nil { depPath = ""
// shouldn't happen if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers)) depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
continue
} }
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := range devCount { for d := range oHandles.oneapi.num_drivers {
gpuInfo := OneapiGPUInfo{ if oHandles.oneapi == nil {
GpuInfo: GpuInfo{ // shouldn't happen
Library: "oneapi", slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
}, continue
driverIndex: int(d), }
gpuIndex: int(i), devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := range devCount {
gpuInfo := OneapiGPUInfo{
GpuInfo: GpuInfo{
Library: "oneapi",
},
driverIndex: int(d),
gpuIndex: int(i),
}
// TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DependencyPath = depPath
oneapiGPUs = append(oneapiGPUs, gpuInfo)
} }
// TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DependencyPath = depPath
oneapiGPUs = append(oneapiGPUs, gpuInfo)
} }
} }
} }

View file

@ -403,7 +403,9 @@ struct llama_server_context
} }
} }
std::tie(model, ctx) = llama_init_from_gpt_params(params); auto init_result = llama_init_from_gpt_params(params);
model = init_result.model;
ctx = init_result.context;
if (model == nullptr) if (model == nullptr)
{ {
LOG_ERROR("unable to load model", {{"model", params.model}}); LOG_ERROR("unable to load model", {{"model", params.model}});
@ -1221,9 +1223,7 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"id", res.id},
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", std::vector<float>(embd, embd + n_embd)},
{"timings", slot.get_formated_timings()},
}; };
} }
} }
@ -2422,7 +2422,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.emplace_back(argv[i], 1.0f); params.lora_adapters.push_back({
std::string(argv[i]),
1.0,
});
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-scaled") else if (arg == "--lora-scaled")
@ -2438,7 +2441,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapters.push_back({
lora_adapter,
std::stof(argv[i])
});
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "-v" || arg == "--verbose") else if (arg == "-v" || arg == "--verbose")
@ -3186,41 +3192,17 @@ int main(int argc, char **argv) {
prompt = ""; prompt = "";
} }
if (prompt.size() == 1) {
prompt = prompt[0];
}
// create and queue the task // create and queue the task
json responses; const int task_id = llama.queue_tasks.get_new_id();
{ llama.queue_results.add_waiting_task_id(task_id);
const int id_task = llama.queue_tasks.get_new_id(); llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(id_task); task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(id_task); llama.queue_results.remove_waiting_task_id(task_id);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
responses = result.result_json.value("results", std::vector<json>{result.result_json}); // send the result
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
return a["id"] < b["id"];
});
json embeddings = json::array();
int prompt_n = 0;
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
prompt_n += elem.at("timings").at("prompt_n").get<int>();
}
// send the result
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
}); });
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View file

@ -157,6 +157,14 @@ type Tensor struct {
io.WriterTo `json:"-"` io.WriterTo `json:"-"`
} }
func (t Tensor) block() (n int) {
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
return -1
}
return
}
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
switch t.Kind { switch t.Kind {
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16 case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16

View file

@ -532,15 +532,14 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
} }
} }
slices.SortFunc(ts, func(a, b Tensor) int { slices.SortStableFunc(ts, func(a, b Tensor) int {
var i, j int if i, j := a.block(), b.block(); i < 0 && j > 0 {
if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 { return 1
return cmp.Compare(a.Name, b.Name) } else if i > 0 && j < 0 {
} else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 { return -1
return cmp.Compare(a.Name, b.Name) } else {
return cmp.Compare(i, j)
} }
return cmp.Compare(i, j)
}) })
var s uint64 var s uint64

@ -1 +1 @@
Subproject commit 6eeaeba126ff701f3e8f79f246805b7023709972 Subproject commit 1e6f6554aa11fa10160a5fda689e736c3c34169f

View file

@ -1,40 +1,32 @@
diff --git a/common/common.cpp b/common/common.cpp diff --git a/common/common.cpp b/common/common.cpp
index dbb724fb..c26fe6ee 100644 index 2e8374d5..70d0afde 100644
--- a/common/common.cpp --- a/common/common.cpp
+++ b/common/common.cpp +++ b/common/common.cpp
@@ -2087,14 +2087,27 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par @@ -2110,9 +2110,21 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); if (loaded_la.adapter == nullptr) {
float lora_scale = std::get<1>(params.lora_adapter[i]); fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
+
+ // try to load as gguf
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- llama_free(lctx); - llama_free(lctx);
- llama_free_model(model); - llama_free_model(model);
- return std::make_tuple(nullptr, nullptr); - return iparams;
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
+ +
+ // if that fails, try loading as ggla for compatibility + // if that fails, try loading as ggla for compatibility
+ int err = llama_model_apply_lora_from_file(model, + int err = llama_model_apply_lora_from_file(model,
+ lora_adapter.c_str(), + la.path.c_str(),
+ lora_scale, + la.scale,
+ nullptr, + nullptr,
+ params.n_threads); + params.n_threads);
+ if (err != 0) { + if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx); + llama_free(lctx);
+ llama_free_model(model); + llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr); + return iparams;
+ } else {
+ break;
+ } + }
+ } else {
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
} }
- llama_lora_adapter_set(lctx, adapter, lora_scale); iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
} }
if (params.ignore_eos) {
diff --git a/include/llama.h b/include/llama.h diff --git a/include/llama.h b/include/llama.h
index 93fd77ca..b0fb37a6 100644 index 93fd77ca..b0fb37a6 100644
--- a/include/llama.h --- a/include/llama.h

View file

@ -1,20 +0,0 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index a207451f..fba6b175 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4969,6 +4969,7 @@ static void llm_load_hparams(
hparams.attn_soft_cap = true;
switch (hparams.n_layer) {
+ case 26: model.type = e_model::MODEL_2B; break;
case 42: model.type = e_model::MODEL_9B; break;
case 46: model.type = e_model::MODEL_27B; break;
default: model.type = e_model::MODEL_UNKNOWN;
@@ -11736,6 +11737,7 @@ struct llm_build_context {
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
switch (model.type) {
+ case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
default: GGML_ABORT("fatal error");

View file

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) (*EmbedResponse, error) Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -44,11 +44,12 @@ type LlamaServer interface {
// llmServer is an instance of the llama.cpp server // llmServer is an instance of the llama.cpp server
type llmServer struct { type llmServer struct {
port int port int
cmd *exec.Cmd cmd *exec.Cmd
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
numParallel int
estimate MemoryEstimate estimate MemoryEstimate
totalLayers uint64 totalLayers uint64
@ -124,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
} }
// On linux, over-allocating CPU memory will almost always result in an error // On linux and windows, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" { // Darwin has fully dynamic swap so has no direct concept of free swap space
if runtime.GOOS != "darwin" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available { if systemMemoryRequired > available {
@ -343,6 +345,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimate: estimate, estimate: estimate,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)), sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1, totalLayers: ggml.KV().BlockCount() + 1,
gpus: gpus, gpus: gpus,
@ -880,16 +883,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil return nil
} }
type EmbedRequest struct { type EmbeddingRequest struct {
Content []string `json:"content"` Content string `json:"content"`
} }
type EmbedResponse struct { type EmbeddingResponse struct {
Embedding [][]float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_n"`
} }
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) { func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@ -904,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(EmbedRequest{Content: input}) data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err) return nil, fmt.Errorf("error creating embed request: %w", err)
} }
req.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(r)
if err != nil { if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err) return nil, fmt.Errorf("do embedding request: %w", err)
} }
@ -931,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var e EmbedResponse var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil { if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }
return &e, nil return e.Embedding, nil
} }
type TokenizeRequest struct { type TokenizeRequest struct {

View file

@ -26,6 +26,7 @@ var errorPrefixes = []string{
"cudaMalloc failed", "cudaMalloc failed",
"\"ERR\"", "\"ERR\"",
"error loading model", "error loading model",
"GGML_ASSERT",
} }
func (w *StatusWriter) Write(b []byte) (int, error) { func (w *StatusWriter) Write(b []byte) (int, error) {

View file

@ -7,27 +7,22 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
const ( const (
prefix = `data:image/jpeg;base64,` prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
) )
func prepareRequest(req *http.Request, body any) { var False = false
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
func TestChatMiddleware(t *testing.T) { func TestChatMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) req api.ChatRequest
err ErrorResponse
} }
var capturedRequest *api.ChatRequest var capturedRequest *api.ChatRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", name: "chat handler",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
Model: "test-model", "messages": [
Messages: []Message{{Role: "user", Content: "Hello"}}, {"role": "user", "content": "Hello"}
} ]
prepareRequest(req, body) }`,
}, req: api.ChatRequest{
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { Model: "test-model",
if resp.Code != http.StatusOK { Messages: []api.Message{
t.Fatalf("expected 200, got %d", resp.Code) {
} Role: "user",
Content: "Hello",
if req.Messages[0].Role != "user" { },
t.Fatalf("expected 'user', got %s", req.Messages[0].Role) },
} Options: map[string]any{
"temperature": 1.0,
if req.Messages[0].Content != "Hello" { "top_p": 1.0,
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) },
} Stream: &False,
}, },
}, },
{ {
Name: "chat handler with image content", name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
Model: "test-model", "messages": [
Messages: []Message{ {
{ "role": "user",
Role: "user", Content: []map[string]any{ "content": [
{"type": "text", "text": "Hello"}, {
{"type": "image_url", "image_url": map[string]string{"url": imageURL}}, "type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
},
},
}, },
}, },
}, },
} },
prepareRequest(req, body) Options: map[string]any{
}, "temperature": 1.0,
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { "top_p": 1.0,
if resp.Code != http.StatusOK { },
t.Fatalf("expected 200, got %d", resp.Code) Stream: &False,
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
}, },
}, },
{ {
Name: "chat handler with tools", name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
Model: "test-model", "messages": [
Messages: []Message{ {"role": "user", "content": 2}
{Role: "user", Content: "What's the weather like in Paris Today?"}, ]
{Role: "assistant", ToolCalls: []ToolCall{{ }`,
ID: "id", err: ErrorResponse{
Type: "function", Error: Error{
Function: struct { Message: "invalid message content type: float64",
Name string `json:"name"` Type: "invalid_request_error",
Arguments string `json:"arguments"` },
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/chat", endpoint) router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
func TestCompletionsMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) req api.GenerateRequest
err ErrorResponse
} }
var capturedRequest *api.GenerateRequest var capturedRequest *api.GenerateRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "completions handler", name: "completions handler",
Setup: func(t *testing.T, req *http.Request) { body: `{
temp := float32(0.8) "model": "test-model",
body := CompletionRequest{ "prompt": "Hello",
Model: "test-model", "temperature": 0.8,
Prompt: "Hello", "stop": ["\n", "stop"],
Temperature: &temp, "suffix": "suffix"
Stop: []string{"\n", "stop"}, }`,
Suffix: "suffix", req: api.GenerateRequest{
} Model: "test-model",
prepareRequest(req, body) Prompt: "Hello",
}, Options: map[string]any{
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { "frequency_penalty": 0.0,
if req.Prompt != "Hello" { "presence_penalty": 0.0,
t.Fatalf("expected 'Hello', got %s", req.Prompt) "temperature": 1.6,
} "top_p": 1.0,
"stop": []any{"\n", "stop"},
if req.Options["temperature"] != 1.6 { },
t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) Suffix: "suffix",
} Stream: &False,
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
}, },
}, },
{ {
Name: "completions handler error forwarding", name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := CompletionRequest{ "model": "test-model",
Model: "test-model", "prompt": "Hello",
Prompt: "Hello", "temperature": null,
Temperature: nil, "stop": [1, 2],
Stop: []int{1, 2}, "suffix": "suffix"
Suffix: "suffix", }`,
} err: ErrorResponse{
prepareRequest(req, body) Error: Error{
}, Message: "invalid type for 'stop' field: float64",
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { Type: "invalid_request_error",
if resp.Code != http.StatusBadRequest { },
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/generate", endpoint) router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
func TestEmbeddingsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) req api.EmbedRequest
err ErrorResponse
} }
var capturedRequest *api.EmbedRequest var capturedRequest *api.EmbedRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "embed handler single input", name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "input": "Hello",
Input: "Hello", "model": "test-model"
Model: "test-model", }`,
} req: api.EmbedRequest{
prepareRequest(req, body) Input: "Hello",
}, Model: "test-model",
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
Name: "embed handler batch input", name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "input": ["Hello", "World"],
Input: []string{"Hello", "World"}, "model": "test-model"
Model: "test-model", }`,
} req: api.EmbedRequest{
prepareRequest(req, body) Input: []any{"Hello", "World"},
}, Model: "test-model",
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
Name: "embed handler error forwarding", name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "model": "test-model"
Model: "test-model", }`,
} err: ErrorResponse{
prepareRequest(req, body) Error: Error{
}, Message: "invalid input",
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { Type: "invalid_request_error",
if resp.Code != http.StatusBadRequest { },
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/embed", endpoint) router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
} }
func TestMiddlewareResponses(t *testing.T) { func TestListMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Method string endpoint func(c *gin.Context)
Path string resp string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
} }
testCases := []testCase{ testCases := []testCase{
{ {
Name: "list handler", name: "list handler",
Method: http.MethodGet, endpoint: func(c *gin.Context) {
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{ c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{ Models: []api.ListModelResponse{
{ {
Name: "Test Model", Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
}, },
}, },
}) })
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { resp: `{
var listResp ListCompletion "object": "list",
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { "data": [
t.Fatal(err) {
} "id": "test-model",
"object": "model",
if listResp.Object != "list" { "created": 1686935002,
t.Fatalf("expected list, got %s", listResp.Object) "owned_by": "library"
} }
]
if len(listResp.Data) != 1 { }`,
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
}, },
{ {
Name: "retrieve model", name: "list handler empty output",
Method: http.MethodGet, endpoint: func(c *gin.Context) {
Path: "/api/show/:model", c.JSON(http.StatusOK, api.ListResponse{})
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
}, },
resp: `{
"object": "list",
"data": null
}`,
}, },
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { router := gin.New()
router = gin.New() router.Use(ListMiddleware())
router.Use(tc.Handler()) router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
router.Handle(tc.Method, tc.Path, tc.Endpoint) req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil { resp := httptest.NewRecorder()
tc.Setup(t, req) router.ServeHTTP(resp, req)
}
resp := httptest.NewRecorder() var expected, actual map[string]any
router.ServeHTTP(resp, req) err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
assert.Equal(t, http.StatusOK, resp.Code) err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
tc.Expected(t, resp) if !reflect.DeepEqual(expected, actual) {
}) t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
resp: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
} }
} }

View file

@ -3,11 +3,12 @@ package progress
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
type Spinner struct { type Spinner struct {
message string message atomic.Value
messageWidth int messageWidth int
parts []string parts []string
@ -21,20 +22,25 @@ type Spinner struct {
func NewSpinner(message string) *Spinner { func NewSpinner(message string) *Spinner {
s := &Spinner{ s := &Spinner{
message: message,
parts: []string{ parts: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
}, },
started: time.Now(), started: time.Now(),
} }
s.SetMessage(message)
go s.start() go s.start()
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message.Store(message)
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 {
message := strings.TrimSpace(s.message) if message, ok := s.message.Load().(string); ok && len(message) > 0 {
message := strings.TrimSpace(message)
if s.messageWidth > 0 && len(message) > s.messageWidth { if s.messageWidth > 0 && len(message) > s.messageWidth {
message = message[:s.messageWidth] message = message[:s.messageWidth]
} }

View file

@ -62,7 +62,7 @@ func (b *Buffer) MoveLeft() {
rLength := runewidth.RuneWidth(r) rLength := runewidth.RuneWidth(r)
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width))
if rLength == 2 { if rLength == 2 {
fmt.Print(CursorLeft) fmt.Print(CursorLeft)
} }
@ -74,7 +74,7 @@ func (b *Buffer) MoveLeft() {
fmt.Print(CursorLeft) fmt.Print(CursorLeft)
} }
} else { } else {
fmt.Print(cursorLeftN(rLength)) fmt.Print(CursorLeftN(rLength))
} }
b.Pos -= 1 b.Pos -= 1
@ -115,15 +115,15 @@ func (b *Buffer) MoveRight() {
b.DisplayPos += rLength b.DisplayPos += rLength
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength)) fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())+rLength))
b.DisplayPos += 1 b.DisplayPos += 1
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())))
b.DisplayPos += 1 b.DisplayPos += 1
} else { } else {
fmt.Print(cursorRightN(rLength)) fmt.Print(CursorRightN(rLength))
} }
} }
} }
@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() {
fmt.Print(CursorUp) fmt.Print(CursorUp)
} }
} }
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())))
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0 b.DisplayPos = 0
} }
@ -169,9 +169,9 @@ func (b *Buffer) MoveToEnd() {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
remainder := b.DisplaySize() % b.LineWidth remainder := b.DisplaySize() % b.LineWidth
fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder)) fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())+remainder))
} else { } else {
fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos)) fmt.Print(CursorRightN(b.DisplaySize() - b.DisplayPos))
} }
b.Pos = b.Buf.Size() b.Pos = b.Buf.Size()
@ -286,8 +286,7 @@ func (b *Buffer) drawRemaining() {
remLength := runewidth.StringWidth(remainingText) remLength := runewidth.StringWidth(remainingText)
if len(currLine) > 0 { if len(currLine) > 0 {
fmt.Printf(ClearToEOL + currLine) fmt.Print(ClearToEOL + currLine + CursorLeftN(currLineSpace))
fmt.Print(cursorLeftN(currLineSpace))
} else { } else {
fmt.Print(ClearToEOL) fmt.Print(ClearToEOL)
} }
@ -301,9 +300,9 @@ func (b *Buffer) drawRemaining() {
} }
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText { if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
fmt.Print(cursorRightN(currLineSpace)) fmt.Print(CursorRightN(currLineSpace))
fmt.Printf("\n%s", b.Prompt.AltPrompt) fmt.Printf("\n%s", b.Prompt.AltPrompt)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace)) fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width-currLineSpace))
} }
// render the other lines // render the other lines
@ -333,9 +332,7 @@ func (b *Buffer) drawRemaining() {
lineLength += runewidth.RuneWidth(c) lineLength += runewidth.RuneWidth(c)
fmt.Printf("%c", c) fmt.Printf("%c", c)
} }
fmt.Print(ClearToEOL) fmt.Print(ClearToEOL + CursorUpN(totalLines) + CursorBOL + CursorRightN(b.Width-currLineSpace))
fmt.Print(cursorUpN(totalLines))
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth) hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
@ -357,8 +354,7 @@ func (b *Buffer) Remove() {
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
// if the user backspaces over the word boundary, do this magic to clear the line // if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line // and move to the end of the previous line
fmt.Printf(CursorBOL + ClearToEOL) fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth { if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
@ -370,24 +366,23 @@ func (b *Buffer) Remove() {
} }
if rLength == 2 { if rLength == 2 {
fmt.Print(CursorLeft + " " + cursorLeftN(2)) fmt.Print(CursorLeft + " " + CursorLeftN(2))
} else { } else {
fmt.Print(" " + CursorLeft) fmt.Print(" " + CursorLeft)
} }
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
fmt.Printf(CursorBOL + ClearToEOL) fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width))
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.Pos == b.Buf.Size() { if b.Pos == b.Buf.Size() {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
} }
b.DisplayPos -= 1 b.DisplayPos -= 1
} else { } else {
fmt.Print(cursorLeftN(rLength)) fmt.Print(CursorLeftN(rLength))
for range rLength { for range rLength {
fmt.Print(" ") fmt.Print(" ")
} }
fmt.Print(cursorLeftN(rLength)) fmt.Print(CursorLeftN(rLength))
} }
var eraseExtraLine bool var eraseExtraLine bool
@ -405,9 +400,9 @@ func (b *Buffer) Remove() {
// are trailing characters which go over the line width boundary // are trailing characters which go over the line width boundary
if eraseExtraLine { if eraseExtraLine {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL) fmt.Print(CursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
place := b.DisplayPos % b.LineWidth place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt()))) fmt.Print(CursorUpN(remainingLines+1) + CursorRightN(place+len(b.Prompt.prompt())))
} }
} }
} }
@ -422,9 +417,9 @@ func (b *Buffer) Delete() {
if b.DisplaySize()%b.LineWidth == 0 { if b.DisplaySize()%b.LineWidth == 0 {
if b.DisplayPos != b.DisplaySize() { if b.DisplayPos != b.DisplaySize() {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL) fmt.Print(CursorDownN(remainingLines) + CursorBOL + ClearToEOL)
place := b.DisplayPos % b.LineWidth place := b.DisplayPos % b.LineWidth
fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt()))) fmt.Print(CursorUpN(remainingLines) + CursorRightN(place+len(b.Prompt.prompt())))
} }
} }
} }
@ -471,17 +466,17 @@ func (b *Buffer) DeleteWord() {
} }
func (b *Buffer) ClearScreen() { func (b *Buffer) ClearScreen() {
fmt.Printf(ClearScreen + CursorReset + b.Prompt.prompt()) fmt.Print(ClearScreen + CursorReset + b.Prompt.prompt())
if b.IsEmpty() { if b.IsEmpty() {
ph := b.Prompt.placeholder() ph := b.Prompt.placeholder()
fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault) fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
} else { } else {
currPos := b.DisplayPos currPos := b.DisplayPos
currIndex := b.Pos currIndex := b.Pos
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0 b.DisplayPos = 0
b.drawRemaining() b.drawRemaining()
fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt()))) fmt.Print(CursorReset + CursorRightN(len(b.Prompt.prompt())))
if currPos > 0 { if currPos > 0 {
targetLine := currPos / b.LineWidth targetLine := currPos / b.LineWidth
if targetLine > 0 { if targetLine > 0 {
@ -491,10 +486,10 @@ func (b *Buffer) ClearScreen() {
} }
remainder := currPos % b.LineWidth remainder := currPos % b.LineWidth
if remainder > 0 { if remainder > 0 {
fmt.Print(cursorRightN(remainder)) fmt.Print(CursorRightN(remainder))
} }
if currPos%b.LineWidth == 0 { if currPos%b.LineWidth == 0 {
fmt.Printf(CursorBOL + b.Prompt.AltPrompt) fmt.Print(CursorBOL + b.Prompt.AltPrompt)
} }
} }
b.Pos = currIndex b.Pos = currIndex
@ -513,13 +508,13 @@ func (b *Buffer) Replace(r []rune) {
b.Buf.Clear() b.Buf.Clear()
fmt.Printf(CursorBOL + ClearToEOL) fmt.Print(CursorBOL + ClearToEOL)
for range lineNums { for range lineNums {
fmt.Print(CursorUp + CursorBOL + ClearToEOL) fmt.Print(CursorUp + CursorBOL + ClearToEOL)
} }
fmt.Printf(CursorBOL + b.Prompt.prompt()) fmt.Print(CursorBOL + b.Prompt.prompt())
for _, c := range r { for _, c := range r {
b.Add(c) b.Add(c)
@ -545,19 +540,3 @@ func (b *Buffer) StringNM(n, m int) string {
} }
return s return s
} }
func cursorLeftN(n int) string {
return fmt.Sprintf(CursorLeftN, n)
}
func cursorRightN(n int) string {
return fmt.Sprintf(CursorRightN, n)
}
func cursorUpN(n int) string {
return fmt.Sprintf(CursorUpN, n)
}
func cursorDownN(n int) string {
return fmt.Sprintf(CursorDownN, n)
}

View file

@ -98,7 +98,7 @@ func (i *Instance) Readline() (string, error) {
showPlaceholder := !i.Pasting || i.Prompt.UseAlt showPlaceholder := !i.Pasting || i.Prompt.UseAlt
if buf.IsEmpty() && showPlaceholder { if buf.IsEmpty() && showPlaceholder {
ph := i.Prompt.placeholder() ph := i.Prompt.placeholder()
fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault) fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault)
} }
r, err := i.Terminal.Read() r, err := i.Terminal.Read()

View file

@ -1,5 +1,7 @@
package readline package readline
import "strconv"
const ( const (
CharNull = 0 CharNull = 0
CharLineStart = 1 CharLineStart = 1
@ -41,34 +43,49 @@ const (
) )
const ( const (
CursorUp = "\033[1A" Esc = "\x1b"
CursorDown = "\033[1B"
CursorRight = "\033[1C"
CursorLeft = "\033[1D"
CursorSave = "\033[s" CursorSave = Esc + "[s"
CursorRestore = "\033[u" CursorRestore = Esc + "[u"
CursorUpN = "\033[%dA" CursorEOL = Esc + "[E"
CursorDownN = "\033[%dB" CursorBOL = Esc + "[1G"
CursorRightN = "\033[%dC" CursorHide = Esc + "[?25l"
CursorLeftN = "\033[%dD" CursorShow = Esc + "[?25h"
CursorEOL = "\033[E" ClearToEOL = Esc + "[K"
CursorBOL = "\033[1G" ClearLine = Esc + "[2K"
CursorHide = "\033[?25l" ClearScreen = Esc + "[2J"
CursorShow = "\033[?25h" CursorReset = Esc + "[0;0f"
ClearToEOL = "\033[K" ColorGrey = Esc + "[38;5;245m"
ClearLine = "\033[2K" ColorDefault = Esc + "[0m"
ClearScreen = "\033[2J"
CursorReset = "\033[0;0f"
ColorGrey = "\033[38;5;245m" StartBracketedPaste = Esc + "[?2004h"
ColorDefault = "\033[0m" EndBracketedPaste = Esc + "[?2004l"
)
StartBracketedPaste = "\033[?2004h" func CursorUpN(n int) string {
EndBracketedPaste = "\033[?2004l" return Esc + "[" + strconv.Itoa(n) + "A"
}
func CursorDownN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "B"
}
func CursorRightN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "C"
}
func CursorLeftN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "D"
}
var (
CursorUp = CursorUpN(1)
CursorDown = CursorDownN(1)
CursorRight = CursorRightN(1)
CursorLeft = CursorLeftN(1)
) )
const ( const (

View file

@ -209,15 +209,15 @@ install_cuda_driver_yum() {
case $PACKAGE_MANAGER in case $PACKAGE_MANAGER in
yum) yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils $SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi
;; ;;
dnf) dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi
@ -245,8 +245,8 @@ install_cuda_driver_yum() {
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() { install_cuda_driver_apt() {
status 'Installing NVIDIA repository...' status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi

View file

@ -94,7 +94,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
} }
const ( const (
numDownloadParts = 64 numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
@ -216,6 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
defer file.Close() defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
@ -232,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 { if len(via) > 10 {
return errors.New("maxium redirects exceeded (10) for directURL") return errors.New("maximum redirects exceeded (10) for directURL")
} }
// if the hostname is the same, allow the redirect // if the hostname is the same, allow the redirect

View file

@ -250,19 +250,21 @@ func GetModel(name string) (*Model, error) {
Template: template.DefaultTemplate, Template: template.DefaultTemplate,
} }
filename, err := GetBlobsPath(manifest.Config.Digest) if manifest.Config.Digest != "" {
if err != nil { filename, err := GetBlobsPath(manifest.Config.Digest)
return nil, err if err != nil {
} return nil, err
}
configFile, err := os.Open(filename) configFile, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer configFile.Close() defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err return nil, err
}
} }
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
@ -371,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
var messages []*api.Message var messages []*api.Message
parameters := make(map[string]any) parameters := make(map[string]any)
var layers []*Layer var layers []Layer
for _, c := range modelfile.Commands { for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@ -497,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer Layer) bool {
if layer.MediaType != mediatype { if layer.MediaType != mediatype {
return false return false
} }
@ -543,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
} }
var err2 error var err2 error
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer Layer) bool {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.message": case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones // if there are new messages, remove the inherited ones
@ -623,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil { if err != nil {
return err return err
} }
for _, layer := range append(layers, layer) { for _, layer := range append(layers, configLayer) {
if layer.status != "" { if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status}) fn(api.ProgressResponse{Status: layer.status})
} }
@ -637,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
old, _ := ParseNamedManifest(name) old, _ := ParseNamedManifest(name)
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, layer, layers); err != nil { if err := WriteManifest(name, configLayer, layers); err != nil {
return err return err
} }
@ -714,8 +716,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests // save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp) manifest, _, err := GetManifest(fmp)
if err != nil { if err != nil {
//nolint:nilerr return err
return nil
} }
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
@ -782,7 +783,8 @@ func PruneLayers() error {
err = deleteUnusedLayers(nil, deleteMap) err = deleteUnusedLayers(nil, deleteMap)
if err != nil { if err != nil {
return err slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
return nil
} }
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap))) slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
@ -837,9 +839,11 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return err return err
} }
var layers []*Layer var layers []Layer
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config) if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
for _, layer := range layers { for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@ -890,7 +894,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{} deleteMap[l.Digest] = struct{}{}
} }
deleteMap[manifest.Config.Digest] = struct{}{} if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
} }
} }
@ -905,9 +911,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %s", err)
} }
var layers []*Layer var layers []Layer
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config) if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool) skipVerify := make(map[string]bool)
for _, layer := range layers { for _, layer := range layers {
@ -971,7 +979,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "removing any unused layers"}) fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap) err = deleteUnusedLayers(nil, deleteMap)
if err != nil { if err != nil {
return err slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
} }
} }

View file

@ -2,6 +2,7 @@ package server
import ( import (
"crypto/sha256" "crypto/sha256"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -15,15 +16,15 @@ type Layer struct {
status string status string
} }
func NewLayer(r io.Reader, mediatype string) (*Layer, error) { func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("") blobs, err := GetBlobsPath("")
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
temp, err := os.CreateTemp(blobs, "sha256-") temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
@ -31,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
sha256sum := sha256.New() sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
if err := temp.Close(); err != nil { if err := temp.Close(); err != nil {
return nil, err return Layer{}, err
} }
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest) blob, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
status := "using existing layer" status := "using existing layer"
if _, err := os.Stat(blob); err != nil { if _, err := os.Stat(blob); err != nil {
status = "creating new layer" status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil { if err := os.Rename(temp.Name(), blob); err != nil {
return nil, err return Layer{}, err
} }
} }
return &Layer{ return Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: digest, Digest: digest,
Size: n, Size: n,
@ -60,18 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
}, nil }, nil
} }
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) { func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if digest == "" {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest) blob, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
fi, err := os.Stat(blob) fi, err := os.Stat(blob)
if err != nil { if err != nil {
return nil, err return Layer{}, err
} }
return &Layer{ return Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: digest, Digest: digest,
Size: fi.Size(), Size: fi.Size(),
@ -81,6 +86,10 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
} }
func (l *Layer) Open() (io.ReadSeekCloser, error) { func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest) blob, err := GetBlobsPath(l.Digest)
if err != nil { if err != nil {
return nil, err return nil, err
@ -90,6 +99,10 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
} }
func (l *Layer) Remove() error { func (l *Layer) Remove() error {
if l.Digest == "" {
return nil
}
ms, err := Manifests() ms, err := Manifests()
if err != nil { if err != nil {
return err return err

View file

@ -14,10 +14,10 @@ import (
) )
type Manifest struct { type Manifest struct {
SchemaVersion int `json:"schemaVersion"` SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Config *Layer `json:"config"` Config Layer `json:"config"`
Layers []*Layer `json:"layers"` Layers []Layer `json:"layers"`
filepath string filepath string
fi os.FileInfo fi os.FileInfo
@ -47,10 +47,12 @@ func (m *Manifest) Remove() error {
func (m *Manifest) RemoveLayers() error { func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) { for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) { if layer.Digest != "" {
slog.Debug("layer does not exist", "digest", layer.Digest) if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
} else if err != nil { slog.Debug("layer does not exist", "digest", layer.Digest)
return err } else if err != nil {
return err
}
} }
} }
@ -93,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return &m, nil return &m, nil
} }
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
if err != nil { if err != nil {
return err return err

View file

@ -26,7 +26,7 @@ import (
var intermediateBlobs map[string]string = make(map[string]string) var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct { type layerGGML struct {
*Layer Layer
*llm.GGML *llm.GGML
} }
@ -176,9 +176,20 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
mediatype = "application/vnd.ollama.image.projector" mediatype = "application/vnd.ollama.image.projector"
} }
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) var layer Layer
if err != nil { if digest != "" && n == stat.Size() && offset == 0 {
return nil, err layer, err = NewLayerFromLayer(digest, mediatype, file.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil {
return nil, err
}
} }
layers = append(layers, &layerGGML{layer, ggml}) layers = append(layers, &layerGGML{layer, ggml})

View file

@ -2,8 +2,10 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -11,6 +13,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@ -133,3 +136,82 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
}) })
} }
} }
func TestParseFromFileFromLayer(t *testing.T) {
tempModels := t.TempDir()
file, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file.Close()
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 1 {
t.Fatalf("got %d != want 1", len(layers))
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers2) != 1 {
t.Fatalf("got %d != want 1", len(layers2))
}
if layers[0].Digest != layers2[0].Digest {
t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest)
}
if layers[0].Size != layers2[0].Size {
t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size)
}
if layers[0].MediaType != layers2[0].MediaType {
t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType)
}
}
func TestParseLayerFromCopy(t *testing.T) {
tempModels := t.TempDir()
file2, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file2.Close()
for range 5 {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
}
if _, err := file2.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 5 {
t.Fatalf("got %d != want 5", len(layers))
}
}

View file

@ -23,6 +23,7 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@ -323,13 +324,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input = append(input, v.(string)) input = append(input, v.(string))
} }
default: default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) if req.Input != nil {
return c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
} return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
@ -340,12 +338,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
kvData, err := getKVData(m.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var count int
for i, s := range input { for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s) tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil { if err != nil {
@ -368,25 +372,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
} }
count += len(tokens)
input[i] = s input[i] = s
} }
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil { var g errgroup.Group
slog.Error("embedding generation failed", "error", err) embeddings := make([][]float32, len(input))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) for i, text := range input {
return g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
embeddings[i] = normalize(embedding)
return nil
})
} }
for i, e := range embeddings.Embedding { if err := g.Wait(); err != nil {
embeddings.Embedding[i] = normalize(e) slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return
} }
resp := api.EmbedResponse{ resp := api.EmbedResponse{
Model: req.Model, Model: req.Model,
Embeddings: embeddings.Embedding, Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart), TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: embeddings.PromptEvalCount, PromptEvalCount: count,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -430,21 +445,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
embedding := make([]float64, len(embeddings.Embedding[0])) var e []float64
for _, v := range embedding {
for i, v := range embeddings.Embedding[0] { e = append(e, float64(v))
embedding[i] = float64(v)
} }
resp := api.EmbeddingResponse{ resp := api.EmbeddingResponse{
Embedding: embedding, Embedding: e,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -824,17 +838,20 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
models := []api.ListModelResponse{} models := []api.ListModelResponse{}
for n, m := range ms { for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
var cf ConfigV2 var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err) if m.Config.Digest != "" {
continue f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
} }
// tag should never be masked // tag should never be masked

View file

@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
} }
// create a manifest with duplicate layers // create a manifest with duplicate layers
if err := WriteManifest(n, config, []*Layer{config}); err != nil { if err := WriteManifest(n, config, []Layer{config}); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -272,76 +272,6 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy) assert.Equal(t, "library", retrieveResp.OwnedBy)
}, },
}, },
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())

View file

@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
// check for model compatibility // check for model compatibility
if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") { if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
} }
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)

View file

@ -708,8 +708,8 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embedResp *llm.EmbedResponse embeddingResp []float32
embedRespErr error embeddingRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
detokenizeResp string detokenizeResp string
@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) { func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embedResp, s.embedRespErr return s.embeddingResp, s.embeddingRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {

8
server/sparse_common.go Normal file
View file

@ -0,0 +1,8 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

17
server/sparse_windows.go Normal file
View file

@ -0,0 +1,17 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}

View file

@ -26,7 +26,7 @@ import (
var blobUploadManager sync.Map var blobUploadManager sync.Map
type blobUpload struct { type blobUpload struct {
*Layer Layer
Total int64 Total int64
Completed atomic.Int64 Completed atomic.Int64
@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
p.written = 0 p.written = 0
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL() requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)

View file

@ -219,7 +219,7 @@ func (n Name) String() string {
return b.String() return b.String()
} }
// DisplayShort returns a short string version of the name. // DisplayShortest returns a short string version of the name.
func (n Name) DisplayShortest() string { func (n Name) DisplayShortest() string {
var sb strings.Builder var sb strings.Builder