diff --git a/centroid_index_test.go b/centroid_index_test.go new file mode 100644 index 000000000..a7334236b --- /dev/null +++ b/centroid_index_test.go @@ -0,0 +1,74 @@ +//go:build vectors +// +build vectors + +package bleve + +import ( + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/blevesearch/bleve/v2/analysis/lang/en" + "github.com/blevesearch/bleve/v2/mapping" + index "github.com/blevesearch/bleve_index_api" +) + +func loadSiftData() ([]map[string]interface{}, error) { + fileContent, err := os.ReadFile("~/fts/data/datasets/vec-sift-bucket.json") + if err != nil { + return nil, err + } + var documents []map[string]interface{} + err = json.Unmarshal(fileContent, &documents) + if err != nil { + return nil, err + } + return documents, nil +} + +func TestCentroidIndex(t *testing.T) { + _, _, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents, err := loadSiftData() + if err != nil { + t.Fatal(err) + } + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + + vecFieldMappingL2 := mapping.NewVectorFieldMapping() + vecFieldMappingL2.Dims = 128 + vecFieldMappingL2.Similarity = index.EuclideanDistance + + indexMappingL2Norm := NewIndexMapping() + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2) + + idx, err := newIndexUsing(t.TempDir(), indexMappingL2Norm, Config.DefaultIndexType, Config.DefaultKVStore, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + err := idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for _, doc := range documents[:100000] { + docId := fmt.Sprintf("%s:%s", index.TrainDataPrefix, doc["id"]) + err = batch.Index(docId, doc) + if err != nil { + t.Fatal(err) + } + } + + err = idx.Train(batch) + if err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod index c4bc98254..61d2ae87d 100644 --- a/go.mod +++ b/go.mod @@ -44,3 +44,25 @@ require ( github.com/spf13/pflag v1.0.6 // indirect golang.org/x/sys v0.29.0 // indirect ) + +replace github.com/blevesearch/bleve/v2 => /Users/thejas.orkombu/fts/blevesearch/bleve + +replace github.com/blevesearch/zapx/v11 => /Users/thejas.orkombu/fts/blevesearch/zapx11 + +replace github.com/blevesearch/zapx/v12 => /Users/thejas.orkombu/fts/blevesearch/zapx12 + +replace github.com/blevesearch/zapx/v13 => /Users/thejas.orkombu/fts/blevesearch/zapx13 + +replace github.com/blevesearch/zapx/v14 => /Users/thejas.orkombu/fts/blevesearch/zapx14 + +replace github.com/blevesearch/zapx/v15 => /Users/thejas.orkombu/fts/blevesearch/zapx15 + +replace github.com/blevesearch/zapx/v16 => /Users/thejas.orkombu/fts/blevesearch/zapx + +replace github.com/blevesearch/scorch_segment_api/v2 => /Users/thejas.orkombu/fts/blevesearch/scorch_segment_api + +replace github.com/blevesearch/go-faiss => /Users/thejas.orkombu/fts/blevesearch/go-faiss + +replace github.com/blevesearch/bleve_index_api => /Users/thejas.orkombu/fts/blevesearch/bleve_index_api + +replace github.com/blevesearch/sear => /Users/thejas.orkombu/fts/blevesearch/sear diff --git a/go.sum b/go.sum index b46bebcef..fcb958d67 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,8 @@ github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/ github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4= github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/blevesearch/bleve_index_api v1.2.11 h1:bXQ54kVuwP8hdrXUSOnvTQfgK0KI1+f9A0ITJT8tX1s= -github.com/blevesearch/bleve_index_api v1.2.11/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= github.com/blevesearch/geo v0.2.4 h1:ECIGQhw+QALCZaDcogRTNSJYQXRtC8/m8IKiA706cqk= github.com/blevesearch/geo v0.2.4/go.mod h1:K56Q33AzXt2YExVHGObtmRSFYZKYGv0JEN5mdacJJR8= -github.com/blevesearch/go-faiss v1.0.26 h1:4dRLolFgjPyjkaXwff4NfbZFdE/dfywbzDqporeQvXI= -github.com/blevesearch/go-faiss v1.0.26/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:kDy+zgJFJJoJYBvdfBSiZYBbdsUL0XcjHYWezpQBGPA= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= @@ -20,8 +16,6 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.13 h1:ZPjv/4VwWvHJZKeMSgScCapOy8+DdmsmRyLmSB88UoY= -github.com/blevesearch/scorch_segment_api/v2 v2.3.13/go.mod h1:ENk2LClTehOuMS8XzN3UxBEErYmtwkE7MAArFTXs9Vc= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -34,18 +28,6 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= -github.com/blevesearch/zapx/v11 v11.4.2 h1:l46SV+b0gFN+Rw3wUI1YdMWdSAVhskYuvxlcgpQFljs= -github.com/blevesearch/zapx/v11 v11.4.2/go.mod h1:4gdeyy9oGa/lLa6D34R9daXNUvfMPZqUYjPwiLmekwc= -github.com/blevesearch/zapx/v12 v12.4.2 h1:fzRbhllQmEMUuAQ7zBuMvKRlcPA5ESTgWlDEoB9uQNE= -github.com/blevesearch/zapx/v12 v12.4.2/go.mod h1:TdFmr7afSz1hFh/SIBCCZvcLfzYvievIH6aEISCte58= -github.com/blevesearch/zapx/v13 v13.4.2 h1:46PIZCO/ZuKZYgxI8Y7lOJqX3Irkc3N8W82QTK3MVks= -github.com/blevesearch/zapx/v13 v13.4.2/go.mod h1:knK8z2NdQHlb5ot/uj8wuvOq5PhDGjNYQQy0QDnopZk= -github.com/blevesearch/zapx/v14 v14.4.2 h1:2SGHakVKd+TrtEqpfeq8X+So5PShQ5nW6GNxT7fWYz0= -github.com/blevesearch/zapx/v14 v14.4.2/go.mod h1:rz0XNb/OZSMjNorufDGSpFpjoFKhXmppH9Hi7a877D8= -github.com/blevesearch/zapx/v15 v15.4.2 h1:sWxpDE0QQOTjyxYbAVjt3+0ieu8NCE0fDRaFxEsp31k= -github.com/blevesearch/zapx/v15 v15.4.2/go.mod h1:1pssev/59FsuWcgSnTa0OeEpOzmhtmr/0/11H0Z8+Nw= -github.com/blevesearch/zapx/v16 v16.2.7 h1:xcgFRa7f/tQXOwApVq7JWgPYSlzyUMmkuYa54tMDuR0= -github.com/blevesearch/zapx/v16 v16.2.7/go.mod h1:murSoCJPCk25MqURrcJaBQ1RekuqSCSfMjXH4rHyA14= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= diff --git a/index.go b/index.go index 2f1ba5fbf..bd5421d85 100644 --- a/index.go +++ b/index.go @@ -51,10 +51,12 @@ func (b *Batch) Index(id string, data interface{}) error { eventIndex.FireIndexEvent() } doc := document.NewDocument(id) + // fmt.Printf("data is before mapping %#v\n", data) err := b.index.Mapping().MapDocument(doc, data) if err != nil { return err } + // fmt.Printf("data is after mapping %#v\n", doc) b.internal.Update(doc) b.lastDocSize = uint64(doc.Size() + @@ -353,6 +355,11 @@ type IndexCopyable interface { CopyTo(d index.Directory) error } +type IndexFileCopyable interface { + UpdateFileInBolt(key []byte, value []byte) error + CopyFile(file string, d index.IndexDirectory) error +} + // FileSystemDirectory is the default implementation for the // index.Directory interface. type FileSystemDirectory string @@ -396,3 +403,7 @@ type InsightsIndex interface { // CentroidCardinalities returns the centroids (clusters) from IVF indexes ordered by data density. CentroidCardinalities(field string, limit int, desceding bool) ([]index.CentroidCardinality, error) } +type VectorIndex interface { + Index + Train(*Batch) error +} diff --git a/index/scorch/merge.go b/index/scorch/merge.go index 9abcf2db6..e2aa8b03b 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -360,8 +360,9 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - newDocNums, _, err := s.segPlugin.Merge(segmentsToMerge, docsToDrop, path, - cw.cancelCh, s) + + newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path, + cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) fileMergeZapTime := uint64(time.Since(fileMergeZapStartTime)) @@ -379,7 +380,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, return fmt.Errorf("merging failed: %v", err) } - seg, err = s.segPlugin.Open(path) + seg, err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { s.unmarkIneligibleForRemoval(filename) atomic.AddUint64(&s.stats.TotFileMergePlanTasksErr, 1) @@ -528,7 +529,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, // the newly merged segment is already flushed out to disk, just needs // to be opened using mmap. newDocIDs, _, err := - s.segPlugin.Merge(segsBatch, dropsBatch, path, s.closeCh, s) + s.segPlugin.MergeEx(segsBatch, dropsBatch, path, s.closeCh, s, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) @@ -543,7 +544,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, s.markIneligibleForRemoval(filename) newMergedSegmentIDs[id] = newSegmentID newDocIDsSet[id] = newDocIDs - newMergedSegments[id], err = s.segPlugin.Open(path) + newMergedSegments[id], err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) diff --git a/index/scorch/persister.go b/index/scorch/persister.go index d92c3a85b..622d3ead7 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -793,7 +793,7 @@ func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot, exclude map[uint } }() for segmentID, path := range newSegmentPaths { - newSegments[segmentID], err = s.segPlugin.Open(path) + newSegments[segmentID], err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { return fmt.Errorf("error opening new segment at %s, %v", path, err) } @@ -842,6 +842,22 @@ func zapFileName(epoch uint64) string { return fmt.Sprintf("%012x.zap", epoch) } +func (s *Scorch) updateCentroidIndex(bucket *bolt.Bucket) error { + if bucket == nil { + return nil + } + fmt.Println("updateCentroidIndex bucket", bucket != nil) + segmentSnapshot, err := s.loadSegment(bucket) + if err != nil { + return err + } + s.rootLock.Lock() + defer s.rootLock.Unlock() + fmt.Println("updateCentroidIndex", segmentSnapshot.segment != nil) + s.centroidIndex = segmentSnapshot + return nil +} + // bolt snapshot code func (s *Scorch) loadFromBolt() error { @@ -862,6 +878,12 @@ func (s *Scorch) loadFromBolt() error { s.AddEligibleForRemoval(snapshotEpoch) continue } + // fmt.Println("loadFromBolt key %s", k) + // if k[0] == util.BoltCentroidIndexKey[0] { + // fmt.Println("loadFromBolt centroid index key", string(k)) + + // continue + // } snapshot := snapshots.Bucket(k) if snapshot == nil { log.Printf("snapshot key, but bucket missing %x, continuing", k) @@ -893,6 +915,12 @@ func (s *Scorch) loadFromBolt() error { foundRoot = true } + + centroidIndexBucket := snapshots.Bucket(util.BoltCentroidIndexKey) + err := s.updateCentroidIndex(centroidIndexBucket) + if err != nil { + return err + } return nil }) if err != nil { @@ -1005,7 +1033,7 @@ func (s *Scorch) loadSegment(segmentBucket *bolt.Bucket) (*SegmentSnapshot, erro return nil, fmt.Errorf("segment path missing") } segmentPath := s.path + string(os.PathSeparator) + string(pathBytes) - seg, err := s.segPlugin.Open(segmentPath) + seg, err := s.segPlugin.OpenEx(segmentPath, s.segmentConfig) if err != nil { return nil, fmt.Errorf("error opening bolt segment: %v", err) } diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 83924978e..365ef972a 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -15,10 +15,14 @@ package scorch import ( + "bytes" + "encoding/binary" "encoding/json" "fmt" + "io" "os" "path/filepath" + "strings" "sync" "sync/atomic" "time" @@ -27,6 +31,7 @@ import ( "github.com/blevesearch/bleve/v2/registry" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" + "github.com/blevesearch/go-faiss" segment "github.com/blevesearch/scorch_segment_api/v2" bolt "go.etcd.io/bbolt" ) @@ -45,6 +50,7 @@ type Scorch struct { readOnly bool version uint8 config map[string]interface{} + segmentConfig map[string]interface{} analysisQueue *index.AnalysisQueue path string @@ -77,6 +83,9 @@ type Scorch struct { persisterNotifier chan *epochWatcher rootBolt *bolt.DB asyncTasks sync.WaitGroup + // not a real searchable segment, singleton + centroidIndex *SegmentSnapshot + train chan *trainRequest onEvent func(event Event) bool onAsyncError func(err error, path string) @@ -88,6 +97,12 @@ type Scorch struct { spatialPlugin index.SpatialAnalyzerPlugin } +type trainRequest struct { + sample segment.Segment + vecCount int + ackCh chan error +} + // AsyncPanicError is passed to scorch asyncErrorHandler when panic occurs in scorch background process type AsyncPanicError struct { Source string @@ -123,6 +138,7 @@ func NewScorch(storeName string, forceMergeRequestCh: make(chan *mergerCtrl, 1), segPlugin: defaultSegmentPlugin, copyScheduled: map[string]int{}, + segmentConfig: make(map[string]interface{}), } forcedSegmentType, forcedSegmentVersion, err := configForceSegmentTypeVersion(config) @@ -137,6 +153,12 @@ func NewScorch(storeName string, } } + // "pretraining": true + segConfig, ok := config["segmentConfig"].(map[string]interface{}) + if ok { + rv.segmentConfig = segConfig + } + typ, ok := config["spatialPlugin"].(string) if ok { if err := rv.loadSpatialAnalyzerPlugin(typ); err != nil { @@ -466,7 +488,7 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { stats := newFieldStats() if len(analysisResults) > 0 { - newSegment, bufBytes, err = s.segPlugin.New(analysisResults) + newSegment, bufBytes, err = s.segPlugin.NewEx(analysisResults, s.segmentConfig) if err != nil { return err } @@ -501,6 +523,197 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { return err } +func (s *Scorch) getInternal(key []byte) ([]byte, error) { + s.rootLock.RLock() + defer s.rootLock.RUnlock() + // todo: return the total number of vectors that have been processed so far in training + // in cbft use that as a checkpoint to resume training for n-x samples. + if string(key) == "_centroid_index_complete" { + return []byte(fmt.Sprintf("%t", s.centroidIndex != nil)), nil + } + return nil, nil +} + +// this is not a routine that will be running throughout the lifetime of the index. It's purpose +// is to only train the vector index before the data ingestion starts. +func (s *Scorch) trainerLoop() { + // some init stuff + s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + var totalSamplesProcessed int + filename := "centroid_index" + path := filepath.Join(s.path, filename) + buf := make([]byte, binary.MaxVarintLen64) + for { + select { + case <-s.closeCh: + return + case trainReq := <-s.train: + sampleSeg := trainReq.sample + if s.centroidIndex == nil { + // new centroid index + s.centroidIndex = &SegmentSnapshot{ + segment: sampleSeg, + } + switch seg := sampleSeg.(type) { + case segment.UnpersistedSegment: + err := persistToDirectory(seg, nil, path) + if err != nil { + // clean up this ugly ass error handling code + trainReq.ackCh <- fmt.Errorf("error persisting segment: %v", err) + close(trainReq.ackCh) + } + default: + fmt.Errorf("segment is not a unpersisted segment") + close(s.closeCh) + } + } else { + // merge the new segment with the existing one, no need to persist? + // persist in a tmp file and then rename - is that a fair strategy? + _, _, err := s.segPlugin.MergeEx([]segment.Segment{s.centroidIndex.segment, sampleSeg}, + []*roaring.Bitmap{nil, nil}, "centroid_index.tmp", s.closeCh, nil, s.segmentConfig) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error merging centroid index: %v", err) + close(trainReq.ackCh) + } + + // close the existing centroid segment - it's supposed to be gc'd at this point + s.centroidIndex.segment.Close() + err = os.Rename(filepath.Join(s.path, "centroid_index.tmp"), filepath.Join(s.path, "centroid_index")) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error renaming centroid index: %v", err) + close(trainReq.ackCh) + } + } + + totalSamplesProcessed += trainReq.vecCount + // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint + // where we can be sure that the centroid index is available for the indexing operations downstream + // + // note: when the scale increases massively especially with real world dimensions of 1536+, this API + // will have to be refactored to persist in a more resource efficient way. so having this bolt related + // code will help in tracking the progress a lot better and avoid any redudant data streaming operations. + tx, err := s.rootBolt.Begin(true) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error starting bolt transaction: %v", err) + close(trainReq.ackCh) + } + defer tx.Rollback() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating snapshots bucket: %v", err) + close(trainReq.ackCh) + } + + centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating centroid bucket: %v", err) + close(trainReq.ackCh) + } + + err = centroidBucket.Put(util.BoltPathKey, []byte(filename)) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error updating centroid bucket: %v", err) + close(trainReq.ackCh) + } + + // total number of vectors that have been processed so far for the training + n := binary.PutUvarint(buf, uint64(totalSamplesProcessed)) + err = centroidBucket.Put(util.BoltVecSamplesProcessedKey, buf[:n]) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error updating vec samples processed: %v", err) + close(trainReq.ackCh) + } + + err = tx.Commit() + if err != nil { + trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err) + close(trainReq.ackCh) + } + + close(trainReq.ackCh) + } + } +} + +func (s *Scorch) Train(batch *index.Batch) error { + // regulate the Train function + s.FireIndexEvent() + + // is the lock really needed? + s.rootLock.Lock() + defer s.rootLock.Unlock() + if s.centroidIndex != nil { + // singleton API + return nil + } + var trainData []index.Document + if s.centroidIndex == nil { + for key, doc := range batch.IndexOps { + if doc != nil { + // insert _id field + // no need to track updates/deletes over here since + // the API is singleton + doc.AddIDField() + } + if strings.HasPrefix(key, index.TrainDataPrefix) { + trainData = append(trainData, doc) + } + } + } + + // just builds a new vector index out of the train data provided + // it'll be an IVF index so the centroids are computed at this stage and + // this template will be used in the indexing down the line to index + // the data vectors. s.segmentConfig will mark this as a training phase + // and zap will handle it accordingly. + // + // note: this might index text data too, how to handle this? s.segmentConfig? + // todo: updates/deletes -> data drift detection + s.segmentConfig["training"] = true + seg, n, err := s.segPlugin.NewEx(trainData, s.segmentConfig) + if err != nil { + return err + } + // reset the training flag once completed + s.segmentConfig["training"] = false + + trainReq := &trainRequest{ + sample: seg, + vecCount: len(trainData), // todo: multivector support + ackCh: make(chan error), + } + + s.train <- trainReq + err = <-trainReq.ackCh + if err != nil { + return err + } + + centroidIndex, err := s.segPlugin.OpenEx(filepath.Join(s.path, "centroid_index"), s.segmentConfig) + if err != nil { + return err + } + s.centroidIndex = &SegmentSnapshot{ + segment: centroidIndex, + } + fmt.Println("number of bytes written to centroid index", n) + return err +} + +func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) { + // return the coarse quantizer of the centroid index belonging to the field + centroidIndexSegment, ok := s.centroidIndex.segment.(segment.CentroidIndexSegment) + if !ok { + return nil, fmt.Errorf("segment is not a centroid index segment", s.centroidIndex.segment != nil) + } + coarseQuantizer, err := centroidIndexSegment.GetCoarseQuantizer(field) + if err != nil { + return nil, err + } + return coarseQuantizer, nil +} + func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, internalOps map[string][]byte, persistedCallback index.BatchCallback, stats *fieldStats, ) error { @@ -940,6 +1153,84 @@ func (s *Scorch) CopyReader() index.CopyReader { return rv } +func (s *Scorch) UpdateFileInBolt(key []byte, value []byte) error { + tx, err := s.rootBolt.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + return err + } + + // currently this is specific to centroid index file update + if bytes.Equal(key, util.BoltCentroidIndexKey) { + // todo: guard against duplicate updates + centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + return err + } + if centroidBucket == nil { + return fmt.Errorf("centroid bucket not found") + } + existingValue := centroidBucket.Get(util.BoltPathKey) + if existingValue != nil { + return fmt.Errorf("key already exists %v %v", s.path, string(existingValue)) + } + + err = centroidBucket.Put(util.BoltPathKey, value) + if err != nil { + return err + } + } + + err = tx.Commit() + if err != nil { + return err + } + + err = s.rootBolt.Sync() + if err != nil { + return err + } + + return nil +} + +// CopyFile copies a specific file to a destination directory which has an access to a bleve index +// doing a io.Copy() isn't enough because the file needs to be tracked in bolt file as well +func (s *Scorch) CopyFile(file string, d index.IndexDirectory) error { + s.rootLock.Lock() + defer s.rootLock.Unlock() + + // this code is currently specific to centroid index file but is future proofed for other files + // to be updated in the dest's bolt + if strings.HasSuffix(file, "centroid_index") { + // centroid index file - this is outside the snapshots domain so the bolt update is different + err := d.UpdateFileInBolt(util.BoltCentroidIndexKey, []byte(file)) + if err != nil { + return fmt.Errorf("error updating dest index bolt: %w", err) + } + } + + dest, err := d.GetWriter(filepath.Join("store", file)) + if err != nil { + return err + } + + source, err := os.Open(filepath.Join(s.path, file)) + if err != nil { + return err + } + + defer source.Close() + defer dest.Close() + _, err = io.Copy(dest, source) + return err +} + // external API to fire a scorch event (EventKindIndexStart) externally from bleve func (s *Scorch) FireIndexEvent() { s.fireEvent(EventKindIndexStart, 0) diff --git a/index/scorch/segment_plugin.go b/index/scorch/segment_plugin.go index 790a8008a..baa3c21dc 100644 --- a/index/scorch/segment_plugin.go +++ b/index/scorch/segment_plugin.go @@ -45,10 +45,14 @@ type SegmentPlugin interface { // New takes a set of Documents and turns them into a new Segment New(results []index.Document) (segment.Segment, uint64, error) + NewEx(results []index.Document, config map[string]interface{}) (segment.Segment, uint64, error) + // Open attempts to open the file at the specified path and // return the corresponding Segment Open(path string) (segment.Segment, error) + OpenEx(path string, config map[string]interface{}) (segment.Segment, error) + // Merge takes a set of Segments, and creates a new segment on disk at // the specified path. // Drops is a set of bitmaps (one for each segment) indicating which @@ -66,6 +70,10 @@ type SegmentPlugin interface { Merge(segments []segment.Segment, drops []*roaring.Bitmap, path string, closeCh chan struct{}, s segment.StatsReporter) ( [][]uint64, uint64, error) + + MergeEx(segments []segment.Segment, drops []*roaring.Bitmap, path string, + closeCh chan struct{}, s segment.StatsReporter, config map[string]interface{}) ( + [][]uint64, uint64, error) } var supportedSegmentPlugins map[string]map[uint32]SegmentPlugin diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 3f2a330c5..6be348ab1 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -68,7 +68,8 @@ func init() { } type IndexSnapshot struct { - parent *Scorch + parent *Scorch + segment []*SegmentSnapshot offsets []uint64 internal map[string][]byte @@ -468,6 +469,10 @@ func (is *IndexSnapshot) Fields() ([]string, error) { } func (is *IndexSnapshot) GetInternal(key []byte) ([]byte, error) { + _, ok := is.internal[string(key)] + if !ok { + return is.parent.getInternal(key) + } return is.internal[string(key)], nil } diff --git a/index_alias_impl.go b/index_alias_impl.go index 8212c74b9..16f20ac45 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -103,6 +103,14 @@ func (i *indexAliasImpl) IndexSynonym(id string, collection string, definition * return ErrorSynonymSearchNotSupported } +func (i *indexAliasImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + // TODO: implement this + return fmt.Errorf("not a vector index") +} + func (i *indexAliasImpl) Delete(id string) error { i.mutex.RLock() defer i.mutex.RUnlock() diff --git a/index_impl.go b/index_impl.go index 8065d9c1e..18fae745c 100644 --- a/index_impl.go +++ b/index_impl.go @@ -326,11 +326,13 @@ func (i *indexImpl) Index(id string, data interface{}) (err error) { i.FireIndexEvent() + // fmt.Printf("data is %#v\n", data) doc := document.NewDocument(id) err = i.m.MapDocument(doc, data) if err != nil { return } + // fmt.Printf("data is after mapping %#v\n", doc) err = i.i.Update(doc) return } @@ -369,6 +371,20 @@ func (i *indexImpl) IndexSynonym(id string, collection string, definition *Synon return err } +func (i *indexImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + if vi, ok := i.i.(index.VectorIndex); ok { + return vi.Train(batch.internal) + } + return fmt.Errorf("not a vector index") +} + // IndexAdvanced takes a document.Document object // skips the mapping and indexes it. func (i *indexImpl) IndexAdvanced(doc *document.Document) (err error) { @@ -1362,6 +1378,39 @@ func (m *searchHitSorter) Less(i, j int) bool { return c < 0 } +func (i *indexImpl) CopyFile(file string, d index.IndexDirectory) (err error) { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + copyIndex, ok := i.i.(index.IndexFileCopyable) + if !ok { + return fmt.Errorf("index implementation does not support copy reader") + } + + return copyIndex.CopyFile(file, d) +} + +func (i *indexImpl) UpdateFileInBolt(key []byte, value []byte) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + copyIndex, ok := i.i.(index.IndexFileCopyable) + if !ok { + return fmt.Errorf("index implementation does not support file copy") + } + + return copyIndex.UpdateFileInBolt(key, value) +} + +// CopyTo (index.Directory, filter) func (i *indexImpl) CopyTo(d index.Directory) (err error) { i.mutex.RLock() defer i.mutex.RUnlock() @@ -1375,6 +1424,8 @@ func (i *indexImpl) CopyTo(d index.Directory) (err error) { return fmt.Errorf("index implementation does not support copy reader") } + // copyIndex.Copy() -> copies the centroid index + copyReader := copyIndex.CopyReader() if copyReader == nil { return fmt.Errorf("index's copyReader is nil") @@ -1388,7 +1439,7 @@ func (i *indexImpl) CopyTo(d index.Directory) (err error) { err = copyReader.CopyTo(d) if err != nil { - return fmt.Errorf("error copying index metadata: %v", err) + return fmt.Errorf("error copying index data: %v", err) } // copy the metadata diff --git a/util/keys.go b/util/keys.go index b71a7f48b..67415e782 100644 --- a/util/keys.go +++ b/util/keys.go @@ -17,6 +17,8 @@ package util var ( // Bolt keys BoltSnapshotsBucket = []byte{'s'} + BoltCentroidIndexKey = []byte{'c'} + BoltVecSamplesProcessedKey = []byte{'v'} BoltPathKey = []byte{'p'} BoltDeletedKey = []byte{'d'} BoltInternalKey = []byte{'i'}