Compare commits

..

No commits in common. "9e08c23ba931bf86687865ac5ebf83e8c38f83bf" and "f2d1c842add2d1edd081d862b44bbb2f8a0cbc19" have entirely different histories.

50 changed files with 773 additions and 1247 deletions

3
.gitattributes vendored
View file

@ -1,3 +1,2 @@
llm/ext_server/* linguist-vendored llm/ext_server/* linguist-vendored
* text=auto * text eol=lf
*.go text eol=lf

View file

@ -31,7 +31,7 @@ jobs:
security set-keychain-settings -lut 3600 build.keychain security set-keychain-settings -lut 3600 build.keychain
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: Build Darwin - name: Build Darwin
env: env:
@ -87,7 +87,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@ -141,7 +141,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@ -218,7 +218,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@ -306,7 +306,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get - run: go get
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4

View file

@ -63,7 +63,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@ -163,7 +163,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
@ -200,7 +200,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@ -255,7 +255,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: false cache: false
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in
@ -297,7 +297,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version-file: go.mod go-version: "stable"
cache: true cache: true
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in

View file

@ -24,6 +24,7 @@ linters:
- nosprintfhostport - nosprintfhostport
- staticcheck - staticcheck
- tenv - tenv
- testifylint
- unconvert - unconvert
- unused - unused
- usestdlibvars - usestdlibvars

View file

@ -325,7 +325,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [tlm](https://github.com/yusufcanb/tlm) - [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama) - [podman-ollama](https://github.com/ericcurtin/podman-ollama)
- [gollama](https://github.com/sammcj/gollama) - [gollama](https://github.com/sammcj/gollama)
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
### Database ### Database

View file

@ -298,7 +298,7 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
return &lr, nil return &lr, nil
} }
// ListRunning lists running models. // List running models.
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
var lr ProcessResponse var lr ProcessResponse
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil { if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
@ -333,7 +333,7 @@ func (c *Client) Show(ctx context.Context, req *ShowRequest) (*ShowResponse, err
return &resp, nil return &resp, nil
} }
// Heartbeat checks if the server has started and is responsive; if yes, it // Hearbeat checks if the server has started and is responsive; if yes, it
// returns nil, otherwise an error. // returns nil, otherwise an error.
func (c *Client) Heartbeat(ctx context.Context) error { func (c *Client) Heartbeat(ctx context.Context) error {
if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil { if err := c.do(ctx, http.MethodHead, "/", nil, nil); err != nil {

View file

@ -504,7 +504,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
for key, val := range m { for key, val := range m {
opt, ok := jsonOpts[key] opt, ok := jsonOpts[key]
if !ok { if !ok {
slog.Warn("invalid option provided", "option", key) slog.Warn("invalid option provided", "option", opt.Name)
continue continue
} }

View file

@ -11,8 +11,8 @@ import (
) )
const ( const (
updateAvailableMenuID = 1 updatAvailableMenuID = 1
updateMenuID = updateAvailableMenuID + 1 updateMenuID = updatAvailableMenuID + 1
separatorMenuID = updateMenuID + 1 separatorMenuID = updateMenuID + 1
diagLogsMenuID = separatorMenuID + 1 diagLogsMenuID = separatorMenuID + 1
diagSeparatorMenuID = diagLogsMenuID + 1 diagSeparatorMenuID = diagLogsMenuID + 1
@ -35,7 +35,7 @@ func (t *winTray) initMenus() error {
func (t *winTray) UpdateAvailable(ver string) error { func (t *winTray) UpdateAvailable(ver string) error {
if !t.updateNotified { if !t.updateNotified {
slog.Debug("updating menu and sending notification for new update") slog.Debug("updating menu and sending notification for new update")
if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { if err := t.addOrUpdateMenuItem(updatAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
return fmt.Errorf("unable to create menu entries %w", err) return fmt.Errorf("unable to create menu entries %w", err)
} }
if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil {

View file

@ -22,7 +22,6 @@ import (
"runtime" "runtime"
"slices" "slices"
"strings" "strings"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -79,7 +78,6 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@ -114,7 +112,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path, spinner) digest, err := createBlob(cmd, client, path)
if err != nil { if err != nil {
return err return err
} }
@ -265,20 +263,13 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) { func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
@ -288,43 +279,13 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *pr
return "", err return "", err
} }
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
type progressWriter struct {
n atomic.Int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n.Add(int64(len(p)))
return len(p), nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
@ -1125,7 +1086,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
return nil return nil
} }
func RunServer(_ *cobra.Command, _ []string) error { func RunServer(cmd *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil { if err := initializeKeypair(); err != nil {
return err return err
} }

View file

@ -27,10 +27,6 @@ func (Parameters) KV(t *Tokenizer) llm.KV {
"tokenizer.ggml.token_type": t.Vocabulary.Types, "tokenizer.ggml.token_type": t.Vocabulary.Types,
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
if t.Template != "" { if t.Template != "" {
kv["tokenizer.chat_template"] = t.Template kv["tokenizer.chat_template"] = t.Template
} }
@ -93,8 +89,6 @@ func Convert(fsys fs.FS, ws io.WriteSeeker) error {
conv = &mixtral{} conv = &mixtral{}
case "GemmaForCausalLM": case "GemmaForCausalLM":
conv = &gemma{} conv = &gemma{}
case "Phi3ForCausalLM":
conv = &phi3{}
default: default:
return errors.New("unsupported architecture") return errors.New("unsupported architecture")
} }

View file

@ -90,6 +90,10 @@ func (p *llama) KV(t *Tokenizer) llm.KV {
kv["llama.attention.value_length"] = p.HeadDim kv["llama.attention.value_length"] = p.HeadDim
} }
if len(t.Merges) > 0 {
kv["tokenizer.ggml.merges"] = t.Merges
}
return kv return kv
} }

View file

@ -1,125 +0,0 @@
package convert
import (
"cmp"
"encoding/binary"
"io"
"math"
"strings"
"sync"
"github.com/ollama/ollama/llm"
)
type phi3 struct {
Parameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NLayers uint32 `json:"n_layers"`
HiddenSize uint32 `json:"hidden_size"`
NEmbd uint32 `json:"n_embd"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NHead uint32 `json:"n_head"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NHeadKV uint32 `json:"n_head_kv"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling struct {
Type string `json:"type"`
LongFactor ropeFactor `json:"long_factor"`
ShortFactor ropeFactor `json:"short_factor"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
NPositions uint32 `json:"n_positions"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
SlidingWindow uint32 `json:"sliding_window"`
}
var _ Converter = (*phi3)(nil)
func (p *phi3) KV(t *Tokenizer) llm.KV {
kv := p.Parameters.KV(t)
kv["general.architecture"] = "phi3"
kv["general.name"] = "phi3"
kv["phi3.context_length"] = p.MaxPositionEmbeddings
kv["phi3.embedding_length"] = cmp.Or(p.HiddenSize, p.NEmbd)
kv["phi3.feed_forward_length"] = p.IntermediateSize
kv["phi3.block_count"] = cmp.Or(p.NumHiddenLayers, p.NLayers)
kv["phi3.attention.head_count"] = cmp.Or(p.NumAttentionHeads, p.NHead)
kv["phi3.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NHeadKV)
kv["phi3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["phi3.rope.dimension_count"] = p.HiddenSize / cmp.Or(p.NumAttentionHeads, p.NHead)
kv["phi3.rope.freq_base"] = p.RopeTheta
kv["phi3.rope.scaling.original_context_length"] = p.OriginalMaxPositionEmbeddings
kv["phi3.attention.sliding_window"] = p.SlidingWindow
scale := float64(p.MaxPositionEmbeddings) / float64(p.OriginalMaxPositionEmbeddings)
switch p.RopeScaling.Type {
case "":
// no scaling
case "su", "longrope":
kv["phi3.rope.scaling.attn_factor"] = float32(max(math.Sqrt(1+math.Log(scale)/math.Log(float64(p.OriginalMaxPositionEmbeddings))), 1.0))
case "yarn":
kv["phi3.rope.scaling.attn_factor"] = float32(max(0.1*math.Log(scale)+1.0, 1.0))
default:
panic("unknown rope scaling type")
}
return kv
}
func (p *phi3) Tensors(ts []Tensor) []llm.Tensor {
var addRopeFactors sync.Once
out := make([]llm.Tensor, 0, len(ts)+2)
for _, t := range ts {
name := p.tensorName(t.Name())
if strings.HasPrefix(name, "blk.0.") {
addRopeFactors.Do(func() {
out = append(out, llm.Tensor{
Name: "rope_factors_long.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor,
}, llm.Tensor{
Name: "rope_factors_short.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
WriterTo: p.RopeScaling.ShortFactor,
})
})
}
out = append(out, llm.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *phi3) tensorName(n string) string {
return strings.NewReplacer(
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.qkv_proj", "attn_qkv",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_up_proj", "ffn_up",
"post_attention_layernorm", "ffn_norm",
).Replace(n)
}
type ropeFactor []float32
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
err := binary.Write(w, binary.LittleEndian, r)
return 0, err
}

View file

@ -65,8 +65,6 @@ func TestConvertFull(t *testing.T) {
"Mistral-7B-Instruct-v0.2", "Mistral-7B-Instruct-v0.2",
"Mixtral-8x7B-Instruct-v0.1", "Mixtral-8x7B-Instruct-v0.1",
"gemma-2b-it", "gemma-2b-it",
// microsoft/Phi-3-mini-128-instruct@d548c233192db00165d842bf8edff054bb3212f8
"Phi-3-mini-128k-instruct",
} }
for i := range cases { for i := range cases {

View file

@ -1,225 +0,0 @@
{
"general.architecture": "phi3",
"general.file_type": "1",
"general.quantization_version": "2",
"phi3.block_count": "32",
"phi3.context_length": "131072",
"phi3.embedding_length": "3072",
"phi3.feed_forward_length": "8192",
"phi3.rope.scaling.original_context_length": "4096",
"phi3.rope.dimension_count": "96",
"phi3.rope.freq_base": "10000",
"phi3.rope.scaling.attn_factor": "1.1902381",
"phi3.attention.head_count": "32",
"phi3.attention.head_count_kv": "32",
"phi3.attention.layer_norm_rms_epsilon": "1e-05",
"phi3.attention.sliding_window": "262144",
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.pre": "default",
"tokenizer.ggml.add_bos_token": "false",
"tokenizer.ggml.add_eos_token": "false",
"tokenizer.ggml.bos_token_id": "1",
"tokenizer.ggml.eos_token_id": "32000",
"tokenizer.ggml.unknown_token_id": "0",
"tokenizer.ggml.padding_token_id": "32000",
"tokenizer.ggml.scores": "6e37bcde2adc7e350e87c496eddd7a2124329c1dc66c5bf3ad3997253e4f7a62",
"tokenizer.ggml.token_type": "b6ecf55ec64ee67d87750bdb8d757a2c58bf78377e9f4219f5689a6c4dea57ce",
"tokenizer.ggml.tokens": "d168da3ddd3eee820916945fcb9baf24dd3cde42f606cffa2d19e7c8a8743918",
"blk.0.attn_norm.weight": "216aeb2c9e0c271f899e1ef2a63cceeb8f41e97642e84fada54b1d3c1c11cf25",
"blk.0.attn_output.weight": "b597d56f7188ffc1fafc273fadc59d41738cffd677ae98c61a62c3285b3a3099",
"blk.0.attn_qkv.weight": "d28a6b44e13f59be5483e4be2bedb544e346168d720aca27f47d1a5a722be91e",
"blk.0.ffn_down.weight": "4a691370e5a61fcbbf540fbcbf4c0f1d15dec0364528c0e916d0744f6262b63b",
"blk.0.ffn_norm.weight": "0c00af2b4a3128bec64a0cbb1084b042fdbe13d9ad0d03bd577f9449dfead338",
"blk.0.ffn_up.weight": "b32b52f790c1c083bfb8a3126dc1111cfeeb28dc8c584a930a1e5334cb176bf4",
"blk.1.attn_norm.weight": "68748011503c6c029e8e69a84a8e5a89338f378769627b6dbf7f93d715c292e1",
"blk.1.attn_output.weight": "2267344add13b048ca59e4377c86dc512be8046a57156901fa32a20fa74e4ee0",
"blk.1.attn_qkv.weight": "9109d2e3d7a2eacfda5226587b8be124a3bf44b972da7ebb17aa15795897eacc",
"blk.1.ffn_down.weight": "d675df4df4dd039c0c339ad6445d39eddd2004db6bf35bed6314c7497245a633",
"blk.1.ffn_norm.weight": "3b5767ae977bc8baaa06b06efdbea193b6b3ba605ce76d77a76ce317e935500c",
"blk.1.ffn_up.weight": "80dfd6d9d234b00334c89b8e0a02f81899c2efd377321c34ba5ba51a5f61b5ff",
"blk.2.attn_norm.weight": "6a6743b057e5088f145bc179e92c9bfb41163e7295d7b81c62e23dd89d2b59c4",
"blk.2.attn_output.weight": "bc5491ea54e0db81462d7d9b7d25cbdda380c2db8de041bd1c4ab7b76a1d19c3",
"blk.2.attn_qkv.weight": "a61287a9852e2f5aca9c100b471d98398b2913a3497c743de3c70ec9ddd7087f",
"blk.2.ffn_down.weight": "4fddcc382c8dceeab027fe43d8d44e67edb5e8ce4b9a1b7f773c87770380ade1",
"blk.2.ffn_norm.weight": "07e05f82b3f63f711db3b684ca79aed25c0657917e66f88af47348a82065c227",
"blk.2.ffn_up.weight": "4835a682ef1826c12df01ae7663fc45f9c82bc8e64b665f13fb7da8e201ec0fb",
"blk.3.attn_norm.weight": "f22aba7c03999ba7136f39cda747a39715e498699dc1716cd97fc5dfc58d1b1c",
"blk.3.attn_output.weight": "53b579855366fd786c5126b2b30aac4d583ca7bda56833c4865f5cadb5c18c6d",
"blk.3.attn_qkv.weight": "bb56aba78158123140fcea59c69ac562ca208f6d3086819417cdad8c50f333ad",
"blk.3.ffn_down.weight": "97280897a7cd86db2830c004bccc5bc094f50e293baded0189159a2019145a6e",
"blk.3.ffn_norm.weight": "10a8c99f8b57a960e8e0a1133c4a26f9148403d1b9bff2eff114917de996f3b5",
"blk.3.ffn_up.weight": "7324046c915e75d621b2043597a245a428d8eea31869135e6257a861491d8dcc",
"blk.4.attn_norm.weight": "507d8e164de94646edbfe33def8e8fbf7c9a6ee3fbaedb5000f72d9f51ec5e36",
"blk.4.attn_output.weight": "bbb3429e6efa98c150e0fdbf48c16180cbf0d0cbc1b3c253c6c319d78f4593a2",
"blk.4.attn_qkv.weight": "b95ee5be0786d3901273d806c339fe6c20e6bfffd2a20672a9f56af80921e8ab",
"blk.4.ffn_down.weight": "806bbf91df92a5a22bd5aa1ffb7fc2869f7293ffc7704771c290ecc583b27975",
"blk.4.ffn_norm.weight": "cfc2930a81df7aee3a5e7f726a15c1182233e868bf0d9d37f6b6ae6d8c15c234",
"blk.4.ffn_up.weight": "c3390c69533de2c8424e8069323ccc5d0c4543111535da04cf2c7d26745576aa",
"blk.5.attn_norm.weight": "0d71c4fbcefabbd021569442853d2fe90668b19409ae2805a718a829ca60beab",
"blk.5.attn_output.weight": "10ebd93629112bf2df5c30dd0953a4a5e9020306768283181ed426934d47e14f",
"blk.5.attn_qkv.weight": "5cb05633369f12d4b00e0ff787736bd846856682115720ebc6cce05270c334f6",
"blk.5.ffn_down.weight": "e28bcc5094212eafc7476dbc5b7a520d25b79578cbf4229d698e2655956a80ad",
"blk.5.ffn_norm.weight": "b6f2c4cf9f34bb4d59989f96165c14a67dc1e266ad0a6d0fcc49f1add929e6ff",
"blk.5.ffn_up.weight": "0f9ef99423cc07ebedc0e9cfa95809f2d7108d910bb4ef97ebc0b0309c440750",
"blk.6.attn_norm.weight": "b3edcc47a42218234f7564d7470611b49401a41ae8cd42123f86557c69f5d7f2",
"blk.6.attn_output.weight": "eb9b7d257b388bb5b8fe0515e5c6873317239cb94cda236e4b6ada2a6c57c65c",
"blk.6.attn_qkv.weight": "eb968081f478c52f07bd9c2761741e982dba33cc4eeadeea3557d391b9ac2106",
"blk.6.ffn_down.weight": "1b8588bb7463206290322695577dcfced300895d6e6f4b26966c53a9ae2f0f84",
"blk.6.ffn_norm.weight": "1219c04b7770983c77814200eefe743f46d15328ea2b12711e44f8103eab08d3",
"blk.6.ffn_up.weight": "197ef287239fec47c55677f0fbb66eaf0644f775bc382de843971730721394f6",
"blk.7.attn_norm.weight": "b630ad08c80d564ed1c024384818e9fd3f22a36cd7a14aa96e7e2759a8285099",
"blk.7.attn_output.weight": "970255aa750828a47d6b9d399f9612b5bf25aefe7dadbcba41fc416d0d4067c1",
"blk.7.attn_qkv.weight": "ebb157c880293e6de8d629f263ba8853ed1dbdc02c311d43432bb8cfbb310739",
"blk.7.ffn_down.weight": "24bcd4db4cba844c89f878b81843c373dbbc0675e889d32c5b12e63384a7b670",
"blk.7.ffn_norm.weight": "b9c6f71001808ee873ce7db8056e4b53fb4cccec8b7f0f312899b575fae39d39",
"blk.7.ffn_up.weight": "979f1828d227455c26015a2a11afe9dd05f2bb97a8ba6b38c8dab3f50e627401",
"blk.8.attn_norm.weight": "4e8e347e3775010b7112ee630f2f4f2383be7ff64e6ca6154b9b22566552eaa6",
"blk.8.attn_output.weight": "65a44babf44a435a1829945211b3168f9ec78ac3cb7a049a733e93d11f0d6659",
"blk.8.attn_qkv.weight": "343ed07671da400b040812a4058482fa38284b5d9af9becfed07417fe26ce747",
"blk.8.ffn_down.weight": "7fb7e073e3c2c503c4e9d60efa0988fed7398d900cc003695fe3fffd3e188b82",
"blk.8.ffn_norm.weight": "b07c1f655d8593e3892a2cf73f8a0c19ce8e5cb613fafbe7cbd430da8ce4c57d",
"blk.8.ffn_up.weight": "8b26e14de54b3fdc2e2d3ea41720f9d9c236a93688c3b7fd7bf43f5fbb327c9b",
"blk.9.attn_norm.weight": "46394d408a8e316916177e6aa261de32e137a82d729c0b1800b072f0c38c39b6",
"blk.9.attn_output.weight": "d57f3d46107947a7073373a0b35d6ecf7759b5df15406f4a3590a60666af6b16",
"blk.9.attn_qkv.weight": "14bb8ace8c5453148f4b536e9f4279c813f31136716947256f5cca333448639c",
"blk.9.ffn_down.weight": "2b8d98e2b5ed68338f6e4de43bf7de0c4858cc69103cd5177725f7444eec7694",
"blk.9.ffn_norm.weight": "41a499dfd418cc4c6b8c12313f673f7e2cd4a3f9c4065eb6c4feb5eed02fb542",
"blk.9.ffn_up.weight": "143aab7533a64b17fbe201490a6f674bc7f0bd370c094500b2e100419073d1c2",
"blk.10.attn_norm.weight": "ebb670aafd36816a794347287269d8f1a5b19c1e3c0a1e38023bc19fdba9b073",
"blk.10.attn_output.weight": "b5d65bbc0ed5e49fdd9d754bc18163cd042a285024d0cf6f954c503bc8c877cb",
"blk.10.attn_qkv.weight": "f06b15bac88da798fa34a62b03eaac0dbe8b846020516603c387541f2d8dd672",
"blk.10.ffn_down.weight": "fb091fcd1b4de25d1bea94d1755e255cb02914a030d23e3a234e57b8d46bde6e",
"blk.10.ffn_norm.weight": "eb347bdf9c40414af87e13a8e72e40b31f004b50f7cb366f1a219ced60a61355",
"blk.10.ffn_up.weight": "ed2d52fc881a173f404fe8a1067862c9856d6c3e0d2e90a330a7aa394e3f84d1",
"blk.11.attn_norm.weight": "64e252603cf010a0e502ca39fdf8d0a196a79aec67c0d2bb9213fc0cb80c47d4",
"blk.11.attn_output.weight": "228e33e21c69f52efc74fdfc831bc9af271e44b2a29a3dced1d64e667ce36eb5",
"blk.11.attn_qkv.weight": "ab9ce6d4ef9e42ee0da3f20a7708a3bbc5e79e967b05fa86ba946a05e2eb63eb",
"blk.11.ffn_down.weight": "0ca133b7835c98dc77c25d64e4eb7873778bdb5e4d22d8b80f920f46865b43bd",
"blk.11.ffn_norm.weight": "02455741a0dfd161c79aa1ecc381901721f229fdcda5615622a629631fb61cfd",
"blk.11.ffn_up.weight": "9fecdcc099fbb8e23c6b1ea9294702a027f4a58d265543ec5e7be79b8f63b354",
"blk.12.attn_norm.weight": "783bb459911b1b3609a9b2bdfe272f1670add73b5471da738e07ac47e2e07dfd",
"blk.12.attn_output.weight": "1e1a914c9e48b857206ac5a1f7cead994bc1ea91d5d4fff8c834d73f2e38ef5d",
"blk.12.attn_qkv.weight": "5953e7185ccb87fb4dae8f9426ec86315d4c7794326e8ab59b3a95d4af2189f0",
"blk.12.ffn_down.weight": "a3eecf0f394f86e2cfb48a5940a5c50ca86d71883b2f79fcc642a935fabce0d4",
"blk.12.ffn_norm.weight": "0a4272e41373c23bd72f10d2d82930aa3a1480aac75832bfbf01cebf0b86b6a4",
"blk.12.ffn_up.weight": "06f42776de3a7ceac3025f26a7a8bd20e062233cce2bdaa2183470dc4b30b87d",
"blk.13.attn_norm.weight": "5915da60fb03e201fa649faba780e5fdf1c761c262b206e5415cf83181f65780",
"blk.13.attn_output.weight": "4dbf6eab074fa3835fd32bd631a8208e511037d5056d2fd3015735cca7674ef7",
"blk.13.attn_qkv.weight": "d3d8339a1c4782d9e73d77fdebe154d3c5b83ac40c9175b3e91a4977d08f876b",
"blk.13.ffn_down.weight": "de6772b46a55e1fd42b007637dfbf68b6598e5d5b61622da0935002e1e192d3a",
"blk.13.ffn_norm.weight": "5a640ea3b8c7be49c95a58a2327e10d8e8d9d142504bde5c8091613e5b961d7a",
"blk.13.ffn_up.weight": "f35e3545e4bd3531b2e843b5efd31dee0c13c807ee6386e65473ba67bbec30d0",
"blk.14.attn_norm.weight": "9b34986450b7c98b4927e81e61a816f9e84b1addc7c14926402100037aad6678",
"blk.14.attn_output.weight": "155d52efb23d366016d861a251d4d1f4a0c13699188c50d50dba016a0d8bfcd9",
"blk.14.attn_qkv.weight": "8e1415084e1f33c73a777f19e752489f4dd312cca047733e5ea643cd4a955e04",
"blk.14.ffn_down.weight": "a2a142226b94baa01ccb65bdea2b7418e49085c1d9c3c63e544e3112c58a25da",
"blk.14.ffn_norm.weight": "8aecfd9b0ae6affaea31a80c5c9a4a14b31deaa0db7bd8f6da2a64d23447921c",
"blk.14.ffn_up.weight": "0c1407237b8c1bd02f193346b5681926fe698a5055eac6a7450451b0f991707c",
"blk.15.attn_norm.weight": "e037bd19880bfa83d983200fb0c7866f8ad16c3ff5cc4b4f3a37ca7373870ff6",
"blk.15.attn_output.weight": "045fe4fc95cc129a1b92771b179c11b12845c4c088786c607f17bd98857e68e1",
"blk.15.attn_qkv.weight": "7621b7559705cab1d4dea1c69f76dbf9dc1c8837a203b656f484703b9c1b70ce",
"blk.15.ffn_down.weight": "7e5ac20e290bc60761e1cd972354fde225b7fa861048d44d9a0dd9b046d55f58",
"blk.15.ffn_norm.weight": "b6d830d88f1db1825687973c8c2b1a24c6fa84f07af8d0e3ef9c86009baca0b2",
"blk.15.ffn_up.weight": "dcda0957cd04fc45476774dba2bbf9aa89d6b05d5ca7b10ae6f73ad2c49b1cd3",
"blk.16.attn_norm.weight": "4ee9b70ba15cb2a08240f93990e90f5068c48fceb481f8e2186bec8b7214eb3f",
"blk.16.attn_output.weight": "315cfe5536658d2498192b2980eade15b2c9a4ff220e4011911457b1727fa103",
"blk.16.attn_qkv.weight": "3c8122e3ad637583b9dcde8ff3a323267d3014bb1f0f9771e5322260ca9ecc8d",
"blk.16.ffn_down.weight": "3b5fbebd5ee2b86cad96fb8a9b45a8770d08f82c1c8b74d7061e866f7020a18d",
"blk.16.ffn_norm.weight": "ffab69f20bda372de6e5878f0539163e2fc6ba113621ded95705fc3b1465c9f0",
"blk.16.ffn_up.weight": "0935ea3d258da42d6258406365f39f58ddaabfe97ea5977580db3635188f24a1",
"blk.17.attn_norm.weight": "f030441733f3d147b4a06a1eb4aeb8465c7c24d9c53bf4c48fe7e134d3629803",
"blk.17.attn_output.weight": "07a955ef09e8dc766ac0df647d0b2c69f23c4c69a7137654b4aad80303ed0eda",
"blk.17.attn_qkv.weight": "1c10688061e21e2fe12ad0cb54bf03895c1f83c3b0df743a42f548b52cbca1b2",
"blk.17.ffn_down.weight": "ebb9cc9836f41d88fdae2aa9a4355514e4edaec8d1577ffeb947a35204e77f52",
"blk.17.ffn_norm.weight": "50aff44f6528b13db5389f2ddcdb7676244947610bd7ffbff3f881c968c2a0d4",
"blk.17.ffn_up.weight": "d716537949582be33bde6b02e38f5a70081c9642a9fb05a61312126718b8d148",
"blk.18.attn_norm.weight": "0ea695c4e53d637902f46663a6ee42adc493c36794476acc7dbddaa05b13840d",
"blk.18.attn_output.weight": "5fd35b500221a612eb4f4bddf0e9b6b7db4d7733032a75f8802fb2d884647c2e",
"blk.18.attn_qkv.weight": "b0da37fd030fe69581f990bf23bfd35467a1bbe558af6de7c0924f6b72e92317",
"blk.18.ffn_down.weight": "b355c33f44b328f4bb977567de8f7544db4b005d7a8fbded658518ecf3c5a153",
"blk.18.ffn_norm.weight": "58b3fe9094079989a86e0387143259e1cc35952d24dc3df290c4ba6df44f5c51",
"blk.18.ffn_up.weight": "2ce530954c342c30ed2ead5353f931960bfae1d278868504c0efb973560fabbe",
"blk.19.attn_norm.weight": "533e9aed66feea8f0392aa81f9e293240e1f009a5334253915fb60c2749b615d",
"blk.19.attn_output.weight": "84f2d00f98a4113a779d3b5d1c3e7c914eb47784d3ab13b290367c124c2994aa",
"blk.19.attn_qkv.weight": "fbe6b9f53b07fa7537d3b3d452d20a9bc666f9fd41ec2091dd28bc2f70fc668f",
"blk.19.ffn_down.weight": "b30199e098c8bb3f890183d8b18471e80b62b604729b277ad62488dd71e1206b",
"blk.19.ffn_norm.weight": "c81373e41cd340b7badb19f9517c77c4250b4eb9a02dc758b8b49b652487d7ff",
"blk.19.ffn_up.weight": "5a5cb083ca7725720e3a890f7fa46354760e8007a8188849a092e305694a75e3",
"blk.20.attn_norm.weight": "4953091b4477e354357a8e743ba0a1900633e52f1599ee082a0c9b0b2b5cd978",
"blk.20.attn_output.weight": "62d54f7749cd6856097b2632066a322b0296df915fe66f382c5b5981be0d4f23",
"blk.20.attn_qkv.weight": "406de9e35b0729ebe902d7a47905cc7fb29a921431ed35dbef0c03e5690a1329",
"blk.20.ffn_down.weight": "62fb678b0d1261e19a4903a2b347d67afcc8acff01feb33a687a35a2d1e6f9a5",
"blk.20.ffn_norm.weight": "cd9d36b7e71e55c8925b97bb09c28219f182626bcff094878ae39c3db887a14b",
"blk.20.ffn_up.weight": "b9276771d79d3e932e73ccc520c3f8476342b9ef312ed2ee1e0da822e6e3ad18",
"blk.21.attn_norm.weight": "66d8c8a35e13ce9c2a0e75b670150e2c31484a55c2316df46075312196178ed3",
"blk.21.attn_output.weight": "12ab46c9382648f9b3350fdd92a6be6352743d62d6b520d7e2024e0c838588f5",
"blk.21.attn_qkv.weight": "a7909676ee1675ca23cd29a5fdd226df8dd9d68f94c6c9bbb51dd9fd38504008",
"blk.21.ffn_down.weight": "6fb317279c6542e82f97d5a12a60fac1bd0fa0405154f9fbe265e2fe39bd49cc",
"blk.21.ffn_norm.weight": "c0f703eb3ff161b5ba4490d87d8684b8a6c47a8f433e12f418333b9db439010a",
"blk.21.ffn_up.weight": "6dbdb80ef0c35e364bbce12d40d5e74c7963c7b55d58d9579567a07ffce7b863",
"blk.22.attn_norm.weight": "f94237433bf03d675cb2f655b81ca91a1ce2447bc6b00b13d6b0ccfe2d411eff",
"blk.22.attn_output.weight": "e821f95995ce497c01e63ca64f737713b1b65f11df1903e51d444aa516f33f71",
"blk.22.attn_qkv.weight": "1b0f717c73afb5eb4c82a1708c4e85c969e8a2a8770d9ddb78b1870a2d8a781e",
"blk.22.ffn_down.weight": "0f33f7a3cdc685484be99aa0c03642b0b20850a27d1fddbe054b13a9382f3ccb",
"blk.22.ffn_norm.weight": "9df285cf211ddd7df2b36a50489af574755c7d4d98b29a05cd04566ae613c8dc",
"blk.22.ffn_up.weight": "63ac300e1efb34041dd0136cf43ea622fac6f0caccce1cd9262f5e08d2cf179c",
"blk.23.attn_norm.weight": "5f72d9e88689b4027b28f5f8f26cd3abb03635ceea7ec98a4c91a9fc691f6707",
"blk.23.attn_output.weight": "6ecf04ff61125c5fc768f8656497152149373daf321ee9c957e8f7245a1184d1",
"blk.23.attn_qkv.weight": "a9d9978806724c2959f2cf386c233831f08e1e933dbf2b32665e788d9d512ea4",
"blk.23.ffn_down.weight": "72c7d17886a3da17fa0daa456aa5e877b2ef5b8b403182b870d9ca5ca9c70347",
"blk.23.ffn_norm.weight": "971e4b712e3025a13419b5b57d674b5e4ab7f18f74b57b9afc4671623da90c4b",
"blk.23.ffn_up.weight": "df2b5c7dbd5834545b815073af0c7355b065124e6d6f0fee78d8fa5b2076dc3e",
"blk.24.attn_norm.weight": "c41957c4a79ad3b16f6e11daec1c7f530b9f3f4b618e1e4367c3b67787ac4ab6",
"blk.24.attn_output.weight": "ef7d61f5fc88ac6f31bf60cb5f4d2d6b8df42d38825807112361a7224b0dee3b",
"blk.24.attn_qkv.weight": "3e6a58fe7d49c90bb6971efbad3371c32256881173ea5aee4b0c296cb206490f",
"blk.24.ffn_down.weight": "f43619144047de42fed81dfa495f1815d3cb771330e574043e2b67620819292c",
"blk.24.ffn_norm.weight": "5501d4a2a98c8ca6b42e77b53b221dbc08f530f6a067256d787534ec6fe028bd",
"blk.24.ffn_up.weight": "d64c8b0e509e2b1118f6000176f8956cacecdbb200c7e95ed93fb78b6e26c84a",
"blk.25.attn_norm.weight": "502fa3c302d371f61c5791f4615b73018ffb1daa09b6499b227116581244c5d4",
"blk.25.attn_output.weight": "ad8391d4e9c980856f2547aa945b2b6a407a6382158dc1ddd4f08d94ecc24be6",
"blk.25.attn_qkv.weight": "42e8983780d4a01a02c54ad23d4df21eea437f119a10af5a9c12a76a42d308c1",
"blk.25.ffn_down.weight": "302dd010d4e0ab4eeaee89090409ea0dddeeeed3236415eb8f97c942497eea91",
"blk.25.ffn_norm.weight": "fb34c1ee5bca96986c08834df0a0c047ba041c1123ac1f563e9d64312bf82d6a",
"blk.25.ffn_up.weight": "10739a8de156816d93c92b935386540bfa976bdbef204f0312960f6fc657582f",
"blk.26.attn_norm.weight": "7036c711609128c4e55968ff3681d3043338879a5737efd6c2ac9e1a2a61f1a0",
"blk.26.attn_output.weight": "db5db45dead5cb911fa01da59832f121b7c18b2d167bf53741c40819f24d346c",
"blk.26.attn_qkv.weight": "cae34c6b7f82ed14348d5ed30a79919c383737c1694a9cb9c0de609d3b0c1d0a",
"blk.26.ffn_down.weight": "491ec3a4da9b4f49f8ebc6be658ce397a9b801ae9fb35e82177e47808c65e5d0",
"blk.26.ffn_norm.weight": "fd7059d75d7f0e5288511ddeeb0f772eb3cae3ccfe4226b877015834edc3c386",
"blk.26.ffn_up.weight": "ea1ee1274c56458ce056d2205e5bb6e5422ce4cb0ad58006b8141749b97a0c39",
"blk.27.attn_norm.weight": "cc362c9a937609265052cd38544af17a1a7448cea086d4c801139e1fc865832d",
"blk.27.attn_output.weight": "ba757a81dabde9cb1b069d1bb616fe79649a1724f756567ec61caed1304fe6cf",
"blk.27.attn_qkv.weight": "1ab8d7d02d87756c12c2275636823aa5ede3d683178225c4cac4bd892c319bd4",
"blk.27.ffn_down.weight": "deb1c711c8a66acf4dcd2d088e1548f8e08f296f755e4067d6557fa55afde88c",
"blk.27.ffn_norm.weight": "fc6242d8cb8a4a37a8ddb7e41e7e60a63d4a89edf36acb35df052f10b9c91ece",
"blk.27.ffn_up.weight": "8df39b09c4801f343aca78f2918a1f6db78c8c55e591eda4c69eadb74c26e180",
"blk.28.attn_norm.weight": "75b539308f77e3cefdc6d98484d8b5cbf0538f0c2869a77b7373a145a18bc850",
"blk.28.attn_output.weight": "ae128940eb60a6d2e121762ef4b3e9dcf9eb3e105b249507fa7f12de0e19822c",
"blk.28.attn_qkv.weight": "bdda781c288e9326c240e33905f8e621b6a2ad902e620739d34f93fcd6f933de",
"blk.28.ffn_down.weight": "f1d6e6d1c286b1138bfd7e53fe477f399ae93bc2c04e35416f84218ed7247965",
"blk.28.ffn_norm.weight": "3f837ce82c8b9bde0d61d08b6f5fe5574886ea5328dbdc53f2929f18da8b4087",
"blk.28.ffn_up.weight": "2af027002e31d1b6cfedbdb30a2b9d7213f3aa691167c353913adfd48fda31e4",
"blk.29.attn_norm.weight": "61e8003b5329462ffe0fe172f2b160260de006aed858332d49d75504b6b6aa7a",
"blk.29.attn_output.weight": "ca44542a72a37476dc73dbdcc01f5b7497cb3ebc4ea230a55c9634ccd8e56ad4",
"blk.29.attn_qkv.weight": "abb3d9d6abe57872ae3daa51935d43264093ded5ce63b49d1e280ee5758be0e4",
"blk.29.ffn_down.weight": "6764b895fce881df097489c263446f0106de36217997660c15984b3ee22a5a06",
"blk.29.ffn_norm.weight": "89e03e9a33fc0e6e31ba9f0c2bd7c5734a118c5602bb90148793e08a80e8d0ae",
"blk.29.ffn_up.weight": "fa7ad57a84954f4121653152efed1a871d8adb20a1ea9086e3e849ce359d7d2e",
"blk.30.attn_norm.weight": "91a697aca1e42af54f806a20211031c3369e8d0bd58df1b0147fe24954e1f5a4",
"blk.30.attn_output.weight": "36063fcf766c89ac75be56f688cc63cefe5f2c733fbf4378ea9956ad386fa148",
"blk.30.attn_qkv.weight": "2cacd1161f1121a2c0b979930134f4666f73fb8d7237b3b0659ae091b15955a6",
"blk.30.ffn_down.weight": "9f3fcb6217100595850c05dc98f9ab2a263afdb6ab28df2fcb08aeff512057d7",
"blk.30.ffn_norm.weight": "6c600bc1fc7de39d4f8917b81fc7d1d5ed2a9b56492234c13a4bd6028c30d880",
"blk.30.ffn_up.weight": "73cabd1bb011956b2689ea3338bb76642ef3a57c197377d666d2ab5f56317668",
"blk.31.attn_norm.weight": "72d3e1cc771380645fa75a899858c95f39857a4f3f1ed60fe1578df383b8bc53",
"blk.31.attn_output.weight": "40089cdd29994dc19a1d89fa15902a89cfeca3540f12dc9bf4d00ef82506e456",
"blk.31.attn_qkv.weight": "1d0bb40e9258071ae14290a53c619a8e331dda07354d2a02ef45766c029ae5e4",
"blk.31.ffn_down.weight": "8defa0e06335b793fa8be03883f0a322d6c5b33f52c69c943c35c60d16e42c0a",
"blk.31.ffn_norm.weight": "33c55d9d0c496ccfb130361fe131649346e098abaaac39c0519507e5d846721d",
"blk.31.ffn_up.weight": "599f6503f61c692c1f82001973d35119f9688db5e6be9d9c298411491c93f09b",
"output.weight": "14b8dc662bfa3308ebb2e102c562d8e52c15670e538f20f3216a9c310ca9dd41",
"output_norm.weight": "7f2294ba94ce65681df6c7ddd8698799199b9d77dc83c10bdad5c3999f0fdb82",
"rope_factors_long.weight": "e34d378664e354652c38f47d10dafb0498ccc2fb042d39ff7fef768146fff22b",
"rope_factors_short.weight": "9379146a4988f373d362fe47b06c75e7fe7c54aa4dc9558758df79b7a87471fd",
"token_embd.weight": "19a03c1fb5ac0baee93b0a7d8b0f26e9a9b011e229b694afc50ebfc13d84f8bf"
}

View file

@ -669,7 +669,7 @@ curl http://localhost:11434/api/chat -d '{
``` ```
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama3.1", "model": "mistral",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -708,7 +708,7 @@ curl http://localhost:11434/api/chat -d '{
```json ```json
{ {
"model": "llama3.1", "model": "mistral:7b-instruct-v0.3-q4_K_M",
"created_at": "2024-07-22T20:33:28.123648Z", "created_at": "2024-07-22T20:33:28.123648Z",
"message": { "message": {
"role": "assistant", "role": "assistant",
@ -1175,10 +1175,7 @@ curl http://localhost:11434/api/embed -d '{
"embeddings": [[ "embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814, 0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348 0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
]], ]]
"total_duration": 14143917,
"load_duration": 1019500,
"prompt_eval_count": 8
} }
``` ```

View file

@ -16,9 +16,7 @@ If the model being imported is one of these architectures, it can be imported di
- LlamaForCausalLM - LlamaForCausalLM
- MistralForCausalLM - MistralForCausalLM
- MixtralForCausalLM
- GemmaForCausalLM - GemmaForCausalLM
- Phi3ForCausalLM
```dockerfile ```dockerfile
FROM /path/to/safetensors/directory FROM /path/to/safetensors/directory

View file

@ -182,6 +182,7 @@ curl http://localhost:11434/v1/embeddings \
- [x] Reproducible outputs - [x] Reproducible outputs
- [x] Vision - [x] Vision
- [x] Tools (streaming support coming soon) - [x] Tools (streaming support coming soon)
- [ ] Vision
- [ ] Logprobs - [ ] Logprobs
#### Supported request fields #### Supported request fields

View file

@ -112,9 +112,15 @@ Keep the following tips and best practices in mind when working with Go template
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2. ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
```gotmpl ```gotmpl
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }} {{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|> {{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
``` ```
### Example Tools ### Example Tools

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/ollama/ollama module github.com/ollama/ollama
go 1.22.5 go 1.22.0
require ( require (
github.com/containerd/console v1.0.3 github.com/containerd/console v1.0.3

View file

@ -49,9 +49,13 @@ func PayloadsDir() (string, error) {
} }
// Track our pid so we can clean up orphaned tmpdirs // Track our pid so we can clean up orphaned tmpdirs
n := filepath.Join(tmpDir, "ollama.pid") pidFilePath := filepath.Join(tmpDir, "ollama.pid")
if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil { pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
return "", fmt.Errorf("failed to write pid file %s: %w", n, err) if err != nil {
return "", err
}
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
return "", err
} }
// We create a distinct subdirectory for payloads within the tmpdir // We create a distinct subdirectory for payloads within the tmpdir
@ -63,44 +67,37 @@ func PayloadsDir() (string, error) {
// Best effort to clean up prior tmpdirs // Best effort to clean up prior tmpdirs
func cleanupTmpDirs() { func cleanupTmpDirs() {
matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid")) dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*"))
if err != nil { if err != nil {
return return
} }
for _, d := range dirs {
for _, match := range matches { info, err := os.Stat(d)
raw, err := os.ReadFile(match) if err != nil || !info.IsDir() {
if errors.Is(err, os.ErrNotExist) {
slog.Debug("not a ollama runtime directory, skipping", "path", match)
continue continue
} else if err != nil { }
slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err) raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
if err != nil {
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
// No pid, ignore this tmpdir
continue continue
} }
pid, err := strconv.Atoi(string(raw)) pid, err := strconv.Atoi(string(raw))
if err != nil { if err != nil {
slog.Warn("invalid pid, skipping", "path", match, "error", err) slog.Warn("failed to parse pid", "path", d, "error", err)
continue continue
} }
p, err := os.FindProcess(pid) proc, err := os.FindProcess(pid)
if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) { if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
slog.Warn("process still running, skipping", "pid", pid, "path", match) slog.Warn("found running ollama", "pid", pid, "path", d)
// Another running ollama, ignore this tmpdir
continue continue
} }
if err := os.Remove(match); err != nil { if err := os.Remove(d); err != nil {
slog.Warn("could not cleanup stale pidfile", "path", match, "error", err) slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
}
runners := filepath.Join(filepath.Dir(match), "runners")
if err := os.RemoveAll(runners); err != nil {
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
}
if err := os.Remove(filepath.Dir(match)); err != nil {
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
} }
} }
} }

View file

@ -305,8 +305,6 @@ func GetGPUInfo() GpuInfoList {
// Intel // Intel
if envconfig.IntelGPU() { if envconfig.IntelGPU() {
oHandles = initOneAPIHandles() oHandles = initOneAPIHandles()
if oHandles != nil && oHandles.oneapi != nil {
// On windows we bundle the oneapi library one level above the runner dir // On windows we bundle the oneapi library one level above the runner dir
depPath = "" depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" { if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
@ -342,7 +340,6 @@ func GetGPUInfo() GpuInfoList {
} }
} }
} }
}
rocmGPUs = AMDGetGPUInfo() rocmGPUs = AMDGetGPUInfo()
bootstrapped = true bootstrapped = true

View file

@ -403,9 +403,7 @@ struct llama_server_context
} }
} }
auto init_result = llama_init_from_gpt_params(params); std::tie(model, ctx) = llama_init_from_gpt_params(params);
model = init_result.model;
ctx = init_result.context;
if (model == nullptr) if (model == nullptr)
{ {
LOG_ERROR("unable to load model", {{"model", params.model}}); LOG_ERROR("unable to load model", {{"model", params.model}});
@ -1223,7 +1221,9 @@ struct llama_server_context
res.result_json = json res.result_json = json
{ {
{"id", res.id},
{"embedding", std::vector<float>(embd, embd + n_embd)}, {"embedding", std::vector<float>(embd, embd + n_embd)},
{"timings", slot.get_formated_timings()},
}; };
} }
} }
@ -2422,10 +2422,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapters.push_back({ params.lora_adapter.emplace_back(argv[i], 1.0f);
std::string(argv[i]),
1.0,
});
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-scaled") else if (arg == "--lora-scaled")
@ -2441,10 +2438,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapters.push_back({ params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
lora_adapter,
std::stof(argv[i])
});
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "-v" || arg == "--verbose") else if (arg == "-v" || arg == "--verbose")
@ -3192,17 +3186,41 @@ int main(int argc, char **argv) {
prompt = ""; prompt = "";
} }
if (prompt.size() == 1) {
prompt = prompt[0];
}
// create and queue the task // create and queue the task
const int task_id = llama.queue_tasks.get_new_id(); json responses;
llama.queue_results.add_waiting_task_id(task_id); {
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1); const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(task_id); llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
responses = result.result_json.value("results", std::vector<json>{result.result_json});
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
return a["id"] < b["id"];
});
json embeddings = json::array();
int prompt_n = 0;
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
prompt_n += elem.at("timings").at("prompt_n").get<int>();
}
// send the result // send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
}); });
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View file

@ -157,14 +157,6 @@ type Tensor struct {
io.WriterTo `json:"-"` io.WriterTo `json:"-"`
} }
func (t Tensor) block() (n int) {
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
return -1
}
return
}
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
switch t.Kind { switch t.Kind {
case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16 case 0, 1, 24, 25, 26, 27, 28, 30: // F32, F16, I8, I16, I32, I64, F64, BF16

View file

@ -532,14 +532,15 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
} }
} }
slices.SortStableFunc(ts, func(a, b Tensor) int { slices.SortFunc(ts, func(a, b Tensor) int {
if i, j := a.block(), b.block(); i < 0 && j > 0 { var i, j int
return 1 if n, err := fmt.Sscanf(a.Name, "blk.%d", &i); err != nil || n != 1 {
} else if i > 0 && j < 0 { return cmp.Compare(a.Name, b.Name)
return -1 } else if n, err := fmt.Sscanf(b.Name, "blk.%d", &j); err != nil || n != 1 {
} else { return cmp.Compare(a.Name, b.Name)
return cmp.Compare(i, j)
} }
return cmp.Compare(i, j)
}) })
var s uint64 var s uint64

@ -1 +1 @@
Subproject commit 1e6f6554aa11fa10160a5fda689e736c3c34169f Subproject commit 6eeaeba126ff701f3e8f79f246805b7023709972

View file

@ -1,32 +1,40 @@
diff --git a/common/common.cpp b/common/common.cpp diff --git a/common/common.cpp b/common/common.cpp
index 2e8374d5..70d0afde 100644 index dbb724fb..c26fe6ee 100644
--- a/common/common.cpp --- a/common/common.cpp
+++ b/common/common.cpp +++ b/common/common.cpp
@@ -2110,9 +2110,21 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { @@ -2087,14 +2087,27 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
if (loaded_la.adapter == nullptr) { const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); float lora_scale = std::get<1>(params.lora_adapter[i]);
+
+ // try to load as gguf
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- llama_free(lctx); - llama_free(lctx);
- llama_free_model(model); - llama_free_model(model);
- return iparams; - return std::make_tuple(nullptr, nullptr);
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
+ +
+ // if that fails, try loading as ggla for compatibility + // if that fails, try loading as ggla for compatibility
+ int err = llama_model_apply_lora_from_file(model, + int err = llama_model_apply_lora_from_file(model,
+ la.path.c_str(), + lora_adapter.c_str(),
+ la.scale, + lora_scale,
+ nullptr, + nullptr,
+ params.n_threads); + params.n_threads);
+ if (err != 0) { + if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx); + llama_free(lctx);
+ llama_free_model(model); + llama_free_model(model);
+ return iparams; + return std::make_tuple(nullptr, nullptr);
+ } else {
+ break;
+ } + }
+ } else {
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
} }
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters - llama_lora_adapter_set(lctx, adapter, lora_scale);
} }
if (params.ignore_eos) {
diff --git a/include/llama.h b/include/llama.h diff --git a/include/llama.h b/include/llama.h
index 93fd77ca..b0fb37a6 100644 index 93fd77ca..b0fb37a6 100644
--- a/include/llama.h --- a/include/llama.h

View file

@ -0,0 +1,20 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index a207451f..fba6b175 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4969,6 +4969,7 @@ static void llm_load_hparams(
hparams.attn_soft_cap = true;
switch (hparams.n_layer) {
+ case 26: model.type = e_model::MODEL_2B; break;
case 42: model.type = e_model::MODEL_9B; break;
case 46: model.type = e_model::MODEL_27B; break;
default: model.type = e_model::MODEL_UNKNOWN;
@@ -11736,6 +11737,7 @@ struct llm_build_context {
// ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
switch (model.type) {
+ case e_model::MODEL_2B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_9B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); break;
case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
default: GGML_ABORT("fatal error");

View file

@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, input string) ([]float32, error) Embed(ctx context.Context, input []string) (*EmbedResponse, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@ -49,7 +49,6 @@ type llmServer struct {
done chan error // Channel to signal when the process exits done chan error // Channel to signal when the process exits
status *StatusWriter status *StatusWriter
options api.Options options api.Options
numParallel int
estimate MemoryEstimate estimate MemoryEstimate
totalLayers uint64 totalLayers uint64
@ -125,9 +124,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
} }
// On linux and windows, over-allocating CPU memory will almost always result in an error // On linux, over-allocating CPU memory will almost always result in an error
// Darwin has fully dynamic swap so has no direct concept of free swap space if runtime.GOOS == "linux" {
if runtime.GOOS != "darwin" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available { if systemMemoryRequired > available {
@ -345,7 +343,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
estimate: estimate, estimate: estimate,
numParallel: numParallel,
sem: semaphore.NewWeighted(int64(numParallel)), sem: semaphore.NewWeighted(int64(numParallel)),
totalLayers: ggml.KV().BlockCount() + 1, totalLayers: ggml.KV().BlockCount() + 1,
gpus: gpus, gpus: gpus,
@ -883,15 +880,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil return nil
} }
type EmbeddingRequest struct { type EmbedRequest struct {
Content string `json:"content"` Content []string `json:"content"`
} }
type EmbeddingResponse struct { type EmbedResponse struct {
Embedding []float32 `json:"embedding"` Embedding [][]float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_n"`
} }
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) { func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@ -906,18 +904,18 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(EmbeddingRequest{Content: input}) data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err) return nil, fmt.Errorf("error creating embed request: %w", err)
} }
r.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(r) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err) return nil, fmt.Errorf("do embedding request: %w", err)
} }
@ -933,12 +931,12 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var e EmbeddingResponse var e EmbedResponse
if err := json.Unmarshal(body, &e); err != nil { if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }
return e.Embedding, nil return &e, nil
} }
type TokenizeRequest struct { type TokenizeRequest struct {

View file

@ -26,7 +26,6 @@ var errorPrefixes = []string{
"cudaMalloc failed", "cudaMalloc failed",
"\"ERR\"", "\"ERR\"",
"error loading model", "error loading model",
"GGML_ASSERT",
} }
func (w *StatusWriter) Write(b []byte) (int, error) { func (w *StatusWriter) Write(b []byte) (int, error) {

View file

@ -7,12 +7,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -20,9 +20,14 @@ import (
const ( const (
prefix = `data:image/jpeg;base64,` prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
) )
var False = false func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
@ -38,137 +43,135 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
func TestChatMiddleware(t *testing.T) { func TestChatMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
name string Name string
body string Setup func(t *testing.T, req *http.Request)
req api.ChatRequest Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
err ErrorResponse
} }
var capturedRequest *api.ChatRequest var capturedRequest *api.ChatRequest
testCases := []testCase{ testCases := []testCase{
{ {
name: "chat handler", Name: "chat handler",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"model": "test-model", body := ChatCompletionRequest{
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model", Model: "test-model",
Messages: []api.Message{ Messages: []Message{{Role: "user", Content: "Hello"}},
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with image content",
body: `{
"model": "test-model",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
} }
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
} }
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{ if req.Messages[0].Role != "user" {
name: "chat handler error forwarding", t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
body: `{ }
"model": "test-model",
"messages": [ if req.Messages[0].Content != "Hello" {
{"role": "user", "content": 2} t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
] }
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
}, },
}, },
{
Name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
},
},
{
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
}, },
} }
@ -182,26 +185,16 @@ func TestChatMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/chat", endpoint) router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var errResp ErrorResponse tc.Expected(t, capturedRequest, resp)
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
@ -209,52 +202,71 @@ func TestChatMiddleware(t *testing.T) {
func TestCompletionsMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
name string Name string
body string Setup func(t *testing.T, req *http.Request)
req api.GenerateRequest Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
err ErrorResponse
} }
var capturedRequest *api.GenerateRequest var capturedRequest *api.GenerateRequest
testCases := []testCase{ testCases := []testCase{
{ {
name: "completions handler", Name: "completions handler",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"model": "test-model", temp := float32(0.8)
"prompt": "Hello", body := CompletionRequest{
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model", Model: "test-model",
Prompt: "Hello", Prompt: "Hello",
Options: map[string]any{ Temperature: &temp,
"frequency_penalty": 0.0, Stop: []string{"\n", "stop"},
"presence_penalty": 0.0,
"temperature": 1.6,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix", Suffix: "suffix",
Stream: &False, }
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
}, },
}, },
{ {
name: "completions handler error forwarding", Name: "completions handler error forwarding",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"model": "test-model", body := CompletionRequest{
"prompt": "Hello", Model: "test-model",
"temperature": null, Prompt: "Hello",
"stop": [1, 2], Temperature: nil,
"suffix": "suffix" Stop: []int{1, 2},
}`, Suffix: "suffix",
err: ErrorResponse{ }
Error: Error{ prepareRequest(req, body)
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
}, },
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -269,27 +281,15 @@ func TestCompletionsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/generate", endpoint) router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var errResp ErrorResponse tc.Expected(t, capturedRequest, resp)
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
@ -298,47 +298,78 @@ func TestCompletionsMiddleware(t *testing.T) {
func TestEmbeddingsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
name string Name string
body string Setup func(t *testing.T, req *http.Request)
req api.EmbedRequest Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
err ErrorResponse
} }
var capturedRequest *api.EmbedRequest var capturedRequest *api.EmbedRequest
testCases := []testCase{ testCases := []testCase{
{ {
name: "embed handler single input", Name: "embed handler single input",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"input": "Hello", body := EmbedRequest{
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello", Input: "Hello",
Model: "test-model", Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
name: "embed handler batch input", Name: "embed handler batch input",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"input": ["Hello", "World"], body := EmbedRequest{
"model": "test-model" Input: []string{"Hello", "World"},
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model", Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
name: "embed handler error forwarding", Name: "embed handler error forwarding",
body: `{ Setup: func(t *testing.T, req *http.Request) {
"model": "test-model" body := EmbedRequest{
}`, Model: "test-model",
err: ErrorResponse{ }
Error: Error{ prepareRequest(req, body)
Message: "invalid input",
Type: "invalid_request_error",
}, },
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -353,167 +384,116 @@ func TestEmbeddingsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/embed", endpoint) router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var errResp ErrorResponse tc.Expected(t, capturedRequest, resp)
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
} }
func TestListMiddleware(t *testing.T) { func TestMiddlewareResponses(t *testing.T) {
type testCase struct { type testCase struct {
name string Name string
endpoint func(c *gin.Context) Method string
resp string Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
} }
testCases := []testCase{ testCases := []testCase{
{ {
name: "list handler", Name: "list handler",
endpoint: func(c *gin.Context) { Method: http.MethodGet,
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{ c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{ Models: []api.ListModelResponse{
{ {
Name: "test-model", Name: "Test Model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
}, },
}, },
}) })
}, },
resp: `{ Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
"object": "list", var listResp ListCompletion
"data": [ if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
{ t.Fatal(err)
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
} }
]
}`, if listResp.Object != "list" {
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
}, },
{ {
name: "list handler empty output", Name: "retrieve model",
endpoint: func(c *gin.Context) { Method: http.MethodGet,
c.JSON(http.StatusOK, api.ListResponse{}) Path: "/api/show/:model",
}, TestPath: "/api/show/test-model",
resp: `{ Handler: RetrieveMiddleware,
"object": "list", Endpoint: func(c *gin.Context) {
"data": null
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{ c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
}) })
}, },
resp: `{ Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
"id":"test-model", var retrieveResp Model
"object":"model", if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
"created":1686935002, t.Fatal(err)
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
} }
}`,
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
},
}, },
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases { for _, tc := range testCases {
router := gin.New() t.Run(tc.Name, func(t *testing.T) {
router.Use(RetrieveMiddleware()) router = gin.New()
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) router.Use(tc.Handler())
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) router.Handle(tc.Method, tc.Path, tc.Endpoint)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var expected, actual map[string]any assert.Equal(t, http.StatusOK, resp.Code)
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual) tc.Expected(t, resp)
if err != nil { })
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
} }
} }

View file

@ -3,12 +3,11 @@ package progress
import ( import (
"fmt" "fmt"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
type Spinner struct { type Spinner struct {
message atomic.Value message string
messageWidth int messageWidth int
parts []string parts []string
@ -22,25 +21,20 @@ type Spinner struct {
func NewSpinner(message string) *Spinner { func NewSpinner(message string) *Spinner {
s := &Spinner{ s := &Spinner{
message: message,
parts: []string{ parts: []string{
"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏",
}, },
started: time.Now(), started: time.Now(),
} }
s.SetMessage(message)
go s.start() go s.start()
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message.Store(message)
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 {
if message, ok := s.message.Load().(string); ok && len(message) > 0 { message := strings.TrimSpace(s.message)
message := strings.TrimSpace(message)
if s.messageWidth > 0 && len(message) > s.messageWidth { if s.messageWidth > 0 && len(message) > s.messageWidth {
message = message[:s.messageWidth] message = message[:s.messageWidth]
} }

View file

@ -62,7 +62,7 @@ func (b *Buffer) MoveLeft() {
rLength := runewidth.RuneWidth(r) rLength := runewidth.RuneWidth(r)
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width)) fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if rLength == 2 { if rLength == 2 {
fmt.Print(CursorLeft) fmt.Print(CursorLeft)
} }
@ -74,7 +74,7 @@ func (b *Buffer) MoveLeft() {
fmt.Print(CursorLeft) fmt.Print(CursorLeft)
} }
} else { } else {
fmt.Print(CursorLeftN(rLength)) fmt.Print(cursorLeftN(rLength))
} }
b.Pos -= 1 b.Pos -= 1
@ -115,15 +115,15 @@ func (b *Buffer) MoveRight() {
b.DisplayPos += rLength b.DisplayPos += rLength
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt())+rLength)) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
b.DisplayPos += 1 b.DisplayPos += 1
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Print(CursorDown + CursorBOL + CursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.DisplayPos += 1 b.DisplayPos += 1
} else { } else {
fmt.Print(CursorRightN(rLength)) fmt.Print(cursorRightN(rLength))
} }
} }
} }
@ -154,7 +154,7 @@ func (b *Buffer) MoveToStart() {
fmt.Print(CursorUp) fmt.Print(CursorUp)
} }
} }
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0 b.DisplayPos = 0
} }
@ -169,9 +169,9 @@ func (b *Buffer) MoveToEnd() {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
remainder := b.DisplaySize() % b.LineWidth remainder := b.DisplaySize() % b.LineWidth
fmt.Print(CursorBOL + CursorRightN(len(b.Prompt.prompt())+remainder)) fmt.Printf(CursorBOL + cursorRightN(len(b.Prompt.prompt())+remainder))
} else { } else {
fmt.Print(CursorRightN(b.DisplaySize() - b.DisplayPos)) fmt.Print(cursorRightN(b.DisplaySize() - b.DisplayPos))
} }
b.Pos = b.Buf.Size() b.Pos = b.Buf.Size()
@ -286,7 +286,8 @@ func (b *Buffer) drawRemaining() {
remLength := runewidth.StringWidth(remainingText) remLength := runewidth.StringWidth(remainingText)
if len(currLine) > 0 { if len(currLine) > 0 {
fmt.Print(ClearToEOL + currLine + CursorLeftN(currLineSpace)) fmt.Printf(ClearToEOL + currLine)
fmt.Print(cursorLeftN(currLineSpace))
} else { } else {
fmt.Print(ClearToEOL) fmt.Print(ClearToEOL)
} }
@ -300,9 +301,9 @@ func (b *Buffer) drawRemaining() {
} }
if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText { if (b.DisplayPos+currLineSpace)%b.LineWidth == 0 && currLine == remainingText {
fmt.Print(CursorRightN(currLineSpace)) fmt.Print(cursorRightN(currLineSpace))
fmt.Printf("\n%s", b.Prompt.AltPrompt) fmt.Printf("\n%s", b.Prompt.AltPrompt)
fmt.Print(CursorUp + CursorBOL + CursorRightN(b.Width-currLineSpace)) fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width-currLineSpace))
} }
// render the other lines // render the other lines
@ -332,7 +333,9 @@ func (b *Buffer) drawRemaining() {
lineLength += runewidth.RuneWidth(c) lineLength += runewidth.RuneWidth(c)
fmt.Printf("%c", c) fmt.Printf("%c", c)
} }
fmt.Print(ClearToEOL + CursorUpN(totalLines) + CursorBOL + CursorRightN(b.Width-currLineSpace)) fmt.Print(ClearToEOL)
fmt.Print(cursorUpN(totalLines))
fmt.Printf(CursorBOL + cursorRightN(b.Width-currLineSpace))
hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth) hasSpace := b.GetLineSpacing(b.DisplayPos / b.LineWidth)
@ -354,7 +357,8 @@ func (b *Buffer) Remove() {
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
// if the user backspaces over the word boundary, do this magic to clear the line // if the user backspaces over the word boundary, do this magic to clear the line
// and move to the end of the previous line // and move to the end of the previous line
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width)) fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth { if b.DisplaySize()%b.LineWidth < (b.DisplaySize()-rLength)%b.LineWidth {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
@ -366,23 +370,24 @@ func (b *Buffer) Remove() {
} }
if rLength == 2 { if rLength == 2 {
fmt.Print(CursorLeft + " " + CursorLeftN(2)) fmt.Print(CursorLeft + " " + cursorLeftN(2))
} else { } else {
fmt.Print(" " + CursorLeft) fmt.Print(" " + CursorLeft)
} }
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
fmt.Print(CursorBOL + ClearToEOL + CursorUp + CursorBOL + CursorRightN(b.Width)) fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
if b.Pos == b.Buf.Size() { if b.Pos == b.Buf.Size() {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
} }
b.DisplayPos -= 1 b.DisplayPos -= 1
} else { } else {
fmt.Print(CursorLeftN(rLength)) fmt.Print(cursorLeftN(rLength))
for range rLength { for range rLength {
fmt.Print(" ") fmt.Print(" ")
} }
fmt.Print(CursorLeftN(rLength)) fmt.Print(cursorLeftN(rLength))
} }
var eraseExtraLine bool var eraseExtraLine bool
@ -400,9 +405,9 @@ func (b *Buffer) Remove() {
// are trailing characters which go over the line width boundary // are trailing characters which go over the line width boundary
if eraseExtraLine { if eraseExtraLine {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Print(CursorDownN(remainingLines+1) + CursorBOL + ClearToEOL) fmt.Printf(cursorDownN(remainingLines+1) + CursorBOL + ClearToEOL)
place := b.DisplayPos % b.LineWidth place := b.DisplayPos % b.LineWidth
fmt.Print(CursorUpN(remainingLines+1) + CursorRightN(place+len(b.Prompt.prompt()))) fmt.Printf(cursorUpN(remainingLines+1) + cursorRightN(place+len(b.Prompt.prompt())))
} }
} }
} }
@ -417,9 +422,9 @@ func (b *Buffer) Delete() {
if b.DisplaySize()%b.LineWidth == 0 { if b.DisplaySize()%b.LineWidth == 0 {
if b.DisplayPos != b.DisplaySize() { if b.DisplayPos != b.DisplaySize() {
remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth remainingLines := (b.DisplaySize() - b.DisplayPos) / b.LineWidth
fmt.Print(CursorDownN(remainingLines) + CursorBOL + ClearToEOL) fmt.Printf(cursorDownN(remainingLines) + CursorBOL + ClearToEOL)
place := b.DisplayPos % b.LineWidth place := b.DisplayPos % b.LineWidth
fmt.Print(CursorUpN(remainingLines) + CursorRightN(place+len(b.Prompt.prompt()))) fmt.Printf(cursorUpN(remainingLines) + cursorRightN(place+len(b.Prompt.prompt())))
} }
} }
} }
@ -466,17 +471,17 @@ func (b *Buffer) DeleteWord() {
} }
func (b *Buffer) ClearScreen() { func (b *Buffer) ClearScreen() {
fmt.Print(ClearScreen + CursorReset + b.Prompt.prompt()) fmt.Printf(ClearScreen + CursorReset + b.Prompt.prompt())
if b.IsEmpty() { if b.IsEmpty() {
ph := b.Prompt.placeholder() ph := b.Prompt.placeholder()
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault) fmt.Printf(ColorGrey + ph + cursorLeftN(len(ph)) + ColorDefault)
} else { } else {
currPos := b.DisplayPos currPos := b.DisplayPos
currIndex := b.Pos currIndex := b.Pos
b.Pos = 0 b.Pos = 0
b.DisplayPos = 0 b.DisplayPos = 0
b.drawRemaining() b.drawRemaining()
fmt.Print(CursorReset + CursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorReset + cursorRightN(len(b.Prompt.prompt())))
if currPos > 0 { if currPos > 0 {
targetLine := currPos / b.LineWidth targetLine := currPos / b.LineWidth
if targetLine > 0 { if targetLine > 0 {
@ -486,10 +491,10 @@ func (b *Buffer) ClearScreen() {
} }
remainder := currPos % b.LineWidth remainder := currPos % b.LineWidth
if remainder > 0 { if remainder > 0 {
fmt.Print(CursorRightN(remainder)) fmt.Print(cursorRightN(remainder))
} }
if currPos%b.LineWidth == 0 { if currPos%b.LineWidth == 0 {
fmt.Print(CursorBOL + b.Prompt.AltPrompt) fmt.Printf(CursorBOL + b.Prompt.AltPrompt)
} }
} }
b.Pos = currIndex b.Pos = currIndex
@ -508,13 +513,13 @@ func (b *Buffer) Replace(r []rune) {
b.Buf.Clear() b.Buf.Clear()
fmt.Print(CursorBOL + ClearToEOL) fmt.Printf(CursorBOL + ClearToEOL)
for range lineNums { for range lineNums {
fmt.Print(CursorUp + CursorBOL + ClearToEOL) fmt.Print(CursorUp + CursorBOL + ClearToEOL)
} }
fmt.Print(CursorBOL + b.Prompt.prompt()) fmt.Printf(CursorBOL + b.Prompt.prompt())
for _, c := range r { for _, c := range r {
b.Add(c) b.Add(c)
@ -540,3 +545,19 @@ func (b *Buffer) StringNM(n, m int) string {
} }
return s return s
} }
func cursorLeftN(n int) string {
return fmt.Sprintf(CursorLeftN, n)
}
func cursorRightN(n int) string {
return fmt.Sprintf(CursorRightN, n)
}
func cursorUpN(n int) string {
return fmt.Sprintf(CursorUpN, n)
}
func cursorDownN(n int) string {
return fmt.Sprintf(CursorDownN, n)
}

View file

@ -98,7 +98,7 @@ func (i *Instance) Readline() (string, error) {
showPlaceholder := !i.Pasting || i.Prompt.UseAlt showPlaceholder := !i.Pasting || i.Prompt.UseAlt
if buf.IsEmpty() && showPlaceholder { if buf.IsEmpty() && showPlaceholder {
ph := i.Prompt.placeholder() ph := i.Prompt.placeholder()
fmt.Print(ColorGrey + ph + CursorLeftN(len(ph)) + ColorDefault) fmt.Printf(ColorGrey + ph + fmt.Sprintf(CursorLeftN, len(ph)) + ColorDefault)
} }
r, err := i.Terminal.Read() r, err := i.Terminal.Read()

View file

@ -1,7 +1,5 @@
package readline package readline
import "strconv"
const ( const (
CharNull = 0 CharNull = 0
CharLineStart = 1 CharLineStart = 1
@ -43,49 +41,34 @@ const (
) )
const ( const (
Esc = "\x1b" CursorUp = "\033[1A"
CursorDown = "\033[1B"
CursorRight = "\033[1C"
CursorLeft = "\033[1D"
CursorSave = Esc + "[s" CursorSave = "\033[s"
CursorRestore = Esc + "[u" CursorRestore = "\033[u"
CursorEOL = Esc + "[E" CursorUpN = "\033[%dA"
CursorBOL = Esc + "[1G" CursorDownN = "\033[%dB"
CursorHide = Esc + "[?25l" CursorRightN = "\033[%dC"
CursorShow = Esc + "[?25h" CursorLeftN = "\033[%dD"
ClearToEOL = Esc + "[K" CursorEOL = "\033[E"
ClearLine = Esc + "[2K" CursorBOL = "\033[1G"
ClearScreen = Esc + "[2J" CursorHide = "\033[?25l"
CursorReset = Esc + "[0;0f" CursorShow = "\033[?25h"
ColorGrey = Esc + "[38;5;245m" ClearToEOL = "\033[K"
ColorDefault = Esc + "[0m" ClearLine = "\033[2K"
ClearScreen = "\033[2J"
CursorReset = "\033[0;0f"
StartBracketedPaste = Esc + "[?2004h" ColorGrey = "\033[38;5;245m"
EndBracketedPaste = Esc + "[?2004l" ColorDefault = "\033[0m"
)
func CursorUpN(n int) string { StartBracketedPaste = "\033[?2004h"
return Esc + "[" + strconv.Itoa(n) + "A" EndBracketedPaste = "\033[?2004l"
}
func CursorDownN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "B"
}
func CursorRightN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "C"
}
func CursorLeftN(n int) string {
return Esc + "[" + strconv.Itoa(n) + "D"
}
var (
CursorUp = CursorUpN(1)
CursorDown = CursorDownN(1)
CursorRight = CursorRightN(1)
CursorLeft = CursorLeftN(1)
) )
const ( const (

View file

@ -209,15 +209,15 @@ install_cuda_driver_yum() {
case $PACKAGE_MANAGER in case $PACKAGE_MANAGER in
yum) yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils $SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi
;; ;;
dnf) dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi
@ -245,8 +245,8 @@ install_cuda_driver_yum() {
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() { install_cuda_driver_apt() {
status 'Installing NVIDIA repository...' status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
else else
error $CUDA_REPO_ERR_MSG error $CUDA_REPO_ERR_MSG
fi fi

View file

@ -94,7 +94,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
} }
const ( const (
numDownloadParts = 16 numDownloadParts = 64
minDownloadPartSize int64 = 100 * format.MegaByte minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
@ -216,7 +216,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
defer file.Close() defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
@ -233,7 +232,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error { newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 { if len(via) > 10 {
return errors.New("maximum redirects exceeded (10) for directURL") return errors.New("maxium redirects exceeded (10) for directURL")
} }
// if the hostname is the same, allow the redirect // if the hostname is the same, allow the redirect

View file

@ -250,7 +250,6 @@ func GetModel(name string) (*Model, error) {
Template: template.DefaultTemplate, Template: template.DefaultTemplate,
} }
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest) filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil { if err != nil {
return nil, err return nil, err
@ -265,7 +264,6 @@ func GetModel(name string) (*Model, error) {
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err return nil, err
} }
}
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest) filename, err := GetBlobsPath(layer.Digest)
@ -373,7 +371,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
var messages []*api.Message var messages []*api.Message
parameters := make(map[string]any) parameters := make(map[string]any)
var layers []Layer var layers []*Layer
for _, c := range modelfile.Commands { for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@ -499,7 +497,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
if layer.MediaType != mediatype { if layer.MediaType != mediatype {
return false return false
} }
@ -545,7 +543,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
} }
var err2 error var err2 error
layers = slices.DeleteFunc(layers, func(layer Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.message": case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones // if there are new messages, remove the inherited ones
@ -625,12 +623,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err return err
} }
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil { if err != nil {
return err return err
} }
for _, layer := range append(layers, configLayer) { for _, layer := range append(layers, layer) {
if layer.status != "" { if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status}) fn(api.ProgressResponse{Status: layer.status})
} }
@ -639,7 +637,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
old, _ := ParseNamedManifest(name) old, _ := ParseNamedManifest(name)
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, configLayer, layers); err != nil { if err := WriteManifest(name, layer, layers); err != nil {
return err return err
} }
@ -716,7 +714,8 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests // save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp) manifest, _, err := GetManifest(fmp)
if err != nil { if err != nil {
return err //nolint:nilerr
return nil
} }
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
@ -783,8 +782,7 @@ func PruneLayers() error {
err = deleteUnusedLayers(nil, deleteMap) err = deleteUnusedLayers(nil, deleteMap)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err)) return err
return nil
} }
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap))) slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
@ -839,11 +837,9 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return err return err
} }
var layers []Layer var layers []*Layer
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
}
for _, layer := range layers { for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@ -894,11 +890,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, l := range manifest.Layers { for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{} deleteMap[l.Digest] = struct{}{}
} }
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{} deleteMap[manifest.Config.Digest] = struct{}{}
} }
} }
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure { if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http") return errors.New("insecure protocol http")
@ -911,11 +905,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %s", err)
} }
var layers []Layer var layers []*Layer
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool) skipVerify := make(map[string]bool)
for _, layer := range layers { for _, layer := range layers {
@ -979,8 +971,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "removing any unused layers"}) fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap) err = deleteUnusedLayers(nil, deleteMap)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err)) return err
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
} }
} }

View file

@ -2,7 +2,6 @@ package server
import ( import (
"crypto/sha256" "crypto/sha256"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -16,15 +15,15 @@ type Layer struct {
status string status string
} }
func NewLayer(r io.Reader, mediatype string) (Layer, error) { func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
blobs, err := GetBlobsPath("") blobs, err := GetBlobsPath("")
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
temp, err := os.CreateTemp(blobs, "sha256-") temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
@ -32,28 +31,28 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
sha256sum := sha256.New() sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
if err := temp.Close(); err != nil { if err := temp.Close(); err != nil {
return Layer{}, err return nil, err
} }
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)) digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest) blob, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
status := "using existing layer" status := "using existing layer"
if _, err := os.Stat(blob); err != nil { if _, err := os.Stat(blob); err != nil {
status = "creating new layer" status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil { if err := os.Rename(temp.Name(), blob); err != nil {
return Layer{}, err return nil, err
} }
} }
return Layer{ return &Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: digest, Digest: digest,
Size: n, Size: n,
@ -61,22 +60,18 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
}, nil }, nil
} }
func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) { func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
if digest == "" {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest) blob, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
fi, err := os.Stat(blob) fi, err := os.Stat(blob)
if err != nil { if err != nil {
return Layer{}, err return nil, err
} }
return Layer{ return &Layer{
MediaType: mediatype, MediaType: mediatype,
Digest: digest, Digest: digest,
Size: fi.Size(), Size: fi.Size(),
@ -86,10 +81,6 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
} }
func (l *Layer) Open() (io.ReadSeekCloser, error) { func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest) blob, err := GetBlobsPath(l.Digest)
if err != nil { if err != nil {
return nil, err return nil, err
@ -99,10 +90,6 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
} }
func (l *Layer) Remove() error { func (l *Layer) Remove() error {
if l.Digest == "" {
return nil
}
ms, err := Manifests() ms, err := Manifests()
if err != nil { if err != nil {
return err return err

View file

@ -16,8 +16,8 @@ import (
type Manifest struct { type Manifest struct {
SchemaVersion int `json:"schemaVersion"` SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"` MediaType string `json:"mediaType"`
Config Layer `json:"config"` Config *Layer `json:"config"`
Layers []Layer `json:"layers"` Layers []*Layer `json:"layers"`
filepath string filepath string
fi os.FileInfo fi os.FileInfo
@ -47,14 +47,12 @@ func (m *Manifest) Remove() error {
func (m *Manifest) RemoveLayers() error { func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) { for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) { if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest) slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil { } else if err != nil {
return err return err
} }
} }
}
return nil return nil
} }
@ -95,7 +93,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return &m, nil return &m, nil
} }
func WriteManifest(name model.Name, config Layer, layers []Layer) error { func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
if err != nil { if err != nil {
return err return err

View file

@ -26,7 +26,7 @@ import (
var intermediateBlobs map[string]string = make(map[string]string) var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct { type layerGGML struct {
Layer *Layer
*llm.GGML *llm.GGML
} }
@ -176,21 +176,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
mediatype = "application/vnd.ollama.image.projector" mediatype = "application/vnd.ollama.image.projector"
} }
var layer Layer layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, file.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
layers = append(layers, &layerGGML{layer, ggml}) layers = append(layers, &layerGGML{layer, ggml})
offset = n offset = n

View file

@ -2,10 +2,8 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -13,7 +11,6 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@ -136,82 +133,3 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
}) })
} }
} }
func TestParseFromFileFromLayer(t *testing.T) {
tempModels := t.TempDir()
file, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file.Close()
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 1 {
t.Fatalf("got %d != want 1", len(layers))
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers2) != 1 {
t.Fatalf("got %d != want 1", len(layers2))
}
if layers[0].Digest != layers2[0].Digest {
t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest)
}
if layers[0].Size != layers2[0].Size {
t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size)
}
if layers[0].MediaType != layers2[0].MediaType {
t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType)
}
}
func TestParseLayerFromCopy(t *testing.T) {
tempModels := t.TempDir()
file2, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file2.Close()
for range 5 {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
}
if _, err := file2.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 5 {
t.Fatalf("got %d != want 5", len(layers))
}
}

View file

@ -23,7 +23,6 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@ -324,10 +323,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
input = append(input, v.(string)) input = append(input, v.(string))
} }
default: default:
if req.Input != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return return
} }
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
@ -338,18 +340,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
kvData, err := getKVData(m.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
var count int
for i, s := range input { for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s) tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil { if err != nil {
@ -372,36 +368,25 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
} }
count += len(tokens)
input[i] = s input[i] = s
} }
embeddings, err := r.Embed(c.Request.Context(), input)
var g errgroup.Group
embeddings := make([][]float32, len(input))
for i, text := range input {
g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil { if err != nil {
return err slog.Error("embedding generation failed", "error", err)
} c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
embeddings[i] = normalize(embedding) return
return nil
})
} }
if err := g.Wait(); err != nil { for i, e := range embeddings.Embedding {
slog.Error("embedding generation failed", "error", err) embeddings.Embedding[i] = normalize(e)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return
} }
resp := api.EmbedResponse{ resp := api.EmbedResponse{
Model: req.Model, Model: req.Model,
Embeddings: embeddings, Embeddings: embeddings.Embedding,
TotalDuration: time.Since(checkpointStart), TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart), LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: count, PromptEvalCount: embeddings.PromptEvalCount,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -445,20 +430,21 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embedding, err := r.Embedding(c.Request.Context(), req.Prompt) embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
var e []float64 embedding := make([]float64, len(embeddings.Embedding[0]))
for _, v := range embedding {
e = append(e, float64(v)) for i, v := range embeddings.Embedding[0] {
embedding[i] = float64(v)
} }
resp := api.EmbeddingResponse{ resp := api.EmbeddingResponse{
Embedding: e, Embedding: embedding,
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
@ -838,9 +824,6 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
models := []api.ListModelResponse{} models := []api.ListModelResponse{}
for n, m := range ms { for n, m := range ms {
var cf ConfigV2
if m.Config.Digest != "" {
f, err := m.Config.Open() f, err := m.Config.Open()
if err != nil { if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err) slog.Warn("bad manifest filepath", "name", n, "error", err)
@ -848,11 +831,11 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
} }
defer f.Close() defer f.Close()
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil { if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err) slog.Warn("bad manifest config", "name", n, "error", err)
continue continue
} }
}
// tag should never be masked // tag should never be masked
models = append(models, api.ListModelResponse{ models = append(models, api.ListModelResponse{

View file

@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
} }
// create a manifest with duplicate layers // create a manifest with duplicate layers
if err := WriteManifest(n, config, []Layer{config}); err != nil { if err := WriteManifest(n, config, []*Layer{config}); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -272,6 +272,76 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy) assert.Equal(t, "library", retrieveResp.OwnedBy)
}, },
}, },
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())

View file

@ -418,7 +418,7 @@ func (s *Scheduler) load(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList,
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
// check for model compatibility // check for model compatibility
if errors.Is(err, llm.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { if errors.Is(llm.ErrUnsupportedFormat, err) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
} }
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)

View file

@ -708,8 +708,8 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embeddingResp []float32 embedResp *llm.EmbedResponse
embeddingRespErr error embedRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
detokenizeResp string detokenizeResp string
@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) { func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
return s.embeddingResp, s.embeddingRespErr return s.embedResp, s.embedRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {

View file

@ -1,8 +0,0 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

View file

@ -1,17 +0,0 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}

View file

@ -26,7 +26,7 @@ import (
var blobUploadManager sync.Map var blobUploadManager sync.Map
type blobUpload struct { type blobUpload struct {
Layer *Layer
Total int64 Total int64
Completed atomic.Int64 Completed atomic.Int64
@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
p.written = 0 p.written = 0
} }
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error { func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL() requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest) requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)

View file

@ -219,7 +219,7 @@ func (n Name) String() string {
return b.String() return b.String()
} }
// DisplayShortest returns a short string version of the name. // DisplayShort returns a short string version of the name.
func (n Name) DisplayShortest() string { func (n Name) DisplayShortest() string {
var sb strings.Builder var sb strings.Builder