2024-02-05 12:59:52 -08:00
|
|
|
package auth
|
2023-08-10 11:34:25 -07:00
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2023-08-11 15:41:55 -07:00
|
|
|
"context"
|
2023-08-10 11:34:25 -07:00
|
|
|
"crypto/rand"
|
|
|
|
"encoding/base64"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
2024-01-18 10:52:01 -08:00
|
|
|
"log/slog"
|
2023-08-10 11:34:25 -07:00
|
|
|
"os"
|
2023-09-19 09:36:30 -07:00
|
|
|
"path/filepath"
|
2024-04-30 11:02:08 -07:00
|
|
|
"strings"
|
2023-08-10 11:34:25 -07:00
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
2024-02-05 12:59:52 -08:00
|
|
|
)
|
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
const defaultPrivateKey = "id_ed25519"
|
2023-08-10 11:34:25 -07:00
|
|
|
|
2024-04-30 11:02:08 -07:00
|
|
|
func keyPath() (string, error) {
|
|
|
|
home, err := os.UserHomeDir()
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GetPublicKey() (string, error) {
|
|
|
|
keyPath, err := keyPath()
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
privateKeyFile, err := os.ReadFile(keyPath)
|
|
|
|
if err != nil {
|
|
|
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
|
|
|
|
|
|
|
return strings.TrimSpace(string(publicKey)), nil
|
|
|
|
}
|
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
func NewNonce(r io.Reader, length int) (string, error) {
|
2023-08-10 11:34:25 -07:00
|
|
|
nonce := make([]byte, length)
|
2024-02-14 11:29:49 -08:00
|
|
|
if _, err := io.ReadFull(r, nonce); err != nil {
|
2023-08-10 11:34:25 -07:00
|
|
|
return "", err
|
|
|
|
}
|
2023-08-21 18:38:31 -07:00
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
2023-08-10 11:34:25 -07:00
|
|
|
}
|
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
func Sign(ctx context.Context, bts []byte) (string, error) {
|
2024-04-30 11:02:08 -07:00
|
|
|
keyPath, err := keyPath()
|
2023-08-10 11:34:25 -07:00
|
|
|
if err != nil {
|
2023-10-20 16:52:48 -07:00
|
|
|
return "", err
|
2023-08-10 11:34:25 -07:00
|
|
|
}
|
2024-02-07 11:00:06 -08:00
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
privateKeyFile, err := os.ReadFile(keyPath)
|
2023-08-10 11:34:25 -07:00
|
|
|
if err != nil {
|
2024-02-14 11:29:49 -08:00
|
|
|
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
2023-08-10 11:34:25 -07:00
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
|
2023-08-10 11:34:25 -07:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// get the pubkey, but remove the type
|
2024-02-14 11:29:49 -08:00
|
|
|
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
|
|
|
|
parts := bytes.Split(publicKey, []byte(" "))
|
2023-08-10 11:34:25 -07:00
|
|
|
if len(parts) < 2 {
|
|
|
|
return "", fmt.Errorf("malformed public key")
|
|
|
|
}
|
|
|
|
|
2024-02-14 11:29:49 -08:00
|
|
|
signedData, err := privateKey.Sign(rand.Reader, bts)
|
2023-08-10 11:34:25 -07:00
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
// signature is <pubkey>:<signature>
|
2024-02-14 11:29:49 -08:00
|
|
|
return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
|
2023-08-10 11:34:25 -07:00
|
|
|
}
|