From 1e28c1ed7e0e677e48ad4b70bcdbfd369d2b4922 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 16 Jun 2025 15:23:42 +0530 Subject: [PATCH 01/12] fastmerge wip --- index/scorch/snapshot_index.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 3f2a330c5..ad2c781fe 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -68,13 +68,16 @@ func init() { } type IndexSnapshot struct { - parent *Scorch - segment []*SegmentSnapshot - offsets []uint64 - internal map[string][]byte - epoch uint64 - size uint64 - creator string + parent *Scorch + + // POC: trainData is ephemeral + trainData [][]float32 + segment []*SegmentSnapshot + offsets []uint64 + internal map[string][]byte + epoch uint64 + size uint64 + creator string m sync.Mutex // Protects the fields that follow. refs int64 From e626b64541f7baedf8a940c6bb436f4cda7624a7 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 16 Jun 2025 15:28:24 +0530 Subject: [PATCH 02/12] passing zap config via new plugin APIs --- index/scorch/merge.go | 10 +++++----- index/scorch/persister.go | 4 ++-- index/scorch/scorch.go | 4 +++- index/scorch/segment_plugin.go | 8 ++++++++ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index 9abcf2db6..df0cf69b7 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -360,8 +360,8 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - newDocNums, _, err := s.segPlugin.Merge(segmentsToMerge, docsToDrop, path, - cw.cancelCh, s) + newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path, + cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) fileMergeZapTime := uint64(time.Since(fileMergeZapStartTime)) @@ -379,7 +379,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, return fmt.Errorf("merging failed: %v", err) } - seg, err = s.segPlugin.Open(path) + seg, err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { s.unmarkIneligibleForRemoval(filename) atomic.AddUint64(&s.stats.TotFileMergePlanTasksErr, 1) @@ -528,7 +528,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, // the newly merged segment is already flushed out to disk, just needs // to be opened using mmap. newDocIDs, _, err := - s.segPlugin.Merge(segsBatch, dropsBatch, path, s.closeCh, s) + s.segPlugin.MergeEx(segsBatch, dropsBatch, path, s.closeCh, s, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) @@ -543,7 +543,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, s.markIneligibleForRemoval(filename) newMergedSegmentIDs[id] = newSegmentID newDocIDsSet[id] = newDocIDs - newMergedSegments[id], err = s.segPlugin.Open(path) + newMergedSegments[id], err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { em.Lock() errs = append(errs, err) diff --git a/index/scorch/persister.go b/index/scorch/persister.go index d92c3a85b..1e4860150 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -793,7 +793,7 @@ func (s *Scorch) persistSnapshotDirect(snapshot *IndexSnapshot, exclude map[uint } }() for segmentID, path := range newSegmentPaths { - newSegments[segmentID], err = s.segPlugin.Open(path) + newSegments[segmentID], err = s.segPlugin.OpenEx(path, s.segmentConfig) if err != nil { return fmt.Errorf("error opening new segment at %s, %v", path, err) } @@ -1005,7 +1005,7 @@ func (s *Scorch) loadSegment(segmentBucket *bolt.Bucket) (*SegmentSnapshot, erro return nil, fmt.Errorf("segment path missing") } segmentPath := s.path + string(os.PathSeparator) + string(pathBytes) - seg, err := s.segPlugin.Open(segmentPath) + seg, err := s.segPlugin.OpenEx(segmentPath, s.segmentConfig) if err != nil { return nil, fmt.Errorf("error opening bolt segment: %v", err) } diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 83924978e..32e91842a 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -45,6 +45,7 @@ type Scorch struct { readOnly bool version uint8 config map[string]interface{} + segmentConfig map[string]interface{} analysisQueue *index.AnalysisQueue path string @@ -123,6 +124,7 @@ func NewScorch(storeName string, forceMergeRequestCh: make(chan *mergerCtrl, 1), segPlugin: defaultSegmentPlugin, copyScheduled: map[string]int{}, + segmentConfig: make(map[string]interface{}), } forcedSegmentType, forcedSegmentVersion, err := configForceSegmentTypeVersion(config) @@ -466,7 +468,7 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { stats := newFieldStats() if len(analysisResults) > 0 { - newSegment, bufBytes, err = s.segPlugin.New(analysisResults) + newSegment, bufBytes, err = s.segPlugin.NewEx(analysisResults, s.segmentConfig) if err != nil { return err } diff --git a/index/scorch/segment_plugin.go b/index/scorch/segment_plugin.go index 790a8008a..baa3c21dc 100644 --- a/index/scorch/segment_plugin.go +++ b/index/scorch/segment_plugin.go @@ -45,10 +45,14 @@ type SegmentPlugin interface { // New takes a set of Documents and turns them into a new Segment New(results []index.Document) (segment.Segment, uint64, error) + NewEx(results []index.Document, config map[string]interface{}) (segment.Segment, uint64, error) + // Open attempts to open the file at the specified path and // return the corresponding Segment Open(path string) (segment.Segment, error) + OpenEx(path string, config map[string]interface{}) (segment.Segment, error) + // Merge takes a set of Segments, and creates a new segment on disk at // the specified path. // Drops is a set of bitmaps (one for each segment) indicating which @@ -66,6 +70,10 @@ type SegmentPlugin interface { Merge(segments []segment.Segment, drops []*roaring.Bitmap, path string, closeCh chan struct{}, s segment.StatsReporter) ( [][]uint64, uint64, error) + + MergeEx(segments []segment.Segment, drops []*roaring.Bitmap, path string, + closeCh chan struct{}, s segment.StatsReporter, config map[string]interface{}) ( + [][]uint64, uint64, error) } var supportedSegmentPlugins map[string]map[uint32]SegmentPlugin From 328b5a0ded46541ad737e6d20907e1bf0f9374b7 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 17 Jun 2025 12:12:16 +0530 Subject: [PATCH 03/12] use callbacks to collect and use train data while merging --- index/scorch/introducer.go | 4 ++++ index/scorch/merge.go | 20 ++++++++++++++++++++ index/scorch/snapshot_index.go | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/index/scorch/introducer.go b/index/scorch/introducer.go index 8191584d2..e16deae0f 100644 --- a/index/scorch/introducer.go +++ b/index/scorch/introducer.go @@ -352,6 +352,10 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { creator: "introduceMerge", } + if len(nextMerge.trainData) > 0 { + newSnapshot.trainData = append(root.trainData, nextMerge.trainData...) + } + var running, docsToPersistCount, memSegments, fileSegments uint64 var droppedSegmentFiles []string newSegmentDeleted := make([]*roaring.Bitmap, len(nextMerge.new)) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index df0cf69b7..229cb3c87 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -17,6 +17,7 @@ package scorch import ( "context" "fmt" + "math" "os" "strings" "sync" @@ -360,6 +361,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) + s.segmentConfig["trainData"] = ourSnapshot.trainData newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path, cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) @@ -469,6 +471,7 @@ type mergedSegmentHistory struct { type segmentMerge struct { id []uint64 new []segment.Segment + trainData [][]float32 mergedSegHistory map[uint64]*mergedSegmentHistory notifyCh chan *mergeTaskIntroStatus mmaped uint32 @@ -515,6 +518,22 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, var em sync.Mutex var errs []error + var trainingSample [][]float32 + collectTrainData := func(segTrainData [][]float32) { + trainingSample = append(trainingSample, segTrainData...) + } + + numDocs, err := snapshot.DocCount() + if err != nil { + return nil, nil, err + } + trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 39) + + // collect train data only if needed + if len(snapshot.trainData) < int(trainingSampleSize) { + s.segmentConfig["collectTrainDataCallback"] = collectTrainData + } + s.segmentConfig["trainData"] = snapshot.trainData // deploy the workers to merge and flush the batches of segments concurrently // and create a new file segment for i := 0; i < numFlushes; i++ { @@ -589,6 +608,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, mergedSegHistory: make(map[uint64]*mergedSegmentHistory, numSegments), notifyCh: make(chan *mergeTaskIntroStatus), newCount: newMergedCount, + trainData: trainingSample, } // create a history map which maps the old in-memory segments with the specific diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index ad2c781fe..ec4ec12d7 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -70,7 +70,7 @@ func init() { type IndexSnapshot struct { parent *Scorch - // POC: trainData is ephemeral + // POC: trainData is ephemeral and read-only just like []*SegmentSnapshot trainData [][]float32 segment []*SegmentSnapshot offsets []uint64 From aee2333b0511deeab85b7db5a766d2823563f763 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Wed, 18 Jun 2025 16:49:32 +0530 Subject: [PATCH 04/12] serialized float array --- index/scorch/merge.go | 10 ++++++---- index/scorch/snapshot_index.go | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index 229cb3c87..b25a920bd 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -19,6 +19,7 @@ import ( "fmt" "math" "os" + "slices" "strings" "sync" "sync/atomic" @@ -471,7 +472,7 @@ type mergedSegmentHistory struct { type segmentMerge struct { id []uint64 new []segment.Segment - trainData [][]float32 + trainData []float32 mergedSegHistory map[uint64]*mergedSegmentHistory notifyCh chan *mergeTaskIntroStatus mmaped uint32 @@ -518,9 +519,10 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, var em sync.Mutex var errs []error - var trainingSample [][]float32 - collectTrainData := func(segTrainData [][]float32) { - trainingSample = append(trainingSample, segTrainData...) + var trainingSample []float32 + collectTrainData := func(segTrainData []float32) { + // append a clone of the training sample + trainingSample = append(trainingSample, slices.Clone(segTrainData)...) } numDocs, err := snapshot.DocCount() diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index ec4ec12d7..70269d5ec 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -71,7 +71,7 @@ type IndexSnapshot struct { parent *Scorch // POC: trainData is ephemeral and read-only just like []*SegmentSnapshot - trainData [][]float32 + trainData []float32 segment []*SegmentSnapshot offsets []uint64 internal map[string][]byte From 350cf7dcb1657382f7b7f547c70aa3ce7f64c313 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 8 Jul 2025 11:46:06 +0530 Subject: [PATCH 05/12] collect training sample on the file path as well --- index/scorch/merge.go | 29 +++++++++++++++++++++++------ index/scorch/persister.go | 4 ++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index b25a920bd..f5592f274 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -353,6 +353,11 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, var seg segment.Segment var filename string + var trainingSample []float32 + collectTrainData := func(segTrainData []float32) { + // append a clone of the training sample + trainingSample = append(trainingSample, slices.Clone(segTrainData)...) + } if len(segmentsToMerge) > 0 { filename = zapFileName(newSegmentID) s.markIneligibleForRemoval(filename) @@ -362,7 +367,13 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - s.segmentConfig["trainData"] = ourSnapshot.trainData + + trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(1000000)) * 39) + if len(ourSnapshot.trainData) < int(trainingSampleSize) { + s.segmentConfig["collectTrainDataCallback"] = collectTrainData + } else { + s.segmentConfig["trainData"] = ourSnapshot.trainData + } newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path, cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) @@ -408,6 +419,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, newCount: seg.Count(), notifyCh: make(chan *mergeTaskIntroStatus), mmaped: 1, + trainData: trainingSample, } s.fireEvent(EventKindMergeTaskIntroductionStart, 0) @@ -525,17 +537,22 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, trainingSample = append(trainingSample, slices.Clone(segTrainData)...) } - numDocs, err := snapshot.DocCount() - if err != nil { - return nil, nil, err - } + // numDocs, err := snapshot.DocCount() + // if err != nil { + // return nil, nil, err + // } + + // harcoding the total docs for now, need to get it from CB level + numDocs := 1000000 trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 39) // collect train data only if needed if len(snapshot.trainData) < int(trainingSampleSize) { s.segmentConfig["collectTrainDataCallback"] = collectTrainData + } else { + s.segmentConfig["trainData"] = snapshot.trainData } - s.segmentConfig["trainData"] = snapshot.trainData + // deploy the workers to merge and flush the batches of segments concurrently // and create a new file segment for i := 0; i < numFlushes; i++ { diff --git a/index/scorch/persister.go b/index/scorch/persister.go index 1e4860150..d505d2c95 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -994,6 +994,10 @@ func (s *Scorch) loadSnapshot(snapshot *bolt.Bucket) (*IndexSnapshot, error) { rv.MergeUpdateFieldsInfo(segmentSnapshot.updatedFields) } running += segmentSnapshot.segment.Count() + // persistedSegment, ok := segmentSnapshot.segment.(segment.PersistedSegment) + // if ok { + // fmt.Println("segment path", persistedSegment.Path()) + // } } } return rv, nil From e5bf978063db40a6eea92a381d9348108ca6bc56 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 21 Aug 2025 19:52:06 +0530 Subject: [PATCH 06/12] cleanup debug logs --- index/scorch/introducer.go | 10 +++++++++- index/scorch/merge.go | 10 +++++----- index/scorch/persister.go | 4 ---- index/scorch/snapshot_index.go | 14 ++++++++------ 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/index/scorch/introducer.go b/index/scorch/introducer.go index e16deae0f..261f6f1db 100644 --- a/index/scorch/introducer.go +++ b/index/scorch/introducer.go @@ -126,6 +126,10 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { creator: "introduceSegment", } + if len(root.trainData) > 0 { + newSnapshot.trainData = root.trainData + } + // iterate through current segments var running uint64 var docsToPersistCount, memSegments, fileSegments uint64 @@ -276,6 +280,10 @@ func (s *Scorch) introducePersist(persist *persistIntroduction) { creator: "introducePersist", } + if len(root.trainData) > 0 { + newIndexSnapshot.trainData = root.trainData + } + var docsToPersistCount, memSegments, fileSegments uint64 for i, segmentSnapshot := range root.segment { // see if this segment has been replaced @@ -353,7 +361,7 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { } if len(nextMerge.trainData) > 0 { - newSnapshot.trainData = append(root.trainData, nextMerge.trainData...) + newSnapshot.trainData = nextMerge.trainData } var running, docsToPersistCount, memSegments, fileSegments uint64 diff --git a/index/scorch/merge.go b/index/scorch/merge.go index f5592f274..901c466e1 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -368,8 +368,8 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(1000000)) * 39) - if len(ourSnapshot.trainData) < int(trainingSampleSize) { + trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(1000000)) * 50) + if len(ourSnapshot.trainData)/768 < int(trainingSampleSize) { s.segmentConfig["collectTrainDataCallback"] = collectTrainData } else { s.segmentConfig["trainData"] = ourSnapshot.trainData @@ -534,7 +534,7 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, var trainingSample []float32 collectTrainData := func(segTrainData []float32) { // append a clone of the training sample - trainingSample = append(trainingSample, slices.Clone(segTrainData)...) + trainingSample = append(trainingSample, segTrainData...) } // numDocs, err := snapshot.DocCount() @@ -544,10 +544,10 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, // harcoding the total docs for now, need to get it from CB level numDocs := 1000000 - trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 39) + trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 50) // collect train data only if needed - if len(snapshot.trainData) < int(trainingSampleSize) { + if len(snapshot.trainData)/768 < int(trainingSampleSize) { s.segmentConfig["collectTrainDataCallback"] = collectTrainData } else { s.segmentConfig["trainData"] = snapshot.trainData diff --git a/index/scorch/persister.go b/index/scorch/persister.go index d505d2c95..1e4860150 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -994,10 +994,6 @@ func (s *Scorch) loadSnapshot(snapshot *bolt.Bucket) (*IndexSnapshot, error) { rv.MergeUpdateFieldsInfo(segmentSnapshot.updatedFields) } running += segmentSnapshot.segment.Count() - // persistedSegment, ok := segmentSnapshot.segment.(segment.PersistedSegment) - // if ok { - // fmt.Println("segment path", persistedSegment.Path()) - // } } } return rv, nil diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 70269d5ec..91099a848 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -72,12 +72,14 @@ type IndexSnapshot struct { // POC: trainData is ephemeral and read-only just like []*SegmentSnapshot trainData []float32 - segment []*SegmentSnapshot - offsets []uint64 - internal map[string][]byte - epoch uint64 - size uint64 - creator string + // trainSegments []*SegmentSnapshot // either store []float32 or []faissIndexes aka centroid indexes + + segment []*SegmentSnapshot + offsets []uint64 + internal map[string][]byte + epoch uint64 + size uint64 + creator string m sync.Mutex // Protects the fields that follow. refs int64 From f72871062e2f58fabd9841db99b7d5c34c210b97 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 28 Oct 2025 16:30:10 -0700 Subject: [PATCH 07/12] vector sources API --- mapping/mapping.go | 1 + mapping/mapping_no_vectors.go | 7 +++++++ mapping/mapping_vectors.go | 15 +++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/mapping/mapping.go b/mapping/mapping.go index a6c1591b8..8a2aaaaba 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -57,6 +57,7 @@ type IndexMapping interface { AnalyzerNamed(name string) analysis.Analyzer FieldMappingForPath(path string) FieldMapping + VectorSources() []string } // A SynonymMapping extends the IndexMapping interface to provide diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index 90cb1e225..2f8c312e4 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -42,3 +42,10 @@ func validateFieldMapping(field *FieldMapping, parentName string, fieldAliasCtx map[string]*FieldMapping) error { return validateFieldType(field) } + +// ----------------------------------------------------------------------------- +// vector source functions + +func (im *IndexMappingImpl) VectorSources() []string { + return []string{"vector indexing is not implemented"} +} diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 20cbac6a8..88184036e 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -270,3 +270,18 @@ func NormalizeVector(vec []float32) []float32 { // normalize the vector copy using in-place normalization provided by faiss return faiss.NormalizeVector(vecCopy) } + +// ----------------------------------------------------------------------------- +// vector source functions + +func (im *IndexMappingImpl) VectorSources() []string { + var sources []string + for name, v := range im.TypeMapping { + for _, f := range v.Fields { + if f.Type == "vector" || f.Type == "vector_base64" { + sources = append(sources, name) + } + } + } + return sources +} From d3cae6f43ac5e156b890d3fd232f09d42e067214 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Wed, 26 Nov 2025 11:14:04 -0800 Subject: [PATCH 08/12] batch training support --- go.mod | 4 ++ index.go | 4 ++ index/scorch/introducer.go | 12 ------ index/scorch/merge.go | 38 ----------------- index/scorch/persister.go | 5 +++ index/scorch/scorch.go | 76 ++++++++++++++++++++++++++++++++++ index/scorch/snapshot_index.go | 4 -- index_alias_impl.go | 19 +++++++++ index_impl.go | 17 ++++++++ mapping/mapping.go | 1 - mapping/mapping_no_vectors.go | 7 ---- mapping/mapping_vectors.go | 15 ------- 12 files changed, 125 insertions(+), 77 deletions(-) diff --git a/go.mod b/go.mod index c4bc98254..2604a57df 100644 --- a/go.mod +++ b/go.mod @@ -44,3 +44,7 @@ require ( github.com/spf13/pflag v1.0.6 // indirect golang.org/x/sys v0.29.0 // indirect ) + +replace github.com/blevesearch/scorch_segment_api/v2 => /Users/thejas.orkombu/fts/blevesearch/scorch_segment_api + +replace github.com/blevesearch/bleve_index_api => /Users/thejas.orkombu/fts/blevesearch/bleve_index_api \ No newline at end of file diff --git a/index.go b/index.go index 2f1ba5fbf..c083787c4 100644 --- a/index.go +++ b/index.go @@ -396,3 +396,7 @@ type InsightsIndex interface { // CentroidCardinalities returns the centroids (clusters) from IVF indexes ordered by data density. CentroidCardinalities(field string, limit int, desceding bool) ([]index.CentroidCardinality, error) } +type VectorIndex interface { + Index + Train(*Batch) error +} diff --git a/index/scorch/introducer.go b/index/scorch/introducer.go index 261f6f1db..8191584d2 100644 --- a/index/scorch/introducer.go +++ b/index/scorch/introducer.go @@ -126,10 +126,6 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { creator: "introduceSegment", } - if len(root.trainData) > 0 { - newSnapshot.trainData = root.trainData - } - // iterate through current segments var running uint64 var docsToPersistCount, memSegments, fileSegments uint64 @@ -280,10 +276,6 @@ func (s *Scorch) introducePersist(persist *persistIntroduction) { creator: "introducePersist", } - if len(root.trainData) > 0 { - newIndexSnapshot.trainData = root.trainData - } - var docsToPersistCount, memSegments, fileSegments uint64 for i, segmentSnapshot := range root.segment { // see if this segment has been replaced @@ -360,10 +352,6 @@ func (s *Scorch) introduceMerge(nextMerge *segmentMerge) { creator: "introduceMerge", } - if len(nextMerge.trainData) > 0 { - newSnapshot.trainData = nextMerge.trainData - } - var running, docsToPersistCount, memSegments, fileSegments uint64 var droppedSegmentFiles []string newSegmentDeleted := make([]*roaring.Bitmap, len(nextMerge.new)) diff --git a/index/scorch/merge.go b/index/scorch/merge.go index 901c466e1..e2aa8b03b 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -17,9 +17,7 @@ package scorch import ( "context" "fmt" - "math" "os" - "slices" "strings" "sync" "sync/atomic" @@ -353,11 +351,6 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, var seg segment.Segment var filename string - var trainingSample []float32 - collectTrainData := func(segTrainData []float32) { - // append a clone of the training sample - trainingSample = append(trainingSample, slices.Clone(segTrainData)...) - } if len(segmentsToMerge) > 0 { filename = zapFileName(newSegmentID) s.markIneligibleForRemoval(filename) @@ -368,12 +361,6 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1) prevBytesReadTotal := cumulateBytesRead(segmentsToMerge) - trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(1000000)) * 50) - if len(ourSnapshot.trainData)/768 < int(trainingSampleSize) { - s.segmentConfig["collectTrainDataCallback"] = collectTrainData - } else { - s.segmentConfig["trainData"] = ourSnapshot.trainData - } newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path, cw.cancelCh, s, s.segmentConfig) atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1) @@ -419,7 +406,6 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, newCount: seg.Count(), notifyCh: make(chan *mergeTaskIntroStatus), mmaped: 1, - trainData: trainingSample, } s.fireEvent(EventKindMergeTaskIntroductionStart, 0) @@ -484,7 +470,6 @@ type mergedSegmentHistory struct { type segmentMerge struct { id []uint64 new []segment.Segment - trainData []float32 mergedSegHistory map[uint64]*mergedSegmentHistory notifyCh chan *mergeTaskIntroStatus mmaped uint32 @@ -531,28 +516,6 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, var em sync.Mutex var errs []error - var trainingSample []float32 - collectTrainData := func(segTrainData []float32) { - // append a clone of the training sample - trainingSample = append(trainingSample, segTrainData...) - } - - // numDocs, err := snapshot.DocCount() - // if err != nil { - // return nil, nil, err - // } - - // harcoding the total docs for now, need to get it from CB level - numDocs := 1000000 - trainingSampleSize := math.Ceil(4 * math.Sqrt(float64(numDocs)) * 50) - - // collect train data only if needed - if len(snapshot.trainData)/768 < int(trainingSampleSize) { - s.segmentConfig["collectTrainDataCallback"] = collectTrainData - } else { - s.segmentConfig["trainData"] = snapshot.trainData - } - // deploy the workers to merge and flush the batches of segments concurrently // and create a new file segment for i := 0; i < numFlushes; i++ { @@ -627,7 +590,6 @@ func (s *Scorch) mergeAndPersistInMemorySegments(snapshot *IndexSnapshot, mergedSegHistory: make(map[uint64]*mergedSegmentHistory, numSegments), notifyCh: make(chan *mergeTaskIntroStatus), newCount: newMergedCount, - trainData: trainingSample, } // create a history map which maps the old in-memory segments with the specific diff --git a/index/scorch/persister.go b/index/scorch/persister.go index 1e4860150..b1363c222 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -564,6 +564,11 @@ func copyToDirectory(srcPath string, d index.Directory) (int64, error) { return 0, fmt.Errorf("GetWriter err: %v", err) } + // skip + if dest == nil { + return 0, nil + } + sourceFileStat, err := os.Stat(srcPath) if err != nil { return 0, err diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 32e91842a..a6096b4d2 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -19,6 +19,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "sync/atomic" "time" @@ -27,6 +28,7 @@ import ( "github.com/blevesearch/bleve/v2/registry" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" + "github.com/blevesearch/go-faiss" segment "github.com/blevesearch/scorch_segment_api/v2" bolt "go.etcd.io/bbolt" ) @@ -78,6 +80,8 @@ type Scorch struct { persisterNotifier chan *epochWatcher rootBolt *bolt.DB asyncTasks sync.WaitGroup + // not a real searchable segment, singleton + centroidIndex *SegmentSnapshot onEvent func(event Event) bool onAsyncError func(err error, path string) @@ -139,6 +143,12 @@ func NewScorch(storeName string, } } + // "pretraining": true + segConfig, ok := config["segmentConfig"].(map[string]interface{}) + if ok { + rv.segmentConfig = segConfig + } + typ, ok := config["spatialPlugin"].(string) if ok { if err := rv.loadSpatialAnalyzerPlugin(typ); err != nil { @@ -503,6 +513,72 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { return err } +func (s *Scorch) Train(batch *index.Batch) error { + s.rootLock.Lock() + defer s.rootLock.Unlock() + if s.centroidIndex != nil { + // singleton API + return nil + } + var trainData []index.Document + if s.centroidIndex == nil { + for key, doc := range batch.IndexOps { + if strings.HasPrefix(key, index.TrainDataPrefix) { + trainData = append(trainData, doc) + } + } + } + + // just builds a new vector index out of the train data provided + // it'll be an IVF index so the centroids are computed at this stage and + // this template will be used in the indexing down the line to index + // the data vectors. s.segmentConfig will mark this as a training phase + // and zap will handle it accordingly. + // + // note: this might index text data too, how to handle this? s.segmentConfig? + // todo: updates/deletes -> data drift detection + seg, _, err := s.segPlugin.NewEx(trainData, s.segmentConfig) + if err != nil { + return err + } + filename := "centroid_index.zap" + path := filepath.Join(s.path, filename) + + switch seg := seg.(type) { + case segment.UnpersistedSegment: + err = persistToDirectory(seg, nil, path) + if err != nil { + return err + } + default: + return fmt.Errorf("segment is not a unpersisted segment") + } + + // persist and open the segment mmap mode. + persistedSegment, err := s.segPlugin.OpenEx(path, s.segmentConfig) + if err != nil { + return err + } + s.centroidIndex = &SegmentSnapshot{ + segment: persistedSegment, + } + s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + return nil +} + +func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) { + // return the coarse quantizer of the centroid index belonging to the field + centroidIndexSegment, ok := s.centroidIndex.segment.(segment.CentroidIndexSegment) + if !ok { + return nil, fmt.Errorf("segment is not a centroid index segment") + } + coarseQuantizer, err := centroidIndexSegment.GetCoarseQuantizer(field) + if err != nil { + return nil, err + } + return coarseQuantizer, nil +} + func (s *Scorch) prepareSegment(newSegment segment.Segment, ids []string, internalOps map[string][]byte, persistedCallback index.BatchCallback, stats *fieldStats, ) error { diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 91099a848..0d0fc9d41 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -70,10 +70,6 @@ func init() { type IndexSnapshot struct { parent *Scorch - // POC: trainData is ephemeral and read-only just like []*SegmentSnapshot - trainData []float32 - // trainSegments []*SegmentSnapshot // either store []float32 or []faissIndexes aka centroid indexes - segment []*SegmentSnapshot offsets []uint64 internal map[string][]byte diff --git a/index_alias_impl.go b/index_alias_impl.go index 8212c74b9..ee7fbf2a6 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -103,6 +103,25 @@ func (i *indexAliasImpl) IndexSynonym(id string, collection string, definition * return ErrorSynonymSearchNotSupported } +func (i *indexAliasImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + err := i.isAliasToSingleIndex() + if err != nil { + return err + } + + if vi, ok := i.indexes[0].(VectorIndex); ok { + return vi.Train(batch) + } + return fmt.Errorf("not a vector index") +} + func (i *indexAliasImpl) Delete(id string) error { i.mutex.RLock() defer i.mutex.RUnlock() diff --git a/index_impl.go b/index_impl.go index 8065d9c1e..a969253c1 100644 --- a/index_impl.go +++ b/index_impl.go @@ -369,6 +369,20 @@ func (i *indexImpl) IndexSynonym(id string, collection string, definition *Synon return err } +func (i *indexImpl) Train(batch *Batch) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + if vi, ok := i.i.(VectorIndex); ok { + return vi.Train(batch) + } + return fmt.Errorf("not a vector index") +} + // IndexAdvanced takes a document.Document object // skips the mapping and indexes it. func (i *indexImpl) IndexAdvanced(doc *document.Document) (err error) { @@ -1362,6 +1376,7 @@ func (m *searchHitSorter) Less(i, j int) bool { return c < 0 } +// CopyTo (index.Directory, filter) func (i *indexImpl) CopyTo(d index.Directory) (err error) { i.mutex.RLock() defer i.mutex.RUnlock() @@ -1375,6 +1390,8 @@ func (i *indexImpl) CopyTo(d index.Directory) (err error) { return fmt.Errorf("index implementation does not support copy reader") } + // copyIndex.Copy() -> copies the centroid index + copyReader := copyIndex.CopyReader() if copyReader == nil { return fmt.Errorf("index's copyReader is nil") diff --git a/mapping/mapping.go b/mapping/mapping.go index 8a2aaaaba..a6c1591b8 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -57,7 +57,6 @@ type IndexMapping interface { AnalyzerNamed(name string) analysis.Analyzer FieldMappingForPath(path string) FieldMapping - VectorSources() []string } // A SynonymMapping extends the IndexMapping interface to provide diff --git a/mapping/mapping_no_vectors.go b/mapping/mapping_no_vectors.go index 2f8c312e4..90cb1e225 100644 --- a/mapping/mapping_no_vectors.go +++ b/mapping/mapping_no_vectors.go @@ -42,10 +42,3 @@ func validateFieldMapping(field *FieldMapping, parentName string, fieldAliasCtx map[string]*FieldMapping) error { return validateFieldType(field) } - -// ----------------------------------------------------------------------------- -// vector source functions - -func (im *IndexMappingImpl) VectorSources() []string { - return []string{"vector indexing is not implemented"} -} diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index 88184036e..20cbac6a8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -270,18 +270,3 @@ func NormalizeVector(vec []float32) []float32 { // normalize the vector copy using in-place normalization provided by faiss return faiss.NormalizeVector(vecCopy) } - -// ----------------------------------------------------------------------------- -// vector source functions - -func (im *IndexMappingImpl) VectorSources() []string { - var sources []string - for name, v := range im.TypeMapping { - for _, f := range v.Fields { - if f.Type == "vector" || f.Type == "vector_base64" { - sources = append(sources, name) - } - } - } - return sources -} From acce0036fe5deaf85e254ae31775554d99085894 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Wed, 26 Nov 2025 11:15:13 -0800 Subject: [PATCH 09/12] wip: batch training + interfaces to reuse pre-trained file --- index.go | 5 + index/scorch/persister.go | 29 +++++- index/scorch/scorch.go | 165 ++++++++++++++++++++++++++++++++- index/scorch/snapshot_index.go | 4 + index_alias_impl.go | 5 +- index_impl.go | 38 +++++++- util/keys.go | 1 + 7 files changed, 233 insertions(+), 14 deletions(-) diff --git a/index.go b/index.go index c083787c4..21d016610 100644 --- a/index.go +++ b/index.go @@ -353,6 +353,11 @@ type IndexCopyable interface { CopyTo(d index.Directory) error } +type IndexFileCopyable interface { + UpdateFileInBolt(key []byte, value []byte) error + CopyFile(file string, d index.IndexDirectory) error +} + // FileSystemDirectory is the default implementation for the // index.Directory interface. type FileSystemDirectory string diff --git a/index/scorch/persister.go b/index/scorch/persister.go index b1363c222..9ebc8b559 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -564,11 +564,6 @@ func copyToDirectory(srcPath string, d index.Directory) (int64, error) { return 0, fmt.Errorf("GetWriter err: %v", err) } - // skip - if dest == nil { - return 0, nil - } - sourceFileStat, err := os.Stat(srcPath) if err != nil { return 0, err @@ -847,10 +842,34 @@ func zapFileName(epoch uint64) string { return fmt.Sprintf("%012x.zap", epoch) } +func (s *Scorch) updateCentroidIndex(bucket *bolt.Bucket) error { + if bucket == nil { + return nil + } + segmentSnapshot, err := s.loadSegment(bucket) + if err != nil { + return err + } + s.rootLock.Lock() + defer s.rootLock.Unlock() + + s.centroidIndex = segmentSnapshot + return nil +} + // bolt snapshot code func (s *Scorch) loadFromBolt() error { err := s.rootBolt.View(func(tx *bolt.Tx) error { + centroidIndexBucket := tx.Bucket(util.BoltCentroidIndexKey) + if centroidIndexBucket == nil { + return nil + } + err := s.updateCentroidIndex(centroidIndexBucket) + if err != nil { + return err + } + snapshots := tx.Bucket(util.BoltSnapshotsBucket) if snapshots == nil { return nil diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index a6096b4d2..604c46b92 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -15,8 +15,10 @@ package scorch import ( + "bytes" "encoding/json" "fmt" + "io" "os" "path/filepath" "strings" @@ -513,7 +515,19 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { return err } +func (s *Scorch) getInternal(key []byte) ([]byte, error) { + s.rootLock.RLock() + defer s.rootLock.RUnlock() + if string(key) == "_centroid_index_complete" { + return []byte(fmt.Sprintf("%t", s.centroidIndex != nil)), nil + } + return nil, nil +} + +// min 39 per centroid, recommeded 50 +// max 256 func (s *Scorch) Train(batch *index.Batch) error { + // is the lock really needed? s.rootLock.Lock() defer s.rootLock.Unlock() if s.centroidIndex != nil { @@ -523,6 +537,12 @@ func (s *Scorch) Train(batch *index.Batch) error { var trainData []index.Document if s.centroidIndex == nil { for key, doc := range batch.IndexOps { + if doc != nil { + // insert _id field + // no need to track updates/deletes over here since + // the API is singleton + doc.AddIDField() + } if strings.HasPrefix(key, index.TrainDataPrefix) { trainData = append(trainData, doc) } @@ -537,11 +557,16 @@ func (s *Scorch) Train(batch *index.Batch) error { // // note: this might index text data too, how to handle this? s.segmentConfig? // todo: updates/deletes -> data drift detection - seg, _, err := s.segPlugin.NewEx(trainData, s.segmentConfig) + s.segmentConfig["training"] = true + seg, n, err := s.segPlugin.NewEx(trainData, s.segmentConfig) if err != nil { return err } - filename := "centroid_index.zap" + // reset the training flag once completed + s.segmentConfig["training"] = false + // not suffixing with .zap since the current garbage collection is tailored to .zap ext files + // we don't want to gc this file ever. + filename := "centroid_index" path := filepath.Join(s.path, filename) switch seg := seg.(type) { @@ -562,7 +587,56 @@ func (s *Scorch) Train(batch *index.Batch) error { s.centroidIndex = &SegmentSnapshot{ segment: persistedSegment, } - s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + + fmt.Println("number of bytes written to centroid index", n) + // s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + // updateBolt(tx, cetntroid) + // filename := "centroid_index" + // path := filepath.Join(s.path, filename) + // f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0600) + // if err != nil { + // return err + // } + + // bufw := bufio.NewWriter(f) + // _, err = bufw.Write([]byte(strings.Join([]string{"centroid_index1", path}, " "))) + // if err != nil { + // return err + // } + // err = bufw.Flush() + // if err != nil { + // return err + // } + // err = f.Sync() + // if err != nil { + // return err + // } + // err = f.Close() + // if err != nil { + // return err + // } + + tx, err := s.rootBolt.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + return err + } + + err = snapshotsBucket.Put(util.BoltCentroidIndexKey, []byte(path)) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + return nil } @@ -1018,6 +1092,91 @@ func (s *Scorch) CopyReader() index.CopyReader { return rv } +func (s *Scorch) updateCentroidIndexInBolt(tx *bolt.Tx) error { + centroidIndexBucket, err := tx.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + return err + } + + err = centroidIndexBucket.Put(util.BoltPathKey, []byte("centroid_index.zap")) + if err != nil { + return err + } + + return nil +} + +func (s *Scorch) UpdateFileInBolt(key []byte, value []byte) error { + tx, err := s.rootBolt.Begin(true) + if err != nil { + return err + } + defer tx.Rollback() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + return err + } + + // currently this is specific to centroid index file update + if bytes.Equal(key, util.BoltCentroidIndexKey) { + // guard against duplicate updates + existingValue := snapshotsBucket.Get(key) + if existingValue != nil { + return fmt.Errorf("key already exists") + } + + err = snapshotsBucket.Put(key, value) + if err != nil { + return err + } + } + + err = tx.Commit() + if err != nil { + return err + } + + err = s.rootBolt.Sync() + if err != nil { + return err + } + + return nil +} + +// CopyFile copies a specific file to a destination directory which has an access to a bleve index +// doing a io.Copy() isn't enough because the file needs to be tracked in bolt file as well +func (s *Scorch) CopyFile(file string, d index.IndexDirectory) error { + s.rootLock.Lock() + defer s.rootLock.Unlock() + + // this code is currently specific to centroid index file but is future proofed for other files + // to be updated in the dest's bolt + if strings.HasSuffix(file, "centroid_index") { + // centroid index file - this is outside the snapshots domain so the bolt update is different + err := d.UpdateFileInBolt(util.BoltCentroidIndexKey, []byte(file)) + if err != nil { + return err + } + } + + dest, err := d.GetWriter(filepath.Join("store", file)) + if err != nil { + return err + } + + source, err := os.Open(filepath.Join(s.path, file)) + if err != nil { + return err + } + + defer source.Close() + defer dest.Close() + _, err = io.Copy(dest, source) + return err +} + // external API to fire a scorch event (EventKindIndexStart) externally from bleve func (s *Scorch) FireIndexEvent() { s.fireEvent(EventKindIndexStart, 0) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 0d0fc9d41..6be348ab1 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -469,6 +469,10 @@ func (is *IndexSnapshot) Fields() ([]string, error) { } func (is *IndexSnapshot) GetInternal(key []byte) ([]byte, error) { + _, ok := is.internal[string(key)] + if !ok { + return is.parent.getInternal(key) + } return is.internal[string(key)], nil } diff --git a/index_alias_impl.go b/index_alias_impl.go index ee7fbf2a6..8cc1d90ed 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -110,14 +110,13 @@ func (i *indexAliasImpl) Train(batch *Batch) error { if !i.open { return ErrorIndexClosed } - err := i.isAliasToSingleIndex() if err != nil { return err } - if vi, ok := i.indexes[0].(VectorIndex); ok { - return vi.Train(batch) + if vi, ok := i.indexes[0].(index.VectorIndex); ok { + return vi.Train(batch.internal) } return fmt.Errorf("not a vector index") } diff --git a/index_impl.go b/index_impl.go index a969253c1..c8b8e8a0e 100644 --- a/index_impl.go +++ b/index_impl.go @@ -377,8 +377,8 @@ func (i *indexImpl) Train(batch *Batch) error { return ErrorIndexClosed } - if vi, ok := i.i.(VectorIndex); ok { - return vi.Train(batch) + if vi, ok := i.i.(index.VectorIndex); ok { + return vi.Train(batch.internal) } return fmt.Errorf("not a vector index") } @@ -1376,6 +1376,38 @@ func (m *searchHitSorter) Less(i, j int) bool { return c < 0 } +func (i *indexImpl) CopyFile(file string, d index.IndexDirectory) (err error) { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + copyIndex, ok := i.i.(index.IndexFileCopyable) + if !ok { + return fmt.Errorf("index implementation does not support copy reader") + } + + return copyIndex.CopyFile(file, d) +} + +func (i *indexImpl) UpdateFileInBolt(key []byte, value []byte) error { + i.mutex.RLock() + defer i.mutex.RUnlock() + + if !i.open { + return ErrorIndexClosed + } + + copyIndex, ok := i.i.(index.IndexFileCopyable) + if !ok { + return fmt.Errorf("index implementation does not support file copy") + } + + return copyIndex.UpdateFileInBolt(key, value) +} + // CopyTo (index.Directory, filter) func (i *indexImpl) CopyTo(d index.Directory) (err error) { i.mutex.RLock() @@ -1405,7 +1437,7 @@ func (i *indexImpl) CopyTo(d index.Directory) (err error) { err = copyReader.CopyTo(d) if err != nil { - return fmt.Errorf("error copying index metadata: %v", err) + return fmt.Errorf("error copying index data: %v", err) } // copy the metadata diff --git a/util/keys.go b/util/keys.go index b71a7f48b..11c918865 100644 --- a/util/keys.go +++ b/util/keys.go @@ -17,6 +17,7 @@ package util var ( // Bolt keys BoltSnapshotsBucket = []byte{'s'} + BoltCentroidIndexKey = []byte{'c'} BoltPathKey = []byte{'p'} BoltDeletedKey = []byte{'d'} BoltInternalKey = []byte{'i'} From b60c5f0a01446a154ce5cce172e30b5bce92f635 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 11 Dec 2025 14:43:28 -0800 Subject: [PATCH 10/12] bug fix, debug logging --- centroid_index_test.go | 74 ++++++++++++++++++++++++++++++++++++++ go.mod | 20 ++++++++++- go.sum | 18 ---------- index.go | 2 ++ index/scorch/persister.go | 24 +++++++------ index/scorch/scorch.go | 75 ++++++++++++++------------------------- index_alias_impl.go | 12 +------ index_impl.go | 2 ++ 8 files changed, 138 insertions(+), 89 deletions(-) create mode 100644 centroid_index_test.go diff --git a/centroid_index_test.go b/centroid_index_test.go new file mode 100644 index 000000000..a7334236b --- /dev/null +++ b/centroid_index_test.go @@ -0,0 +1,74 @@ +//go:build vectors +// +build vectors + +package bleve + +import ( + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/blevesearch/bleve/v2/analysis/lang/en" + "github.com/blevesearch/bleve/v2/mapping" + index "github.com/blevesearch/bleve_index_api" +) + +func loadSiftData() ([]map[string]interface{}, error) { + fileContent, err := os.ReadFile("~/fts/data/datasets/vec-sift-bucket.json") + if err != nil { + return nil, err + } + var documents []map[string]interface{} + err = json.Unmarshal(fileContent, &documents) + if err != nil { + return nil, err + } + return documents, nil +} + +func TestCentroidIndex(t *testing.T) { + _, _, err := readDatasetAndQueries(testInputCompressedFile) + if err != nil { + t.Fatal(err) + } + documents, err := loadSiftData() + if err != nil { + t.Fatal(err) + } + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Analyzer = en.AnalyzerName + + vecFieldMappingL2 := mapping.NewVectorFieldMapping() + vecFieldMappingL2.Dims = 128 + vecFieldMappingL2.Similarity = index.EuclideanDistance + + indexMappingL2Norm := NewIndexMapping() + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("content", contentFieldMapping) + indexMappingL2Norm.DefaultMapping.AddFieldMappingsAt("vector", vecFieldMappingL2) + + idx, err := newIndexUsing(t.TempDir(), indexMappingL2Norm, Config.DefaultIndexType, Config.DefaultKVStore, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + err := idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for _, doc := range documents[:100000] { + docId := fmt.Sprintf("%s:%s", index.TrainDataPrefix, doc["id"]) + err = batch.Index(docId, doc) + if err != nil { + t.Fatal(err) + } + } + + err = idx.Train(batch) + if err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod index 2604a57df..61d2ae87d 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,24 @@ require ( golang.org/x/sys v0.29.0 // indirect ) +replace github.com/blevesearch/bleve/v2 => /Users/thejas.orkombu/fts/blevesearch/bleve + +replace github.com/blevesearch/zapx/v11 => /Users/thejas.orkombu/fts/blevesearch/zapx11 + +replace github.com/blevesearch/zapx/v12 => /Users/thejas.orkombu/fts/blevesearch/zapx12 + +replace github.com/blevesearch/zapx/v13 => /Users/thejas.orkombu/fts/blevesearch/zapx13 + +replace github.com/blevesearch/zapx/v14 => /Users/thejas.orkombu/fts/blevesearch/zapx14 + +replace github.com/blevesearch/zapx/v15 => /Users/thejas.orkombu/fts/blevesearch/zapx15 + +replace github.com/blevesearch/zapx/v16 => /Users/thejas.orkombu/fts/blevesearch/zapx + replace github.com/blevesearch/scorch_segment_api/v2 => /Users/thejas.orkombu/fts/blevesearch/scorch_segment_api -replace github.com/blevesearch/bleve_index_api => /Users/thejas.orkombu/fts/blevesearch/bleve_index_api \ No newline at end of file +replace github.com/blevesearch/go-faiss => /Users/thejas.orkombu/fts/blevesearch/go-faiss + +replace github.com/blevesearch/bleve_index_api => /Users/thejas.orkombu/fts/blevesearch/bleve_index_api + +replace github.com/blevesearch/sear => /Users/thejas.orkombu/fts/blevesearch/sear diff --git a/go.sum b/go.sum index b46bebcef..fcb958d67 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,8 @@ github.com/RoaringBitmap/roaring/v2 v2.4.5/go.mod h1:FiJcsfkGje/nZBZgCu0ZxCPOKD/ github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCkcs2uw7w4= github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/blevesearch/bleve_index_api v1.2.11 h1:bXQ54kVuwP8hdrXUSOnvTQfgK0KI1+f9A0ITJT8tX1s= -github.com/blevesearch/bleve_index_api v1.2.11/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= github.com/blevesearch/geo v0.2.4 h1:ECIGQhw+QALCZaDcogRTNSJYQXRtC8/m8IKiA706cqk= github.com/blevesearch/geo v0.2.4/go.mod h1:K56Q33AzXt2YExVHGObtmRSFYZKYGv0JEN5mdacJJR8= -github.com/blevesearch/go-faiss v1.0.26 h1:4dRLolFgjPyjkaXwff4NfbZFdE/dfywbzDqporeQvXI= -github.com/blevesearch/go-faiss v1.0.26/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:kDy+zgJFJJoJYBvdfBSiZYBbdsUL0XcjHYWezpQBGPA= github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= @@ -20,8 +16,6 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.13 h1:ZPjv/4VwWvHJZKeMSgScCapOy8+DdmsmRyLmSB88UoY= -github.com/blevesearch/scorch_segment_api/v2 v2.3.13/go.mod h1:ENk2LClTehOuMS8XzN3UxBEErYmtwkE7MAArFTXs9Vc= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -34,18 +28,6 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= -github.com/blevesearch/zapx/v11 v11.4.2 h1:l46SV+b0gFN+Rw3wUI1YdMWdSAVhskYuvxlcgpQFljs= -github.com/blevesearch/zapx/v11 v11.4.2/go.mod h1:4gdeyy9oGa/lLa6D34R9daXNUvfMPZqUYjPwiLmekwc= -github.com/blevesearch/zapx/v12 v12.4.2 h1:fzRbhllQmEMUuAQ7zBuMvKRlcPA5ESTgWlDEoB9uQNE= -github.com/blevesearch/zapx/v12 v12.4.2/go.mod h1:TdFmr7afSz1hFh/SIBCCZvcLfzYvievIH6aEISCte58= -github.com/blevesearch/zapx/v13 v13.4.2 h1:46PIZCO/ZuKZYgxI8Y7lOJqX3Irkc3N8W82QTK3MVks= -github.com/blevesearch/zapx/v13 v13.4.2/go.mod h1:knK8z2NdQHlb5ot/uj8wuvOq5PhDGjNYQQy0QDnopZk= -github.com/blevesearch/zapx/v14 v14.4.2 h1:2SGHakVKd+TrtEqpfeq8X+So5PShQ5nW6GNxT7fWYz0= -github.com/blevesearch/zapx/v14 v14.4.2/go.mod h1:rz0XNb/OZSMjNorufDGSpFpjoFKhXmppH9Hi7a877D8= -github.com/blevesearch/zapx/v15 v15.4.2 h1:sWxpDE0QQOTjyxYbAVjt3+0ieu8NCE0fDRaFxEsp31k= -github.com/blevesearch/zapx/v15 v15.4.2/go.mod h1:1pssev/59FsuWcgSnTa0OeEpOzmhtmr/0/11H0Z8+Nw= -github.com/blevesearch/zapx/v16 v16.2.7 h1:xcgFRa7f/tQXOwApVq7JWgPYSlzyUMmkuYa54tMDuR0= -github.com/blevesearch/zapx/v16 v16.2.7/go.mod h1:murSoCJPCk25MqURrcJaBQ1RekuqSCSfMjXH4rHyA14= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= diff --git a/index.go b/index.go index 21d016610..bd5421d85 100644 --- a/index.go +++ b/index.go @@ -51,10 +51,12 @@ func (b *Batch) Index(id string, data interface{}) error { eventIndex.FireIndexEvent() } doc := document.NewDocument(id) + // fmt.Printf("data is before mapping %#v\n", data) err := b.index.Mapping().MapDocument(doc, data) if err != nil { return err } + // fmt.Printf("data is after mapping %#v\n", doc) b.internal.Update(doc) b.lastDocSize = uint64(doc.Size() + diff --git a/index/scorch/persister.go b/index/scorch/persister.go index 9ebc8b559..622d3ead7 100644 --- a/index/scorch/persister.go +++ b/index/scorch/persister.go @@ -846,13 +846,14 @@ func (s *Scorch) updateCentroidIndex(bucket *bolt.Bucket) error { if bucket == nil { return nil } + fmt.Println("updateCentroidIndex bucket", bucket != nil) segmentSnapshot, err := s.loadSegment(bucket) if err != nil { return err } s.rootLock.Lock() defer s.rootLock.Unlock() - + fmt.Println("updateCentroidIndex", segmentSnapshot.segment != nil) s.centroidIndex = segmentSnapshot return nil } @@ -861,15 +862,6 @@ func (s *Scorch) updateCentroidIndex(bucket *bolt.Bucket) error { func (s *Scorch) loadFromBolt() error { err := s.rootBolt.View(func(tx *bolt.Tx) error { - centroidIndexBucket := tx.Bucket(util.BoltCentroidIndexKey) - if centroidIndexBucket == nil { - return nil - } - err := s.updateCentroidIndex(centroidIndexBucket) - if err != nil { - return err - } - snapshots := tx.Bucket(util.BoltSnapshotsBucket) if snapshots == nil { return nil @@ -886,6 +878,12 @@ func (s *Scorch) loadFromBolt() error { s.AddEligibleForRemoval(snapshotEpoch) continue } + // fmt.Println("loadFromBolt key %s", k) + // if k[0] == util.BoltCentroidIndexKey[0] { + // fmt.Println("loadFromBolt centroid index key", string(k)) + + // continue + // } snapshot := snapshots.Bucket(k) if snapshot == nil { log.Printf("snapshot key, but bucket missing %x, continuing", k) @@ -917,6 +915,12 @@ func (s *Scorch) loadFromBolt() error { foundRoot = true } + + centroidIndexBucket := snapshots.Bucket(util.BoltCentroidIndexKey) + err := s.updateCentroidIndex(centroidIndexBucket) + if err != nil { + return err + } return nil }) if err != nil { diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 604c46b92..1b41df2e5 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -524,8 +524,6 @@ func (s *Scorch) getInternal(key []byte) ([]byte, error) { return nil, nil } -// min 39 per centroid, recommeded 50 -// max 256 func (s *Scorch) Train(batch *index.Batch) error { // is the lock really needed? s.rootLock.Lock() @@ -557,6 +555,7 @@ func (s *Scorch) Train(batch *index.Batch) error { // // note: this might index text data too, how to handle this? s.segmentConfig? // todo: updates/deletes -> data drift detection + s.segmentConfig["training"] = true seg, n, err := s.segPlugin.NewEx(trainData, s.segmentConfig) if err != nil { @@ -589,33 +588,14 @@ func (s *Scorch) Train(batch *index.Batch) error { } fmt.Println("number of bytes written to centroid index", n) - // s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex - // updateBolt(tx, cetntroid) - // filename := "centroid_index" - // path := filepath.Join(s.path, filename) - // f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0600) - // if err != nil { - // return err - // } - - // bufw := bufio.NewWriter(f) - // _, err = bufw.Write([]byte(strings.Join([]string{"centroid_index1", path}, " "))) - // if err != nil { - // return err - // } - // err = bufw.Flush() - // if err != nil { - // return err - // } - // err = f.Sync() - // if err != nil { - // return err - // } - // err = f.Close() - // if err != nil { - // return err - // } + s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint + // where we can be sure that the centroid index is available for the indexing operations downstream + // + // note: when the scale increases massively especially with real world dimensions of 1536+, this API + // will have to be refactored to persist in a more resource efficient way. so having this bolt related + // code will help in tracking the progress a lot better and avoid any redudant data streaming operations. tx, err := s.rootBolt.Begin(true) if err != nil { return err @@ -627,7 +607,11 @@ func (s *Scorch) Train(batch *index.Batch) error { return err } - err = snapshotsBucket.Put(util.BoltCentroidIndexKey, []byte(path)) + centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + return err + } + err = centroidBucket.Put(util.BoltPathKey, []byte(filename)) if err != nil { return err } @@ -644,7 +628,7 @@ func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) { // return the coarse quantizer of the centroid index belonging to the field centroidIndexSegment, ok := s.centroidIndex.segment.(segment.CentroidIndexSegment) if !ok { - return nil, fmt.Errorf("segment is not a centroid index segment") + return nil, fmt.Errorf("segment is not a centroid index segment", s.centroidIndex.segment != nil) } coarseQuantizer, err := centroidIndexSegment.GetCoarseQuantizer(field) if err != nil { @@ -1092,20 +1076,6 @@ func (s *Scorch) CopyReader() index.CopyReader { return rv } -func (s *Scorch) updateCentroidIndexInBolt(tx *bolt.Tx) error { - centroidIndexBucket, err := tx.CreateBucketIfNotExists(util.BoltCentroidIndexKey) - if err != nil { - return err - } - - err = centroidIndexBucket.Put(util.BoltPathKey, []byte("centroid_index.zap")) - if err != nil { - return err - } - - return nil -} - func (s *Scorch) UpdateFileInBolt(key []byte, value []byte) error { tx, err := s.rootBolt.Begin(true) if err != nil { @@ -1120,13 +1090,20 @@ func (s *Scorch) UpdateFileInBolt(key []byte, value []byte) error { // currently this is specific to centroid index file update if bytes.Equal(key, util.BoltCentroidIndexKey) { - // guard against duplicate updates - existingValue := snapshotsBucket.Get(key) + // todo: guard against duplicate updates + centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + return err + } + if centroidBucket == nil { + return fmt.Errorf("centroid bucket not found") + } + existingValue := centroidBucket.Get(util.BoltPathKey) if existingValue != nil { - return fmt.Errorf("key already exists") + return fmt.Errorf("key already exists %v %v", s.path, string(existingValue)) } - err = snapshotsBucket.Put(key, value) + err = centroidBucket.Put(util.BoltPathKey, value) if err != nil { return err } @@ -1157,7 +1134,7 @@ func (s *Scorch) CopyFile(file string, d index.IndexDirectory) error { // centroid index file - this is outside the snapshots domain so the bolt update is different err := d.UpdateFileInBolt(util.BoltCentroidIndexKey, []byte(file)) if err != nil { - return err + return fmt.Errorf("error updating dest index bolt: %w", err) } } diff --git a/index_alias_impl.go b/index_alias_impl.go index 8cc1d90ed..16f20ac45 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -107,17 +107,7 @@ func (i *indexAliasImpl) Train(batch *Batch) error { i.mutex.RLock() defer i.mutex.RUnlock() - if !i.open { - return ErrorIndexClosed - } - err := i.isAliasToSingleIndex() - if err != nil { - return err - } - - if vi, ok := i.indexes[0].(index.VectorIndex); ok { - return vi.Train(batch.internal) - } + // TODO: implement this return fmt.Errorf("not a vector index") } diff --git a/index_impl.go b/index_impl.go index c8b8e8a0e..18fae745c 100644 --- a/index_impl.go +++ b/index_impl.go @@ -326,11 +326,13 @@ func (i *indexImpl) Index(id string, data interface{}) (err error) { i.FireIndexEvent() + // fmt.Printf("data is %#v\n", data) doc := document.NewDocument(id) err = i.m.MapDocument(doc, data) if err != nil { return } + // fmt.Printf("data is after mapping %#v\n", doc) err = i.i.Update(doc) return } From 5a1b287b2c8d4568184718eb9e0a9e825996da71 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 15 Dec 2025 11:35:30 -0800 Subject: [PATCH 11/12] wip: implement async trainer loop with incremental training support --- index/scorch/scorch.go | 173 +++++++++++++++++++++++++++++------------ util/keys.go | 1 + 2 files changed, 125 insertions(+), 49 deletions(-) diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 1b41df2e5..d7490301f 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -16,6 +16,7 @@ package scorch import ( "bytes" + "encoding/binary" "encoding/json" "fmt" "io" @@ -84,6 +85,7 @@ type Scorch struct { asyncTasks sync.WaitGroup // not a real searchable segment, singleton centroidIndex *SegmentSnapshot + train chan *trainRequest onEvent func(event Event) bool onAsyncError func(err error, path string) @@ -95,6 +97,12 @@ type Scorch struct { spatialPlugin index.SpatialAnalyzerPlugin } +type trainRequest struct { + sample segment.Segment + vecCount int + ackCh chan error +} + // AsyncPanicError is passed to scorch asyncErrorHandler when panic occurs in scorch background process type AsyncPanicError struct { Source string @@ -518,13 +526,118 @@ func (s *Scorch) Batch(batch *index.Batch) (err error) { func (s *Scorch) getInternal(key []byte) ([]byte, error) { s.rootLock.RLock() defer s.rootLock.RUnlock() + // todo: return the total number of vectors that have been processed so far in training + // in cbft use that as a checkpoint to resume training for n-x samples. if string(key) == "_centroid_index_complete" { return []byte(fmt.Sprintf("%t", s.centroidIndex != nil)), nil } return nil, nil } +// this is not a routine that will be running throughout the lifetime of the index. It's purpose +// is to only train the vector index before the data ingestion starts. +func (s *Scorch) trainerLoop() { + // some init stuff + s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex + var totalSamplesProcessed int + filename := "centroid_index" + path := filepath.Join(s.path, filename) + buf := make([]byte, binary.MaxVarintLen64) + for { + select { + case <-s.closeCh: + return + case trainReq := <-s.train: + sampleSeg := trainReq.sample + if s.centroidIndex == nil { + // new centroid index + s.centroidIndex = &SegmentSnapshot{ + segment: sampleSeg, + } + switch seg := sampleSeg.(type) { + case segment.UnpersistedSegment: + err := persistToDirectory(seg, nil, path) + if err != nil { + // clean up this ugly ass error handling code + trainReq.ackCh <- fmt.Errorf("error persisting segment: %v", err) + close(trainReq.ackCh) + } + default: + fmt.Errorf("segment is not a unpersisted segment") + close(s.closeCh) + } + } else { + // merge the new segment with the existing one, no need to persist? + // persist in a tmp file and then rename - is that a fair strategy? + _, _, err := s.segPlugin.MergeEx([]segment.Segment{s.centroidIndex.segment, sampleSeg}, + []*roaring.Bitmap{nil, nil}, "centroid_index.tmp", s.closeCh, nil, s.segmentConfig) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error merging centroid index: %v", err) + close(trainReq.ackCh) + } + + // close the existing centroid segment - it's supposed to be gc'd at this point + s.centroidIndex.segment.Close() + err = os.Rename(filepath.Join(s.path, "centroid_index.tmp"), filepath.Join(s.path, "centroid_index")) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error renaming centroid index: %v", err) + close(trainReq.ackCh) + } + } + + totalSamplesProcessed += trainReq.vecCount + // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint + // where we can be sure that the centroid index is available for the indexing operations downstream + // + // note: when the scale increases massively especially with real world dimensions of 1536+, this API + // will have to be refactored to persist in a more resource efficient way. so having this bolt related + // code will help in tracking the progress a lot better and avoid any redudant data streaming operations. + tx, err := s.rootBolt.Begin(true) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error starting bolt transaction: %v", err) + close(trainReq.ackCh) + } + defer tx.Rollback() + + snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating snapshots bucket: %v", err) + close(trainReq.ackCh) + } + + centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error creating centroid bucket: %v", err) + close(trainReq.ackCh) + } + + err = centroidBucket.Put(util.BoltPathKey, []byte(filename)) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error updating centroid bucket: %v", err) + close(trainReq.ackCh) + } + + // total number of vectors that have been processed so far for the training + n := binary.PutUvarint(buf, uint64(totalSamplesProcessed)) + err = centroidBucket.Put(util.BoltVecSamplesProcessedKey, buf[:n]) + if err != nil { + trainReq.ackCh <- fmt.Errorf("error updating vec samples processed: %v", err) + close(trainReq.ackCh) + } + + err = tx.Commit() + if err != nil { + trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err) + close(trainReq.ackCh) + } + + close(trainReq.ackCh) + } + } +} + func (s *Scorch) Train(batch *index.Batch) error { + // is the lock really needed? s.rootLock.Lock() defer s.rootLock.Unlock() @@ -555,7 +668,6 @@ func (s *Scorch) Train(batch *index.Batch) error { // // note: this might index text data too, how to handle this? s.segmentConfig? // todo: updates/deletes -> data drift detection - s.segmentConfig["training"] = true seg, n, err := s.segPlugin.NewEx(trainData, s.segmentConfig) if err != nil { @@ -563,65 +675,28 @@ func (s *Scorch) Train(batch *index.Batch) error { } // reset the training flag once completed s.segmentConfig["training"] = false - // not suffixing with .zap since the current garbage collection is tailored to .zap ext files - // we don't want to gc this file ever. - filename := "centroid_index" - path := filepath.Join(s.path, filename) - switch seg := seg.(type) { - case segment.UnpersistedSegment: - err = persistToDirectory(seg, nil, path) - if err != nil { - return err - } - default: - return fmt.Errorf("segment is not a unpersisted segment") + trainReq := &trainRequest{ + sample: seg, + vecCount: len(trainData), // todo: multivector support + ackCh: make(chan error), } - // persist and open the segment mmap mode. - persistedSegment, err := s.segPlugin.OpenEx(path, s.segmentConfig) + s.train <- trainReq + err = <-trainReq.ackCh if err != nil { return err } - s.centroidIndex = &SegmentSnapshot{ - segment: persistedSegment, - } - fmt.Println("number of bytes written to centroid index", n) - s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex - - // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint - // where we can be sure that the centroid index is available for the indexing operations downstream - // - // note: when the scale increases massively especially with real world dimensions of 1536+, this API - // will have to be refactored to persist in a more resource efficient way. so having this bolt related - // code will help in tracking the progress a lot better and avoid any redudant data streaming operations. - tx, err := s.rootBolt.Begin(true) + centroidIndex, err := s.segPlugin.OpenEx(filepath.Join(s.path, "centroid_index"), s.segmentConfig) if err != nil { return err } - defer tx.Rollback() - - snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket) - if err != nil { - return err - } - - centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey) - if err != nil { - return err - } - err = centroidBucket.Put(util.BoltPathKey, []byte(filename)) - if err != nil { - return err - } - - err = tx.Commit() - if err != nil { - return err + s.centroidIndex = &SegmentSnapshot{ + segment: centroidIndex, } - - return nil + fmt.Println("number of bytes written to centroid index", n) + return err } func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) { diff --git a/util/keys.go b/util/keys.go index 11c918865..67415e782 100644 --- a/util/keys.go +++ b/util/keys.go @@ -18,6 +18,7 @@ var ( // Bolt keys BoltSnapshotsBucket = []byte{'s'} BoltCentroidIndexKey = []byte{'c'} + BoltVecSamplesProcessedKey = []byte{'v'} BoltPathKey = []byte{'p'} BoltDeletedKey = []byte{'d'} BoltInternalKey = []byte{'i'} From 45c2ee95531faf077892809b3acfd4da72a54f63 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 15 Dec 2025 11:36:44 -0800 Subject: [PATCH 12/12] regulate train function using EventKindIndexStart --- index/scorch/scorch.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index d7490301f..365ef972a 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -637,6 +637,8 @@ func (s *Scorch) trainerLoop() { } func (s *Scorch) Train(batch *index.Batch) error { + // regulate the Train function + s.FireIndexEvent() // is the lock really needed? s.rootLock.Lock()