69 lines
1.7 KiB
Go
69 lines
1.7 KiB
Go
package vector
|
|
|
|
import (
|
|
"container/heap"
|
|
"sort"
|
|
|
|
"gonum.org/v1/gonum/mat"
|
|
)
|
|
|
|
type Embedding struct {
|
|
Vector []float64 // the embedding vector
|
|
Data string // the data represted by the embedding
|
|
}
|
|
|
|
type EmbeddingSimilarity struct {
|
|
Embedding Embedding // the embedding that was used to calculate the similarity
|
|
Similarity float64 // the similarity between the embedding and the query
|
|
}
|
|
|
|
type Heap []EmbeddingSimilarity
|
|
|
|
func (h Heap) Len() int { return len(h) }
|
|
func (h Heap) Less(i, j int) bool { return h[i].Similarity < h[j].Similarity }
|
|
func (h Heap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
func (h *Heap) Push(e any) {
|
|
*h = append(*h, e.(EmbeddingSimilarity))
|
|
}
|
|
|
|
func (h *Heap) Pop() interface{} {
|
|
old := *h
|
|
n := len(old)
|
|
x := old[n-1]
|
|
*h = old[0 : n-1]
|
|
return x
|
|
}
|
|
|
|
// cosineSimilarity is a measure that calculates the cosine of the angle between two vectors.
|
|
// This value will range from -1 to 1, where 1 means the vectors are identical.
|
|
func cosineSimilarity(vec1, vec2 *mat.VecDense) float64 {
|
|
dotProduct := mat.Dot(vec1, vec2)
|
|
norms := mat.Norm(vec1, 2) * mat.Norm(vec2, 2)
|
|
|
|
if norms == 0 {
|
|
return 0
|
|
}
|
|
return dotProduct / norms
|
|
}
|
|
|
|
func TopK(k int, query *mat.VecDense, embeddings []Embedding) []EmbeddingSimilarity {
|
|
h := &Heap{}
|
|
heap.Init(h)
|
|
for _, emb := range embeddings {
|
|
similarity := cosineSimilarity(query, mat.NewVecDense(len(emb.Vector), emb.Vector))
|
|
heap.Push(h, EmbeddingSimilarity{Embedding: emb, Similarity: similarity})
|
|
if h.Len() > k {
|
|
heap.Pop(h)
|
|
}
|
|
}
|
|
|
|
topK := make([]EmbeddingSimilarity, 0, h.Len())
|
|
for h.Len() > 0 {
|
|
topK = append(topK, heap.Pop(h).(EmbeddingSimilarity))
|
|
}
|
|
sort.Slice(topK, func(i, j int) bool {
|
|
return topK[i].Similarity > topK[j].Similarity
|
|
})
|
|
|
|
return topK
|
|
}
|