rerefactor
This commit is contained in:
parent
823a520266
commit
e43648afe5
9 changed files with 224 additions and 251 deletions
|
@ -2,6 +2,7 @@ package lifecycle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -21,7 +23,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
UpdateCheckURLBase = "https://ollama.ai/api/update"
|
UpdateCheckURLBase = "https://ollama.com/api/update"
|
||||||
UpdateDownloaded = false
|
UpdateDownloaded = false
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,22 +49,42 @@ func getClient(req *http.Request) http.Client {
|
||||||
|
|
||||||
func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||||
var updateResp UpdateResponse
|
var updateResp UpdateResponse
|
||||||
updateCheckURL := UpdateCheckURLBase + "?os=" + runtime.GOOS + "&arch=" + runtime.GOARCH + "&version=" + version.Version
|
|
||||||
headers := make(http.Header)
|
requestURL, err := url.Parse(UpdateCheckURLBase)
|
||||||
err := auth.SignRequest(http.MethodGet, updateCheckURL, nil, headers)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("failed to sign update request %s", err))
|
return false, updateResp
|
||||||
}
|
}
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, updateCheckURL, nil)
|
|
||||||
|
query := requestURL.Query()
|
||||||
|
query.Add("os", runtime.GOOS)
|
||||||
|
query.Add("arch", runtime.GOARCH)
|
||||||
|
query.Add("version", version.Version)
|
||||||
|
query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))
|
||||||
|
|
||||||
|
nonce, err := auth.NewNonce(rand.Reader, 16)
|
||||||
|
if err != nil {
|
||||||
|
return false, updateResp
|
||||||
|
}
|
||||||
|
|
||||||
|
query.Add("nonce", nonce)
|
||||||
|
requestURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
|
||||||
|
signature, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return false, updateResp
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
|
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
|
||||||
return false, updateResp
|
return false, updateResp
|
||||||
}
|
}
|
||||||
req.Header = headers
|
req.Header.Set("Authorization", signature)
|
||||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
client := getClient(req)
|
client := getClient(req)
|
||||||
|
|
||||||
slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", updateCheckURL, headers))
|
slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", requestURL, req.Header))
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
|
slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#define MyAppVersion "0.0.0"
|
#define MyAppVersion "0.0.0"
|
||||||
#endif
|
#endif
|
||||||
#define MyAppPublisher "Ollama, Inc."
|
#define MyAppPublisher "Ollama, Inc."
|
||||||
#define MyAppURL "https://ollama.ai/"
|
#define MyAppURL "https://ollama.com/"
|
||||||
#define MyAppExeName "ollama app.exe"
|
#define MyAppExeName "ollama app.exe"
|
||||||
#define MyIcon ".\assets\app.ico"
|
#define MyIcon ".\assets\app.ico"
|
||||||
|
|
||||||
|
|
153
auth/auth.go
153
auth/auth.go
|
@ -4,185 +4,58 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const defaultPrivateKey = "id_ed25519"
|
||||||
KeyType = "id_ed25519"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthRedirect struct {
|
func NewNonce(r io.Reader, length int) (string, error) {
|
||||||
Realm string
|
|
||||||
Service string
|
|
||||||
Scope string
|
|
||||||
}
|
|
||||||
|
|
||||||
type SignatureData struct {
|
|
||||||
Method string
|
|
||||||
Path string
|
|
||||||
Data []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateNonce(length int) (string, error) {
|
|
||||||
nonce := make([]byte, length)
|
nonce := make([]byte, length)
|
||||||
_, err := rand.Read(nonce)
|
if _, err := io.ReadFull(r, nonce); err != nil {
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r AuthRedirect) URL() (*url.URL, error) {
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||||
redirectURL, err := url.Parse(r.Realm)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
values := redirectURL.Query()
|
|
||||||
|
|
||||||
values.Add("service", r.Service)
|
|
||||||
|
|
||||||
for _, s := range strings.Split(r.Scope, " ") {
|
|
||||||
values.Add("scope", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
|
||||||
|
|
||||||
nonce, err := generateNonce(16)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
values.Add("nonce", nonce)
|
|
||||||
|
|
||||||
redirectURL.RawQuery = values.Encode()
|
|
||||||
return redirectURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SignRequest(method, url string, data []byte, headers http.Header) error {
|
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
keyPath := filepath.Join(home, ".ollama", KeyType)
|
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||||
|
|
||||||
rawKey, err := os.ReadFile(keyPath)
|
privateKeyFile, err := os.ReadFile(keyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s := SignatureData{
|
|
||||||
Method: method,
|
|
||||||
Path: url,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := s.Sign(rawKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
headers.Set("Authorization", sig)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
|
||||||
redirectURL, err := redirData.URL()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := make(http.Header)
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
||||||
err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode >= http.StatusBadRequest {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
|
|
||||||
} else if len(responseBody) > 0 {
|
|
||||||
return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("%s", resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var tok api.TokenResponse
|
|
||||||
if err := json.Unmarshal(respBody, &tok); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return tok.Token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes returns a byte slice of the data to sign for the request
|
|
||||||
func (s SignatureData) Bytes() []byte {
|
|
||||||
// We first derive the content hash of the request body using:
|
|
||||||
// base64(hex(sha256(request body)))
|
|
||||||
|
|
||||||
hash := sha256.Sum256(s.Data)
|
|
||||||
hashHex := make([]byte, hex.EncodedLen(len(hash)))
|
|
||||||
hex.Encode(hashHex, hash[:])
|
|
||||||
contentHash := base64.StdEncoding.EncodeToString(hashHex)
|
|
||||||
|
|
||||||
// We then put the entire request together in a serialize string using:
|
|
||||||
// "<method>,<uri>,<content hash>"
|
|
||||||
// e.g. "GET,http://localhost,OTdkZjM1O..."
|
|
||||||
|
|
||||||
return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignData takes a SignatureData object and signs it with a raw private key
|
|
||||||
func (s SignatureData) Sign(rawKey []byte) (string, error) {
|
|
||||||
signer, err := ssh.ParsePrivateKey(rawKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the pubkey, but remove the type
|
// get the pubkey, but remove the type
|
||||||
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
||||||
parts := bytes.Split(pubKey, []byte(" "))
|
parts := bytes.Split(publicKey, []byte(" "))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return "", fmt.Errorf("malformed public key")
|
return "", fmt.Errorf("malformed public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
signedData, err := signer.Sign(nil, s.Bytes())
|
signedData, err := privateKey.Sign(rand.Reader, bts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// signature is <pubkey>:<signature>
|
// signature is <pubkey>:<signature>
|
||||||
sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
||||||
return sig, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/version"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RegistryOptions struct {
|
|
||||||
Insecure bool
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
Token string
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
|
||||||
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
|
||||||
requestURL.Scheme = "http"
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if headers != nil {
|
|
||||||
req.Header = headers
|
|
||||||
}
|
|
||||||
|
|
||||||
if regOpts != nil {
|
|
||||||
if regOpts.Token != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
|
||||||
} else if regOpts.Username != "" && regOpts.Password != "" {
|
|
||||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
|
||||||
|
|
||||||
if s := req.Header.Get("Content-Length"); s != "" {
|
|
||||||
contentLength, err := strconv.ParseInt(s, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
req.ContentLength = contentLength
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyURL, err := http.ProxyFromEnvironment(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := http.Client{
|
|
||||||
Transport: &http.Transport{
|
|
||||||
Proxy: http.ProxyURL(proxyURL),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
95
server/auth.go
Normal file
95
server/auth.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
"github.com/jmorganca/ollama/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type registryChallenge struct {
|
||||||
|
Realm string
|
||||||
|
Service string
|
||||||
|
Scope string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r registryChallenge) URL() (*url.URL, error) {
|
||||||
|
redirectURL, err := url.Parse(r.Realm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
values := redirectURL.Query()
|
||||||
|
values.Add("service", r.Service)
|
||||||
|
for _, s := range strings.Split(r.Scope, " ") {
|
||||||
|
values.Add("scope", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
|
nonce, err := auth.NewNonce(rand.Reader, 16)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
values.Add("nonce", nonce)
|
||||||
|
|
||||||
|
redirectURL.RawQuery = values.Encode()
|
||||||
|
return redirectURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
|
||||||
|
redirectURL, err := challenge.URL()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
sha256sum := sha256.Sum256(nil)
|
||||||
|
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
signature, err := auth.Sign(ctx, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
headers.Add("Authorization", signature)
|
||||||
|
|
||||||
|
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
if len(body) > 0 {
|
||||||
|
return "", fmt.Errorf("%d: %s", response.StatusCode, body)
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("%d", response.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var token api.TokenResponse
|
||||||
|
if err := json.Unmarshal(body, &token); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token.Token, nil
|
||||||
|
}
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/auth"
|
|
||||||
"github.com/jmorganca/ollama/format"
|
"github.com/jmorganca/ollama/format"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -138,11 +137,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *a
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
|
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
|
||||||
b.err = b.run(ctx, requestURL, opts)
|
b.err = b.run(ctx, requestURL, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||||
defer blobDownloadManager.Delete(b.Digest)
|
defer blobDownloadManager.Delete(b.Digest)
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
@ -211,7 +210,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
|
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
g, ctx := errgroup.WithContext(ctx)
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
|
@ -335,7 +334,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||||
type downloadOpts struct {
|
type downloadOpts struct {
|
||||||
mp ModelPath
|
mp ModelPath
|
||||||
digest string
|
digest string
|
||||||
regOpts *auth.RegistryOptions
|
regOpts *registryOptions
|
||||||
fn func(api.ProgressResponse)
|
fn func(api.ProgressResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,17 +16,25 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/auth"
|
|
||||||
"github.com/jmorganca/ollama/llm"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type registryOptions struct {
|
||||||
|
Insecure bool
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Config ConfigV2
|
Config ConfigV2
|
||||||
|
@ -312,7 +320,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, os.ErrNotExist):
|
case errors.Is(err, os.ErrNotExist):
|
||||||
fn(api.ProgressResponse{Status: "pulling model"})
|
fn(api.ProgressResponse{Status: "pulling model"})
|
||||||
if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
|
if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -832,7 +840,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||||
|
|
||||||
|
@ -882,7 +890,7 @@ func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *ManifestV2
|
||||||
|
@ -988,7 +996,7 @@ func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
|
@ -1020,9 +1028,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||||
|
|
||||||
var errUnauthorized = fmt.Errorf("unauthorized")
|
var errUnauthorized = fmt.Errorf("unauthorized")
|
||||||
|
|
||||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
slog.Info(fmt.Sprintf("request failed: %v", err))
|
slog.Info(fmt.Sprintf("request failed: %v", err))
|
||||||
|
@ -1034,9 +1042,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
switch {
|
switch {
|
||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
// Handle authentication error with one retry
|
// Handle authentication error with one retry
|
||||||
authenticate := resp.Header.Get("www-authenticate")
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
authRedir := ParseAuthRedirectString(authenticate)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
token, err := auth.GetAuthToken(ctx, authRedir)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -1063,6 +1070,58 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
return nil, errUnauthorized
|
return nil, errUnauthorized
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
||||||
|
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||||
|
requestURL.Scheme = "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers != nil {
|
||||||
|
req.Header = headers
|
||||||
|
}
|
||||||
|
|
||||||
|
if regOpts != nil {
|
||||||
|
if regOpts.Token != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
||||||
|
} else if regOpts.Username != "" && regOpts.Password != "" {
|
||||||
|
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
|
||||||
|
if s := req.Header.Get("Content-Length"); s != "" {
|
||||||
|
contentLength, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.ContentLength = contentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, err := http.ProxyFromEnvironment(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
Proxy: http.ProxyURL(proxyURL),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getValue(header, key string) string {
|
func getValue(header, key string) string {
|
||||||
startIdx := strings.Index(header, key+"=")
|
startIdx := strings.Index(header, key+"=")
|
||||||
if startIdx == -1 {
|
if startIdx == -1 {
|
||||||
|
@ -1086,10 +1145,10 @@ func getValue(header, key string) string {
|
||||||
return header[startIdx:endIdx]
|
return header[startIdx:endIdx]
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
|
func parseRegistryChallenge(authStr string) registryChallenge {
|
||||||
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
||||||
|
|
||||||
return auth.AuthRedirect{
|
return registryChallenge{
|
||||||
Realm: getValue(authStr, "realm"),
|
Realm: getValue(authStr, "realm"),
|
||||||
Service: getValue(authStr, "service"),
|
Service: getValue(authStr, "service"),
|
||||||
Scope: getValue(authStr, "scope"),
|
Scope: getValue(authStr, "scope"),
|
||||||
|
|
|
@ -25,7 +25,6 @@ import (
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/auth"
|
|
||||||
"github.com/jmorganca/ollama/gpu"
|
"github.com/jmorganca/ollama/gpu"
|
||||||
"github.com/jmorganca/ollama/llm"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/openai"
|
"github.com/jmorganca/ollama/openai"
|
||||||
|
@ -480,7 +479,7 @@ func PullModelHandler(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
regOpts := &auth.RegistryOptions{
|
regOpts := ®istryOptions{
|
||||||
Insecure: req.Insecure,
|
Insecure: req.Insecure,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -529,7 +528,7 @@ func PushModelHandler(c *gin.Context) {
|
||||||
ch <- r
|
ch <- r
|
||||||
}
|
}
|
||||||
|
|
||||||
regOpts := &auth.RegistryOptions{
|
regOpts := ®istryOptions{
|
||||||
Insecure: req.Insecure,
|
Insecure: req.Insecure,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/auth"
|
|
||||||
"github.com/jmorganca/ollama/format"
|
"github.com/jmorganca/ollama/format"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
@ -50,7 +49,7 @@ const (
|
||||||
maxUploadPartSize int64 = 1000 * format.MegaByte
|
maxUploadPartSize int64 = 1000 * format.MegaByte
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
|
||||||
p, err := GetBlobsPath(b.Digest)
|
p, err := GetBlobsPath(b.Digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -122,7 +121,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *aut
|
||||||
|
|
||||||
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||||
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||||
func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
|
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
|
||||||
defer blobUploadManager.Delete(b.Digest)
|
defer blobUploadManager.Delete(b.Digest)
|
||||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
@ -213,7 +212,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
|
||||||
b.done = true
|
b.done = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
|
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Content-Type", "application/octet-stream")
|
headers.Set("Content-Type", "application/octet-stream")
|
||||||
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
||||||
|
@ -228,7 +227,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||||
md5sum := md5.New()
|
md5sum := md5.New()
|
||||||
w := &progressWriter{blobUpload: b}
|
w := &progressWriter{blobUpload: b}
|
||||||
|
|
||||||
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
|
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Rollback()
|
w.Rollback()
|
||||||
return err
|
return err
|
||||||
|
@ -278,9 +277,8 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||||
|
|
||||||
case resp.StatusCode == http.StatusUnauthorized:
|
case resp.StatusCode == http.StatusUnauthorized:
|
||||||
w.Rollback()
|
w.Rollback()
|
||||||
authenticate := resp.Header.Get("www-authenticate")
|
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
||||||
authRedir := ParseAuthRedirectString(authenticate)
|
token, err := getAuthorizationToken(ctx, challenge)
|
||||||
token, err := auth.GetAuthToken(ctx, authRedir)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -365,7 +363,7 @@ func (p *progressWriter) Rollback() {
|
||||||
p.written = 0
|
p.written = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
requestURL := mp.BaseURL()
|
requestURL := mp.BaseURL()
|
||||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue