package server import ( "bytes" "cmp" "context" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "log" "log/slog" "net" "net/http" "net/url" "os" "path/filepath" "runtime" "slices" "strconv" "strings" "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llama" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) var ( errCapabilities = errors.New("does not support") errCapabilityCompletion = errors.New("completion") errCapabilityTools = errors.New("tools") errCapabilityInsert = errors.New("insert") ) type Capability string const ( CapabilityCompletion = Capability("completion") CapabilityTools = Capability("tools") CapabilityInsert = Capability("insert") ) type registryOptions struct { Insecure bool Username string Password string Token string CheckRedirect func(req *http.Request, via []*http.Request) error } type Model struct { Name string `json:"name"` Config ConfigV2 ShortName string ModelPath string ParentModel string AdapterPaths []string ProjectorPaths []string System string License []string Digest string Options map[string]interface{} Messages []api.Message Template *template.Template } // CheckCapabilities checks if the model has the specified capabilities returning an error describing // any missing or unknown capabilities func (m *Model) CheckCapabilities(caps ...Capability) error { var errs []error for _, cap := range caps { switch cap { case CapabilityCompletion: f, err := os.Open(m.ModelPath) if err != nil { slog.Error("couldn't open model file", "error", err) continue } defer f.Close() // TODO(mxyng): decode the GGML into model to avoid doing this multiple times ggml, _, err := llm.DecodeGGML(f, 0) if err != nil { slog.Error("couldn't decode ggml", "error", err) continue } if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { errs = append(errs, errCapabilityCompletion) } case CapabilityTools: if !slices.Contains(m.Template.Vars(), "tools") { errs = append(errs, errCapabilityTools) } case CapabilityInsert: vars := m.Template.Vars() if !slices.Contains(vars, "suffix") { errs = append(errs, errCapabilityInsert) } default: slog.Error("unknown capability", "capability", cap) return fmt.Errorf("unknown capability: %s", cap) } } if err := errors.Join(errs...); err != nil { return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } return nil } func (m *Model) String() string { var modelfile parser.File modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "model", Args: m.ModelPath, }) for _, adapter := range m.AdapterPaths { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "adapter", Args: adapter, }) } for _, projector := range m.ProjectorPaths { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "model", Args: projector, }) } if m.Template != nil { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "template", Args: m.Template.String(), }) } if m.System != "" { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "system", Args: m.System, }) } for k, v := range m.Options { switch v := v.(type) { case []any: for _, s := range v { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: k, Args: fmt.Sprintf("%v", s), }) } default: modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: k, Args: fmt.Sprintf("%v", v), }) } } for _, license := range m.License { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "license", Args: license, }) } for _, msg := range m.Messages { modelfile.Commands = append(modelfile.Commands, parser.Command{ Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content), }) } return modelfile.String() } type ConfigV2 struct { ModelFormat string `json:"model_format"` ModelFamily string `json:"model_family"` ModelFamilies []string `json:"model_families"` ModelType string `json:"model_type"` FileType string `json:"file_type"` // required by spec Architecture string `json:"architecture"` OS string `json:"os"` RootFS RootFS `json:"rootfs"` } type RootFS struct { Type string `json:"type"` DiffIDs []string `json:"diff_ids"` } func GetManifest(mp ModelPath) (*Manifest, string, error) { fp, err := mp.GetManifestPath() if err != nil { return nil, "", err } f, err := os.Open(fp) if err != nil { return nil, "", err } defer f.Close() sha256sum := sha256.New() var manifest Manifest if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil { return nil, "", err } return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil } func GetModel(name string) (*Model, error) { mp := ParseModelPath(name) manifest, digest, err := GetManifest(mp) if err != nil { return nil, err } model := &Model{ Name: mp.GetFullTagname(), ShortName: mp.GetShortTagname(), Digest: digest, Template: template.DefaultTemplate, } if manifest.Config.Digest != "" { filename, err := GetBlobsPath(manifest.Config.Digest) if err != nil { return nil, err } configFile, err := os.Open(filename) if err != nil { return nil, err } defer configFile.Close() if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil { return nil, err } } for _, layer := range manifest.Layers { filename, err := GetBlobsPath(layer.Digest) if err != nil { return nil, err } switch layer.MediaType { case "application/vnd.ollama.image.model": model.ModelPath = filename model.ParentModel = layer.From case "application/vnd.ollama.image.embed": // Deprecated in versions > 0.1.2 // TODO: remove this warning in a future version slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.") case "application/vnd.ollama.image.adapter": model.AdapterPaths = append(model.AdapterPaths, filename) case "application/vnd.ollama.image.projector": model.ProjectorPaths = append(model.ProjectorPaths, filename) case "application/vnd.ollama.image.prompt", "application/vnd.ollama.image.template": bts, err := os.ReadFile(filename) if err != nil { return nil, err } model.Template, err = template.Parse(string(bts)) if err != nil { return nil, err } case "application/vnd.ollama.image.system": bts, err := os.ReadFile(filename) if err != nil { return nil, err } model.System = string(bts) case "application/vnd.ollama.image.params": params, err := os.Open(filename) if err != nil { return nil, err } defer params.Close() // parse model options parameters into a map so that we can see which fields have been specified explicitly if err = json.NewDecoder(params).Decode(&model.Options); err != nil { return nil, err } case "application/vnd.ollama.image.messages": msgs, err := os.Open(filename) if err != nil { return nil, err } defer msgs.Close() if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil { return nil, err } case "application/vnd.ollama.image.license": bts, err := os.ReadFile(filename) if err != nil { return nil, err } model.License = append(model.License, string(bts)) } } return model, nil } func realpath(rel, from string) string { abspath, err := filepath.Abs(from) if err != nil { return from } home, err := os.UserHomeDir() if err != nil { return abspath } if from == "~" { return home } else if strings.HasPrefix(from, "~/") { return filepath.Join(home, from[2:]) } if _, err := os.Stat(filepath.Join(rel, from)); err == nil { // this is a file relative to the Modelfile return filepath.Join(rel, from) } return abspath } func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { config := ConfigV2{ OS: "linux", Architecture: "amd64", RootFS: RootFS{ Type: "layers", }, } var messages []*api.Message parameters := make(map[string]any) var layers []Layer var baseLayers []*layerGGML for _, c := range modelfile.Commands { mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) command := c.Name switch command { case "model", "adapter": if name := model.ParseName(c.Args); name.IsValid() && command == "model" { baseLayers, err = parseFromModel(ctx, name, fn) if err != nil { return err } } else if strings.HasPrefix(c.Args, "@") { digest := strings.TrimPrefix(c.Args, "@") if ib, ok := intermediateBlobs[digest]; ok { p, err := GetBlobsPath(ib) if err != nil { return err } if _, err := os.Stat(p); errors.Is(err, os.ErrNotExist) { // pass } else if err != nil { return err } else { fn(api.ProgressResponse{Status: fmt.Sprintf("using cached layer %s", ib)}) digest = ib } } blobpath, err := GetBlobsPath(digest) if err != nil { return err } blob, err := os.Open(blobpath) if err != nil { return err } defer blob.Close() baseLayers, err = parseFromFile(ctx, command, baseLayers, blob, digest, fn) if err != nil { return err } } else if file, err := os.Open(realpath(modelFileDir, c.Args)); err == nil { defer file.Close() baseLayers, err = parseFromFile(ctx, command, baseLayers, file, "", fn) if err != nil { return err } } else { return fmt.Errorf("invalid model reference: %s", c.Args) } for _, baseLayer := range baseLayers { if quantization != "" && baseLayer.MediaType == "application/vnd.ollama.image.model" && baseLayer.GGML != nil && baseLayer.GGML.Name() == "gguf" { want, err := llm.ParseFileType(quantization) if err != nil { return err } ft := baseLayer.GGML.KV().FileType() if !slices.Contains([]string{"F16", "F32"}, ft.String()) { return errors.New("quantization is only supported for F16 and F32 models") } else if want != ft { fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)}) blob, err := GetBlobsPath(baseLayer.Digest) if err != nil { return err } temp, err := os.CreateTemp(filepath.Dir(blob), quantization) if err != nil { return err } defer temp.Close() defer os.Remove(temp.Name()) if err := llama.Quantize(blob, temp.Name(), uint32(want)); err != nil { return err } layer, err := NewLayer(temp, baseLayer.MediaType) if err != nil { return err } if _, err := temp.Seek(0, io.SeekStart); err != nil { return err } ggml, _, err := llm.DecodeGGML(temp, 0) if err != nil { return err } baseLayer.Layer = layer baseLayer.GGML = ggml } } if baseLayer.GGML != nil { config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name()) config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture()) config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) } layers = append(layers, baseLayer.Layer) } case "license", "template", "system": if c.Name == "template" { if _, err := template.Parse(c.Args); err != nil { return fmt.Errorf("%w: %s", errBadTemplate, err) } } if c.Name != "license" { // replace layers = slices.DeleteFunc(layers, func(layer Layer) bool { if layer.MediaType != mediatype { return false } if err := layer.Remove(); err != nil { return false } return true }) } blob := strings.NewReader(c.Args) layer, err := NewLayer(blob, mediatype) if err != nil { return err } layers = append(layers, layer) case "message": role, content, ok := strings.Cut(c.Args, ": ") if !ok { return fmt.Errorf("invalid message: %s", c.Args) } messages = append(messages, &api.Message{Role: role, Content: content}) default: ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) if err != nil { return err } for k, v := range ps { if ks, ok := parameters[k].([]string); ok { parameters[k] = append(ks, v.([]string)...) } else if vs, ok := v.([]string); ok { parameters[k] = vs } else { parameters[k] = v } } } } var err2 error layers = slices.DeleteFunc(layers, func(layer Layer) bool { switch layer.MediaType { case "application/vnd.ollama.image.message": // if there are new messages, remove the inherited ones if len(messages) > 0 { return true } return false case "application/vnd.ollama.image.params": // merge inherited parameters with new ones r, err := layer.Open() if err != nil { err2 = err return false } defer r.Close() var ps map[string]any if err := json.NewDecoder(r).Decode(&ps); err != nil { err2 = err return false } for k, v := range ps { if _, ok := parameters[k]; !ok { parameters[k] = v } } return true default: return false } }) if err2 != nil { return err2 } if len(messages) > 0 { var b bytes.Buffer if err := json.NewEncoder(&b).Encode(messages); err != nil { return err } layer, err := NewLayer(&b, "application/vnd.ollama.image.messages") if err != nil { return err } layers = append(layers, layer) } if len(parameters) > 0 { var b bytes.Buffer if err := json.NewEncoder(&b).Encode(parameters); err != nil { return err } layer, err := NewLayer(&b, "application/vnd.ollama.image.params") if err != nil { return err } layers = append(layers, layer) } digests := make([]string, len(layers)) for i, layer := range layers { digests[i] = layer.Digest } config.RootFS.DiffIDs = digests var b bytes.Buffer if err := json.NewEncoder(&b).Encode(config); err != nil { return err } configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json") if err != nil { return err } for _, layer := range append(layers, configLayer) { if layer.status != "" { fn(api.ProgressResponse{Status: layer.status}) } } old, _ := ParseNamedManifest(name) fn(api.ProgressResponse{Status: "writing manifest"}) if err := WriteManifest(name, configLayer, layers); err != nil { return err } if !envconfig.NoPrune() && old != nil { if err := old.RemoveLayers(); err != nil { return err } } fn(api.ProgressResponse{Status: "success"}) return nil } func CopyModel(src, dst model.Name) error { if !dst.IsFullyQualified() { return model.Unqualified(dst) } if !src.IsFullyQualified() { return model.Unqualified(src) } if src.Filepath() == dst.Filepath() { return nil } manifests, err := GetManifestPath() if err != nil { return err } dstpath := filepath.Join(manifests, dst.Filepath()) if err := os.MkdirAll(filepath.Dir(dstpath), 0o755); err != nil { return err } srcpath := filepath.Join(manifests, src.Filepath()) srcfile, err := os.Open(srcpath) if err != nil { return err } defer srcfile.Close() dstfile, err := os.Create(dstpath) if err != nil { return err } defer dstfile.Close() _, err = io.Copy(dstfile, srcfile) return err } func deleteUnusedLayers(deleteMap map[string]struct{}) error { // Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned manifests, err := Manifests(true) if err != nil { return err } for _, manifest := range manifests { for _, layer := range manifest.Layers { delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) } // only delete the files which are still in the deleteMap for k := range deleteMap { fp, err := GetBlobsPath(k) if err != nil { slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err)) continue } if err := os.Remove(fp); err != nil { slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err)) continue } } return nil } func PruneLayers() error { deleteMap := make(map[string]struct{}) p, err := GetBlobsPath("") if err != nil { return err } blobs, err := os.ReadDir(p) if err != nil { slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err)) return err } for _, blob := range blobs { name := blob.Name() name = strings.ReplaceAll(name, "-", ":") _, err := GetBlobsPath(name) if err != nil { if errors.Is(err, ErrInvalidDigestFormat) { // remove invalid blobs (e.g. partial downloads) if err := os.Remove(filepath.Join(p, blob.Name())); err != nil { slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err) } } continue } deleteMap[name] = struct{}{} } slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap))) if err := deleteUnusedLayers(deleteMap); err != nil { slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err)) return nil } slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap))) return nil } func PruneDirectory(path string) error { info, err := os.Lstat(path) if err != nil { return err } if info.IsDir() && info.Mode()&os.ModeSymlink == 0 { entries, err := os.ReadDir(path) if err != nil { return err } for _, entry := range entries { if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil { return err } } entries, err = os.ReadDir(path) if err != nil { return err } if len(entries) > 0 { return nil } return os.Remove(path) } return nil } func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) fn(api.ProgressResponse{Status: "retrieving manifest"}) if mp.ProtocolScheme == "http" && !regOpts.Insecure { return errors.New("insecure protocol http") } manifest, _, err := GetManifest(mp) if err != nil { fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) return err } var layers []Layer layers = append(layers, manifest.Layers...) if manifest.Config.Digest != "" { layers = append(layers, manifest.Config) } for _, layer := range layers { if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil { slog.Info(fmt.Sprintf("error uploading blob: %v", err)) return err } } fn(api.ProgressResponse{Status: "pushing manifest"}) requestURL := mp.BaseURL() requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) manifestJSON, err := json.Marshal(manifest) if err != nil { return err } headers := make(http.Header) headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json") resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts) if err != nil { return err } defer resp.Body.Close() fn(api.ProgressResponse{Status: "success"}) return nil } func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { mp := ParseModelPath(name) // build deleteMap to prune unused layers deleteMap := make(map[string]struct{}) manifest, _, err := GetManifest(mp) if errors.Is(err, os.ErrNotExist) { // noop } else if err != nil { slog.Warn("pulling model with bad existing manifest", "name", name, "error", err) } else { for _, l := range manifest.Layers { deleteMap[l.Digest] = struct{}{} } if manifest.Config.Digest != "" { deleteMap[manifest.Config.Digest] = struct{}{} } } if mp.ProtocolScheme == "http" && !regOpts.Insecure { return errors.New("insecure protocol http") } fn(api.ProgressResponse{Status: "pulling manifest"}) manifest, err = pullModelManifest(ctx, mp, regOpts) if err != nil { return fmt.Errorf("pull model manifest: %s", err) } var layers []Layer layers = append(layers, manifest.Layers...) if manifest.Config.Digest != "" { layers = append(layers, manifest.Config) } skipVerify := make(map[string]bool) for _, layer := range layers { cacheHit, err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, fn: fn, }) if err != nil { return err } skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) fn(api.ProgressResponse{Status: "verifying sha256 digest"}) for _, layer := range layers { if skipVerify[layer.Digest] { continue } if err := verifyBlob(layer.Digest); err != nil { if errors.Is(err, errDigestMismatch) { // something went wrong, delete the blob fp, err := GetBlobsPath(layer.Digest) if err != nil { return err } if err := os.Remove(fp); err != nil { // log this, but return the original error slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) } } return err } } fn(api.ProgressResponse{Status: "writing manifest"}) manifestJSON, err := json.Marshal(manifest) if err != nil { return err } fp, err := mp.GetManifestPath() if err != nil { return err } if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil { return err } err = os.WriteFile(fp, manifestJSON, 0o644) if err != nil { slog.Info(fmt.Sprintf("couldn't write to %s", fp)) return err } if !envconfig.NoPrune() && len(deleteMap) > 0 { fn(api.ProgressResponse{Status: "removing unused layers"}) if err := deleteUnusedLayers(deleteMap); err != nil { fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)}) } } fn(api.ProgressResponse{Status: "success"}) return nil } func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) headers := make(http.Header) headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) if err != nil { return nil, err } defer resp.Body.Close() var m Manifest if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { return nil, err } return &m, err } // GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer func GetSHA256Digest(r io.Reader) (string, int64) { h := sha256.New() n, err := io.Copy(h, r) if err != nil { log.Fatal(err) } return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } var errUnauthorized = errors.New("unauthorized: access denied") // getTokenSubject returns the subject of a JWT token, it does not validate the token func getTokenSubject(token string) string { parts := strings.Split(token, ".") if len(parts) != 3 { return "" } payload := parts[1] payloadBytes, err := base64.RawURLEncoding.DecodeString(payload) if err != nil { slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err)) return "" } var payloadMap map[string]interface{} if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil { slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err)) return "" } sub, ok := payloadMap["sub"] if !ok { slog.Error("jwt does not contain 'sub' field") return "" } return fmt.Sprintf("%s", sub) } func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { anonymous := true // access will default to anonymous if no user is found associated with the public key for range 2 { resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) if err != nil { if !errors.Is(err, context.Canceled) { slog.Info(fmt.Sprintf("request failed: %v", err)) } return nil, err } switch { case resp.StatusCode == http.StatusUnauthorized: resp.Body.Close() // Handle authentication error with one retry challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate")) token, err := getAuthorizationToken(ctx, challenge) if err != nil { return nil, err } anonymous = getTokenSubject(token) == "anonymous" regOpts.Token = token if body != nil { _, err = body.Seek(0, io.SeekStart) if err != nil { return nil, err } } case resp.StatusCode == http.StatusNotFound: resp.Body.Close() return nil, os.ErrNotExist case resp.StatusCode >= http.StatusBadRequest: defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("%d: %s", resp.StatusCode, err) } return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody) default: return resp, nil } } if anonymous { // no user is associated with the public key, and the request requires non-anonymous access pubKey, nestedErr := auth.GetPublicKey() if nestedErr != nil { slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) return nil, errUnauthorized } return nil, &errtypes.UnknownOllamaKey{Key: pubKey} } // user is associated with the public key, but is not authorized to make the request return nil, errUnauthorized } // testMakeRequestDialContext specifies the dial function for the http client in // makeRequest. It can be used to resolve hosts in model names to local // addresses for testing. For example, the model name ("example.com/my/model") // can be directed to push/pull from "127.0.0.1:1234". // // This is not safe to set across goroutines. It should be set in // the main test goroutine, and not by tests marked to run in parallel with // t.Parallel(). // // It should be cleared after use, otherwise it will affect other tests. // // Ideally we would have some set this up the stack, but the code is not // structured in a way that makes this easy, so this will have to do for now. var testMakeRequestDialContext func(ctx context.Context, network, addr string) (net.Conn, error) func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) { if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure { requestURL.Scheme = "http" } req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body) if err != nil { return nil, err } if headers != nil { req.Header = headers } if regOpts != nil { if regOpts.Token != "" { req.Header.Set("Authorization", "Bearer "+regOpts.Token) } else if regOpts.Username != "" && regOpts.Password != "" { req.SetBasicAuth(regOpts.Username, regOpts.Password) } } req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) if s := req.Header.Get("Content-Length"); s != "" { contentLength, err := strconv.ParseInt(s, 10, 64) if err != nil { return nil, err } req.ContentLength = contentLength } resp, err := (&http.Client{ Transport: &http.Transport{ DialContext: testMakeRequestDialContext, }, CheckRedirect: regOpts.CheckRedirect, }).Do(req) if err != nil { return nil, err } return resp, nil } func getValue(header, key string) string { startIdx := strings.Index(header, key+"=") if startIdx == -1 { return "" } // Move the index to the starting quote after the key. startIdx += len(key) + 2 endIdx := startIdx for endIdx < len(header) { if header[endIdx] == '"' { if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue endIdx++ continue } break } endIdx++ } return header[startIdx:endIdx] } func parseRegistryChallenge(authStr string) registryChallenge { authStr = strings.TrimPrefix(authStr, "Bearer ") return registryChallenge{ Realm: getValue(authStr, "realm"), Service: getValue(authStr, "service"), Scope: getValue(authStr, "scope"), } } var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again") func verifyBlob(digest string) error { fp, err := GetBlobsPath(digest) if err != nil { return err } f, err := os.Open(fp) if err != nil { return err } defer f.Close() fileDigest, _ := GetSHA256Digest(f) if digest != fileDigest { return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest) } return nil }