diff --git a/README.md b/README.md index 47ff00773..5e6a4169d 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ A modern indexing + search library in GO * [geo spatial search](https://github.com/blevesearch/bleve/blob/master/geo/README.md) * approximate k-nearest neighbors via [vector search](https://github.com/blevesearch/bleve/blob/master/docs/vectors.md) * [synonym search](https://github.com/blevesearch/bleve/blob/master/docs/synonyms.md) + * [hierarchy search](https://github.com/blevesearch/bleve/blob/master/docs/hierarchy.md) * [tf-idf](https://github.com/blevesearch/bleve/blob/master/docs/scoring.md#tf-idf) / [bm25](https://github.com/blevesearch/bleve/blob/master/docs/scoring.md#bm25) scoring models * Hybrid search: exact + semantic * Supports [RRF (Reciprocal Rank Fusion) and RSF (Relative Score Fusion)](docs/score_fusion.md) diff --git a/docs/hierarchy.md b/docs/hierarchy.md new file mode 100644 index 000000000..4aa62af57 --- /dev/null +++ b/docs/hierarchy.md @@ -0,0 +1,376 @@ +# Hierarchy search + +* *v2.6.0* (and after) will come with support for **Array indexing and hierarchy search**. +* We've achieved this by embedding nested documents within our bleve (scorch) indexes. +* Usage of zap file format: [v17](https://github.com/blevesearch/zapx/blob/master/zap.md). Here we preserve hierarchical document relationships within segments, continuing to conform to the segmented architecture of *scorch*. + +## Supported + +* Indexing `Arrays` allows specifying fields that contain arrays of objects. Each object in the array can have its own set of fields, enabling the representation of hierarchical data structures within a single document. + +```json +{ + "id": "1", + "name": "John Doe", + "addresses": [ + { + "type": "home", + "street": "123 Main St", + "city": "Hometown", + "zip": "12345" + }, + { + "type": "work", + "street": "456 Corporate Blvd", + "city": "Metropolis", + "zip": "67890" + } + ] +} +``` + +* Multi-level arrays: Arrays can contain objects that themselves have array fields, allowing for deeply nested structures, such as a list of projects, each with its own list of tasks. + +```json +{ + "id": "2", + "name": "Jane Smith", + "projects": [ + { + "name": "Project Alpha", + "tasks": [ + {"title": "Task 1", "status": "completed"}, + {"title": "Task 2", "status": "in-progress"} + ] + }, + { + "name": "Project Beta", + "tasks": [ + {"title": "Task A", "status": "not-started"}, + {"title": "Task B", "status": "completed"} + ] + } + ] +} +``` + +* Multiple arrays: A document can have multiple fields that are arrays, each representing different hierarchical data, such as a list of phone numbers and a list of email addresses. + +```json +{ + "id": "3", + "name": "Alice Johnson", + "phones": [ + {"type": "mobile", "number": "555-1234"}, + {"type": "home", "number": "555-5678"} + ], + "emails": [ + {"type": "personal", "address": "alice@example.com"}, + {"type": "work", "address": "alice@work.com"} + ] +} +``` + +* Hybrid arrays: Multi-level and multiple arrays can be combined within the same document to represent complex hierarchical data structures, such as a company with multiple departments, each having its own list of employees and projects. + +```json +{ + "id": "doc1", + "company": { + "id": "c1", + "name": "TechCorp", + "departments": [ + { + "name": "Engineering", + "budget": 2000000, + "employees": [ + {"name": "Alice", "role": "Engineer"}, + {"name": "Bob", "role": "Manager"} + ], + "projects": [ + {"title": "Project X", "status": "ongoing"}, + {"title": "Project Y", "status": "completed"} + ] + }, + { + "name": "Sales", + "budget": 300000, + "employees": [ + {"name": "Eve", "role": "Salesperson"}, + {"name": "Mallory", "role": "Manager"} + ], + "projects": [ + {"title": "Project A", "status": "completed"}, + {"title": "Project B", "status": "ongoing"} + ] + } + ], + "locations": [ + {"city": "Athens","country": "Greece"}, + {"city": "Berlin","country": "USA"} + ] + } +} +``` + +* Earlier versions of Bleve only supported flat arrays of primitive types (e.g., strings, numbers), and would flatten nested structures, losing the hierarchical relationships, so the above complex documents could not be accurately represented or queried. For example, the "employees" and "projects" fields within each department would be flattened, making it impossible to associate employees with their respective departments. + +* From v2.6.0 onwards, Bleve allows for accurate representation and querying of complex nested structures, preserving the relationships between different levels of the hierarchy, across multi-level, multiple and hybrid arrays. + +* The addition of `nested` document mappings enable defining fields that contain arrays of objects, giving the option to preserve the hierarchical relationships within the array during indexing. Having `nested` as false (default) will flatten the objects within the array, losing the hierarchy, which was the earlier behavior. + +```json +{ + "departments": { + "dynamic": false, + "enabled": true, + "nested": true, + "properties": { + "employees": { + "dynamic": false, + "enabled": true, + "nested": true + }, + "projects": { + "dynamic": false, + "enabled": true, + "nested": true + } + } + }, + "locations": { + "dynamic": false, + "enabled": true, + "nested": true + } +} +``` + +* Any Bleve query (e.g., match, phrase, term, fuzzy, numeric/date range etc.) can be executed against fields within nested documents, with no special handling required. The query processor will automatically traverse the nested structures to find matches. Additional search constructs +like vector search, synonym search, hybrid and pre-filtered vector search integrate seamlessly with hierarchy search. + +* Conjunction Queries (AND queries) and other queries that depend on term co-occurrence within the same hierarchical context will respect the boundaries of nested documents. This means that terms must appear within the same nested object to be considered a match. For example, a conjunction query searching for an employee named "Alice" with the role "Engineer" within the "Engineering" department will only return results where both name and role terms are found within the same employee object, which is itself within a "Engineering" department object. + +* Some other search constructs will have enhanced precision with hierarchy search. + * Field-Level Highlighting: Only fields within the matched nested object are retrieved and highlighted, ensuring highlights appear in the correct hierarchical context. For example, a match in `departments[name=Engineering].employees` highlights only employees in that department. + + * Nested Faceting / Aggregations: Facets are computed within matched nested objects, producing context-aware buckets. E.g., a facet on `departments.projects.status` returns ongoing or completed only for projects in matched departments. + + * Sorting by Nested Fields: Sorting can use fields from the relevant nested object, e.g., ordering companies by `departments.budget sorts` based on the budget of the specific matched department, not unrelated departments. + +* Vector Search (KNN / Multi-KNN): When an array of objects is marked as nested and contains vector fields, each vector is treated as belonging to its own nested document. Vector similarity is computed only within the same nested object, not across siblings. For example, if `departments.employees` is a nested array where each employee has a `skills_vector`, a KNN search using the embedding of `machine learning engineer` will match only employees whose own `skills_vector` is similar; other employees vectors within the same department or document do not contribute to the score or match. This also means that a vector search query for `K = 3` will return the top 3 most similar employees across all departments and all companies, and may return multiple employees from the same department or company if they rank among the top 3 most similar overall. + +* Pre-Filtered Vector Search: When vector search is combined with filters on fields inside a nested array, the filters are applied first to pick which nested items are eligible. The vector search then runs only on those filtered items. For example, if `departments.employees` is a `nested` array, a pre-filtered KNN query for employees with the role `Manager` in the `Sales` department will first narrow the candidate set to only employees who meet those field conditions, and then compute vector similarity on the `skills_vector` of that filtered subset. This ensures that vector search results come only from the employees that satisfy the filter, while still treating each employee as an independent vector candidate. + +## Indexing + +Below is an example of using the Bleve API to index documents with hierarchical structures, using hybrid arrays and nested mappings. + +```go +// Define a document to be indexed. +docJSON := + `{ + "company": { + "id": "c3", + "name": "WebSolutions", + "departments": [ + { + "name": "HR", + "budget": 800000, + "employees": [ + {"name": "Eve", "role": "Manager"}, + {"name": "Frank", "role": "HR"} + ], + "projects": [ + {"title": "Project Beta", "status": "completed"}, + {"title": "Project B", "status": "ongoing"} + ] + }, + { + "name": "Engineering", + "budget": 200000, + "employees": [ + {"name": "Heidi", "role": "Support Engineer"}, + {"name": "Ivan", "role": "Manager"} + ], + "projects": [ + {"title": "Project Helpdesk", "status": "ongoing"}, + {"title": "Project FAQ", "status": "completed"} + ] + } + ], + "locations": [ + {"city": "Edinburgh", "country": "UK"}, + {"city": "London", "country": "Canada"} + ] + } + }` + +// Define departments as a nested document mapping (since it contains arrays of objects) +// and index name and budget fields +departmentsMapping := bleve.NewNestedDocumentMapping() +departmentsMapping.AddFieldMappingsAt("name", bleve.NewTextFieldMapping()) +departmentsMapping.AddFieldMappingsAt("budget", bleve.NewNumericFieldMapping()) + +// Define employees as a nested document mapping within departments (since it contains arrays of objects) +// and index name and role fields +employeesMapping := bleve.NewNestedDocumentMapping() +employeesMapping.AddFieldMappingsAt("name", bleve.NewTextFieldMapping()) +employeesMapping.AddFieldMappingsAt("role", bleve.NewTextFieldMapping()) +departmentsMapping.AddSubDocumentMapping("employees", employeesMapping) + +// Define projects as a nested document mapping within departments (since it contains arrays of objects) +// and index title and status fields +projectsMapping := bleve.NewNestedDocumentMapping() +projectsMapping.AddFieldMappingsAt("title", bleve.NewTextFieldMapping()) +projectsMapping.AddFieldMappingsAt("status", bleve.NewTextFieldMapping()) +departmentsMapping.AddSubDocumentMapping("projects", projectsMapping) + +// Define locations as a nested document mapping (since it contains arrays of objects) +// and index city and country fields +locationsMapping := bleve.NewNestedDocumentMapping() +locationsMapping.AddFieldMappingsAt("city", bleve.NewTextFieldMapping()) +locationsMapping.AddFieldMappingsAt("country", bleve.NewTextFieldMapping()) + +// Define company as a document mapping and index its name field and +// add departments and locations as sub-document mappings +companyMapping := bleve.NewDocumentMapping() +companyMapping.AddFieldMappingsAt("name", bleve.NewTextFieldMapping()) +companyMapping.AddSubDocumentMapping("departments", departmentsMapping) +companyMapping.AddSubDocumentMapping("locations", locationsMapping) + +// Define the final index mapping and add company as a sub-document mapping in the default mapping +indexMapping := bleve.NewIndexMapping() +indexMapping.DefaultMapping.AddSubDocumentMapping("company", companyMapping) + +// Create the index with the defined mapping +index, err := bleve.New("hierarchy_example.bleve", indexMapping) +if err != nil { + panic(err) +} + +// Unmarshal the document JSON into a map, for indexing +var doc map[string]interface{} +err = json.Unmarshal([]byte(docJSON), &doc) +if err != nil { + panic(err) +} + +// Index the document +err = index.Index("doc1", doc) +if err != nil { + panic(err) +} +``` + +## Querying + +```go +// Open the index +index, err := bleve.Open("hierarchy_example.bleve") +if err != nil { + panic(err) +} + +var ( + req *bleve.SearchRequest + res *bleve.SearchResult +) + +// Example 1: Simple Match Query on a field within a nested document, should work as if it were a flat field +q1 := bleve.NewMatchQuery("Engineer") +q1.SetField("company.departments.employees.role") +req = bleve.NewSearchRequest(q1) +res, err = index.Search(req) +if err != nil { + panic(err) +} +fmt.Println("Match Query Results:", res) + +// Example 2: Conjunction Query (AND) on fields within the same nested document +// like finding employees with name "Eve" and role "Manager". This will only match +// if both terms are in the same employee object. +q1 = bleve.NewMatchQuery("Eve") +q1.SetField("company.departments.employees.name") +q2 := bleve.NewMatchQuery("Manager") +q2.SetField("company.departments.employees.role") +conjQuery := bleve.NewConjunctionQuery( + q1, + q2, +) +req = bleve.NewSearchRequest(conjQuery) +res, err = index.Search(req) +if err != nil { + panic(err) +} +fmt.Println("Conjunction Query Results:", res) + +// Example 3: Multi-level Nested Query, finding projects with status "ongoing" +// within the "Engineering" department. This ensures both conditions are met +// within the correct hierarchy, i.e., the ongoing project must belong to the +// Engineering department. +q1 = bleve.NewMatchQuery("Engineering") +q1.SetField("company.departments.name") +q2 = bleve.NewMatchQuery("ongoing") +q2.SetField("company.departments.projects.status") +multiLevelQuery := bleve.NewConjunctionQuery( + q1, + q2, +) +req = bleve.NewSearchRequest(multiLevelQuery) +res, err = index.Search(req) +if err != nil { + panic(err) +} +fmt.Println("Multi-level Nested Query Results:", res) + +// Example 4: Multiple Arrays Query, finding documents with a location in "London" +// and an employee with the role "Manager". This checks conditions across different arrays. +q1 = bleve.NewMatchQuery("London") +q1.SetField("company.locations.city") +q2 = bleve.NewMatchQuery("Manager") +q2.SetField("company.departments.employees.role") +multiArrayQuery := bleve.NewConjunctionQuery( + q1, + q2, +) +req = bleve.NewSearchRequest(multiArrayQuery) +res, err = index.Search(req) +if err != nil { + panic(err) +} +fmt.Println("Multiple Arrays Query Results:", res) + +// Hybrid Arrays Query, combining multi-level and multiple arrays, +// finding documents with a Manager named Ivan working in Edinburgh, UK +q1 = bleve.NewMatchQuery("Ivan") +q1.SetField("company.departments.employees.name") +q2 = bleve.NewMatchQuery("Manager") +q2.SetField("company.departments.employees.role") +q3 := bleve.NewMatchQuery("Edinburgh") +q3.SetField("company.locations.city") +q4 := bleve.NewMatchQuery("UK") +q4.SetField("company.locations.country") +hybridArrayQuery := bleve.NewConjunctionQuery( + bleve.NewConjunctionQuery( + q1, + q2, + ), + bleve.NewConjunctionQuery( + q3, + q4, + ), +) +req = bleve.NewSearchRequest(hybridArrayQuery) +res, err = index.Search(req) +if err != nil { + panic(err) +} +fmt.Println("Hybrid Arrays Query Results:", res) + +// Close the index when done +err = index.Close() +if err != nil { + panic(err) +} +``` diff --git a/document/document.go b/document/document.go index 569d57bd6..7efea56da 100644 --- a/document/document.go +++ b/document/document.go @@ -18,6 +18,7 @@ import ( "fmt" "reflect" + "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/size" index "github.com/blevesearch/bleve_index_api" ) @@ -30,8 +31,9 @@ func init() { } type Document struct { - id string `json:"id"` - Fields []Field `json:"fields"` + id string + Fields []Field `json:"fields"` + NestedDocuments []*Document `json:"nested_documents"` CompositeFields []*CompositeField StoredFieldsSize uint64 indexed bool @@ -157,3 +159,34 @@ func (d *Document) SetIndexed() { func (d *Document) Indexed() bool { return d.indexed } + +func (d *Document) AddNestedDocument(doc *Document) { + d.NestedDocuments = append(d.NestedDocuments, doc) +} + +func (d *Document) NestedFields() search.FieldSet { + if len(d.NestedDocuments) == 0 { + return nil + } + fieldSet := search.NewFieldSet() + var collectFields func(index.Document) + collectFields = func(doc index.Document) { + // Add all field names from this nested document + doc.VisitFields(func(field index.Field) { + fieldSet.AddField(field.Name()) + }) + // Recursively collect from this document's nested documents + if nd, ok := doc.(index.NestedDocument); ok { + nd.VisitNestedDocuments(collectFields) + } + } + // Start collection from nested documents only (not root document) + d.VisitNestedDocuments(collectFields) + return fieldSet +} + +func (d *Document) VisitNestedDocuments(visitor func(doc index.Document)) { + for _, doc := range d.NestedDocuments { + visitor(doc) + } +} diff --git a/index/scorch/introducer.go b/index/scorch/introducer.go index cb11d5072..ef26532b0 100644 --- a/index/scorch/introducer.go +++ b/index/scorch/introducer.go @@ -170,6 +170,11 @@ func (s *Scorch) introduceSegment(next *segmentIntroduction) error { newss.deleted = nil } + // update the deleted bitmap to include any nested/sub-documents as well + // if the segment supports that + if ns, ok := newss.segment.(segment.NestedSegment); ok { + newss.deleted = ns.AddNestedDocuments(newss.deleted) + } // check for live size before copying if newss.LiveSize() > 0 { newSnapshot.segment = append(newSnapshot.segment, newss) diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 287d8e07f..329de598e 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -799,6 +799,12 @@ func analyze(d index.Document, fn customAnalyzerPluginInitFunc) { } } }) + if nd, ok := d.(index.NestedDocument); ok { + nd.VisitNestedDocuments(func(doc index.Document) { + doc.AddIDField() + analyze(doc, fn) + }) + } } func (s *Scorch) AddEligibleForRemoval(epoch uint64) { diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 3f2a330c5..c6bc2a5e6 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -17,7 +17,6 @@ package scorch import ( "container/heap" "context" - "encoding/binary" "fmt" "os" "path/filepath" @@ -42,9 +41,8 @@ type asynchSegmentResult struct { dict segment.TermDictionary dictItr segment.DictionaryIterator - cardinality int - index int - docs *roaring.Bitmap + index int + docs *roaring.Bitmap thesItr segment.ThesaurusIterator @@ -59,11 +57,11 @@ func init() { var err error lb1, err = lev.NewLevenshteinAutomatonBuilder(1, true) if err != nil { - panic(fmt.Errorf("Levenshtein automaton ed1 builder err: %v", err)) + panic(fmt.Errorf("levenshtein automaton ed1 builder err: %v", err)) } lb2, err = lev.NewLevenshteinAutomatonBuilder(2, true) if err != nil { - panic(fmt.Errorf("Levenshtein automaton ed2 builder err: %v", err)) + panic(fmt.Errorf("levenshtein automaton ed2 builder err: %v", err)) } } @@ -474,7 +472,7 @@ func (is *IndexSnapshot) GetInternal(key []byte) ([]byte, error) { func (is *IndexSnapshot) DocCount() (uint64, error) { var rv uint64 for _, segment := range is.segment { - rv += segment.Count() + rv += segment.CountRoot() } return rv, nil } @@ -501,7 +499,7 @@ func (is *IndexSnapshot) Document(id string) (rv index.Document, err error) { return nil, nil } - docNum, err := docInternalToNumber(next.ID) + docNum, err := next.ID.Value() if err != nil { return nil, err } @@ -571,7 +569,7 @@ func (is *IndexSnapshot) segmentIndexAndLocalDocNumFromGlobal(docNum uint64) (in } func (is *IndexSnapshot) ExternalID(id index.IndexInternalID) (string, error) { - docNum, err := docInternalToNumber(id) + docNum, err := id.Value() if err != nil { return "", err } @@ -589,7 +587,7 @@ func (is *IndexSnapshot) ExternalID(id index.IndexInternalID) (string, error) { } func (is *IndexSnapshot) segmentIndexAndLocalDocNum(id index.IndexInternalID) (int, uint64, error) { - docNum, err := docInternalToNumber(id) + docNum, err := id.Value() if err != nil { return 0, 0, err } @@ -776,25 +774,6 @@ func (is *IndexSnapshot) recycleTermFieldReader(tfr *IndexSnapshotTermFieldReade is.m2.Unlock() } -func docNumberToBytes(buf []byte, in uint64) []byte { - if len(buf) != 8 { - if cap(buf) >= 8 { - buf = buf[0:8] - } else { - buf = make([]byte, 8) - } - } - binary.BigEndian.PutUint64(buf, in) - return buf -} - -func docInternalToNumber(in index.IndexInternalID) (uint64, error) { - if len(in) != 8 { - return 0, fmt.Errorf("wrong len for IndexInternalID: %q", in) - } - return binary.BigEndian.Uint64(in), nil -} - func (is *IndexSnapshot) documentVisitFieldTermsOnSegment( segmentIndex int, localDocNum uint64, fields []string, cFields []string, visitor index.DocValueVisitor, dvs segment.DocVisitState) ( @@ -897,7 +876,7 @@ func (dvr *DocValueReader) BytesRead() uint64 { func (dvr *DocValueReader) VisitDocValues(id index.IndexInternalID, visitor index.DocValueVisitor, ) (err error) { - docNum, err := docInternalToNumber(id) + docNum, err := id.Value() if err != nil { return err } @@ -1297,3 +1276,23 @@ func (is *IndexSnapshot) TermFrequencies(field string, limit int, descending boo return termFreqs[:limit], nil } + +// Ancestors returns the ancestor IDs for the given document ID. The prealloc +// slice can be provided to avoid allocations downstream, and MUST be empty. +func (i *IndexSnapshot) Ancestors(ID index.IndexInternalID, prealloc []index.AncestorID) ([]index.AncestorID, error) { + // get segment and local doc num for the ID + seg, ldoc, err := i.segmentIndexAndLocalDocNum(ID) + if err != nil { + return nil, err + } + // get ancestors from the segment + prealloc = i.segment[seg].Ancestors(ldoc, prealloc) + // get global offset for the segment (correcting factor for multi-segment indexes) + globalOffset := i.offsets[seg] + // adjust ancestors to global doc numbers, not local to segment + for idx := range prealloc { + prealloc[idx] = prealloc[idx].Add(globalOffset) + } + // return adjusted ancestors + return prealloc, nil +} diff --git a/index/scorch/snapshot_index_doc.go b/index/scorch/snapshot_index_doc.go index 0a979bfb5..4048a199b 100644 --- a/index/scorch/snapshot_index_doc.go +++ b/index/scorch/snapshot_index_doc.go @@ -15,7 +15,6 @@ package scorch import ( - "bytes" "reflect" "github.com/RoaringBitmap/roaring/v2" @@ -49,7 +48,7 @@ func (i *IndexSnapshotDocIDReader) Next() (index.IndexInternalID, error) { next := i.iterators[i.segmentOffset].Next() // make segment number into global number by adding offset globalOffset := i.snapshot.offsets[i.segmentOffset] - return docNumberToBytes(nil, uint64(next)+globalOffset), nil + return index.NewIndexInternalID(nil, uint64(next)+globalOffset), nil } return nil, nil } @@ -63,7 +62,7 @@ func (i *IndexSnapshotDocIDReader) Advance(ID index.IndexInternalID) (index.Inde if next == nil { return nil, nil } - for bytes.Compare(next, ID) < 0 { + for next.Compare(ID) < 0 { next, err = i.Next() if err != nil { return nil, err diff --git a/index/scorch/snapshot_index_tfr.go b/index/scorch/snapshot_index_tfr.go index cd4d82dce..08d423925 100644 --- a/index/scorch/snapshot_index_tfr.go +++ b/index/scorch/snapshot_index_tfr.go @@ -15,7 +15,6 @@ package scorch import ( - "bytes" "context" "fmt" "reflect" @@ -94,7 +93,7 @@ func (i *IndexSnapshotTermFieldReader) Next(preAlloced *index.TermFieldDoc) (*in // make segment number into global number by adding offset globalOffset := i.snapshot.offsets[i.segmentOffset] nnum := next.Number() - rv.ID = docNumberToBytes(rv.ID, nnum+globalOffset) + rv.ID = index.NewIndexInternalID(rv.ID, nnum+globalOffset) i.postingToTermFieldDoc(next, rv) i.currID = rv.ID @@ -146,7 +145,7 @@ func (i *IndexSnapshotTermFieldReader) postingToTermFieldDoc(next segment.Postin func (i *IndexSnapshotTermFieldReader) Advance(ID index.IndexInternalID, preAlloced *index.TermFieldDoc) (*index.TermFieldDoc, error) { // FIXME do something better // for now, if we need to seek backwards, then restart from the beginning - if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 { + if i.currPosting != nil && i.currID.Compare(ID) >= 0 { // Check if the TFR is a special unadorned composite optimization. // Such a TFR will NOT have a valid `term` or `field` set, making it // impossible for the TFR to replace itself with a new one. @@ -171,7 +170,7 @@ func (i *IndexSnapshotTermFieldReader) Advance(ID index.IndexInternalID, preAllo } } } - num, err := docInternalToNumber(ID) + num, err := ID.Value() if err != nil { return nil, fmt.Errorf("error converting to doc number % x - %v", ID, err) } @@ -196,7 +195,7 @@ func (i *IndexSnapshotTermFieldReader) Advance(ID index.IndexInternalID, preAllo if preAlloced == nil { preAlloced = &index.TermFieldDoc{} } - preAlloced.ID = docNumberToBytes(preAlloced.ID, next.Number()+ + preAlloced.ID = index.NewIndexInternalID(preAlloced.ID, next.Number()+ i.snapshot.offsets[segIndex]) i.postingToTermFieldDoc(next, preAlloced) i.currID = preAlloced.ID diff --git a/index/scorch/snapshot_index_vr.go b/index/scorch/snapshot_index_vr.go index bd57ad3e0..5e510c4d6 100644 --- a/index/scorch/snapshot_index_vr.go +++ b/index/scorch/snapshot_index_vr.go @@ -18,7 +18,6 @@ package scorch import ( - "bytes" "context" "encoding/json" "fmt" @@ -96,7 +95,7 @@ func (i *IndexSnapshotVectorReader) Next(preAlloced *index.VectorDoc) ( // make segment number into global number by adding offset globalOffset := i.snapshot.offsets[i.segmentOffset] nnum := next.Number() - rv.ID = docNumberToBytes(rv.ID, nnum+globalOffset) + rv.ID = index.NewIndexInternalID(rv.ID, nnum+globalOffset) rv.Score = float64(next.Score()) i.currID = rv.ID @@ -113,7 +112,7 @@ func (i *IndexSnapshotVectorReader) Next(preAlloced *index.VectorDoc) ( func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID, preAlloced *index.VectorDoc) (*index.VectorDoc, error) { - if i.currPosting != nil && bytes.Compare(i.currID, ID) >= 0 { + if i.currPosting != nil && i.currID.Compare(ID) >= 0 { i2, err := i.snapshot.VectorReader(i.ctx, i.vector, i.field, i.k, i.searchParams, i.eligibleSelector) if err != nil { @@ -124,7 +123,7 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID, *i = *(i2.(*IndexSnapshotVectorReader)) } - num, err := docInternalToNumber(ID) + num, err := ID.Value() if err != nil { return nil, fmt.Errorf("error converting to doc number % x - %v", ID, err) } @@ -149,7 +148,7 @@ func (i *IndexSnapshotVectorReader) Advance(ID index.IndexInternalID, if preAlloced == nil { preAlloced = &index.VectorDoc{} } - preAlloced.ID = docNumberToBytes(preAlloced.ID, next.Number()+ + preAlloced.ID = index.NewIndexInternalID(preAlloced.ID, next.Number()+ i.snapshot.offsets[segIndex]) i.currID = preAlloced.ID i.currPosting = next diff --git a/index/scorch/snapshot_segment.go b/index/scorch/snapshot_segment.go index c6f3584cc..34f7a4695 100644 --- a/index/scorch/snapshot_segment.go +++ b/index/scorch/snapshot_segment.go @@ -113,6 +113,19 @@ func (s *SegmentSnapshot) Count() uint64 { return rv } +// this counts the root documents in the segment this differs from Count() in that +// Count() counts all live documents including nested children, whereas this method +// counts only root live documents +func (s *SegmentSnapshot) CountRoot() uint64 { + var rv uint64 + if nsb, ok := s.segment.(segment.NestedSegment); ok { + rv = nsb.CountRoot(s.deleted) + } else { + rv = s.Count() + } + return rv +} + func (s *SegmentSnapshot) DocNumbers(docIDs []string) (*roaring.Bitmap, error) { rv, err := s.segment.DocNumbers(docIDs) if err != nil { @@ -361,3 +374,11 @@ func (c *cachedMeta) fetchMeta(field string) (rv interface{}) { c.m.RUnlock() return rv } + +func (s *SegmentSnapshot) Ancestors(docNum uint64, prealloc []index.AncestorID) []index.AncestorID { + nsb, ok := s.segment.(segment.NestedSegment) + if !ok { + return append(prealloc, index.NewAncestorID(docNum)) + } + return nsb.Ancestors(docNum, prealloc) +} diff --git a/index_impl.go b/index_impl.go index 8065d9c1e..bbc9a01a4 100644 --- a/index_impl.go +++ b/index_impl.go @@ -572,8 +572,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in return nil, err } - fs := make(query.FieldSet) - fs, err := query.ExtractFields(req.Query, i.m, fs) + fs, err := query.ExtractFields(req.Query, i.m, search.NewFieldSet()) if err != nil { return nil, err } @@ -642,7 +641,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr // ------------------------------------------------------------------------------------------ // set up additional contexts for any search operation that will proceed from - // here, such as presearch, collectors etc. + // here, such as presearch, knn collector, topn collector etc. // Scoring model callback to be used to get scoring model scoringModelCallback := func() string { @@ -687,6 +686,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } ctx = context.WithValue(ctx, search.GeoBufferPoolCallbackKey, search.GeoBufferPoolCallbackFunc(getBufferPool)) + // check if the index mapping has any nested fields, which should force + // all collectors and searchers to be run in nested mode + if nm, ok := i.m.(mapping.NestedMapping); ok { + if nm.CountNested() > 0 { + ctx = context.WithValue(ctx, search.NestedSearchKey, true) + } + } // ------------------------------------------------------------------------------------------ if _, ok := ctx.Value(search.PreSearchKey).(bool); ok { @@ -716,11 +722,9 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr req.SearchBefore = nil } - var coll *collector.TopNCollector - if req.SearchAfter != nil { - coll = collector.NewTopNCollectorAfter(req.Size, req.Sort, req.SearchAfter) - } else { - coll = collector.NewTopNCollector(req.Size, req.From, req.Sort) + coll, err := i.buildTopNCollector(ctx, req, indexReader) + if err != nil { + return nil, err } var knnHits []*search.DocumentMatch @@ -795,7 +799,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr // if score fusion, no faceting for knn hits is done // hence we can skip setting the knn hits in the collector if !contextScoreFusionKeyExists { - setKnnHitsInCollector(knnHits, req, coll) + setKnnHitsInCollector(knnHits, coll) } if fts != nil { @@ -937,7 +941,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if i.name != "" && hit.Index == "" { hit.Index = i.name } - err, storedFieldsBytes := LoadAndHighlightFields(hit, req, i.name, indexReader, highlighter) + err, storedFieldsBytes := LoadAndHighlightAllFields(hit, req, i.name, indexReader, highlighter) if err != nil { return nil, err } @@ -1105,6 +1109,56 @@ func LoadAndHighlightFields(hit *search.DocumentMatch, req *SearchRequest, return nil, totalStoredFieldsBytes } +const NestedDocumentKey = "_$nested" + +// LoadAndHighlightAllFields loads stored fields + highlights for root and its descendants. +// All descendant documents are collected into a _$nested array in the root DocumentMatch. +func LoadAndHighlightAllFields( + root *search.DocumentMatch, + req *SearchRequest, + indexName string, + r index.IndexReader, + highlighter highlight.Highlighter, +) (error, uint64) { + var totalStoredFieldsBytes uint64 + // load root fields/highlights + err, bytes := LoadAndHighlightFields(root, req, indexName, r, highlighter) + totalStoredFieldsBytes += bytes + if err != nil { + return err, totalStoredFieldsBytes + } + // collect all descendant documents + nestedDocs := make([]*search.NestedDocumentMatch, 0, len(root.Descendants)) + // create a dummy desc DocumentMatch to reuse LoadAndHighlightFields + desc := &search.DocumentMatch{} + for _, descID := range root.Descendants { + extID, err := r.ExternalID(descID) + if err != nil { + return err, totalStoredFieldsBytes + } + // reset desc for reuse + desc.ID = extID + desc.IndexInternalID = descID + desc.Locations = root.Locations + err, bytes := LoadAndHighlightFields(desc, req, indexName, r, highlighter) + totalStoredFieldsBytes += bytes + if err != nil { + return err, totalStoredFieldsBytes + } + // copy fields to nested doc and append + if len(desc.Fields) != 0 || len(desc.Fragments) != 0 { + nestedDocs = append(nestedDocs, search.NewNestedDocumentMatch(desc.Fields, desc.Fragments)) + } + desc.Fields = nil + desc.Fragments = nil + } + // add nested documents to root under _$nested key + if len(nestedDocs) > 0 { + root.AddFieldValue(NestedDocumentKey, nestedDocs) + } + return nil, totalStoredFieldsBytes +} + // Fields returns the name of all the fields this // Index has operated on. func (i *indexImpl) Fields() (fields []string, err error) { @@ -1487,3 +1541,39 @@ func (i *indexImpl) CentroidCardinalities(field string, limit int, descending bo return centroidCardinalities, nil } + +func (i *indexImpl) buildTopNCollector(ctx context.Context, req *SearchRequest, reader index.IndexReader) (*collector.TopNCollector, error) { + newCollector := func() *collector.TopNCollector { + if req.SearchAfter != nil { + return collector.NewTopNCollectorAfter(req.Size, req.Sort, req.SearchAfter) + } + return collector.NewTopNCollector(req.Size, req.From, req.Sort) + } + + newNestedCollector := func(nr index.NestedReader) *collector.TopNCollector { + if req.SearchAfter != nil { + return collector.NewNestedTopNCollectorAfter(req.Size, req.Sort, req.SearchAfter, nr) + } + return collector.NewNestedTopNCollector(req.Size, req.From, req.Sort, nr) + } + + // check if we are in nested mode + if nestedMode, ok := ctx.Value(search.NestedSearchKey).(bool); ok && nestedMode { + // get the nested reader from the index reader + if nr, ok := reader.(index.NestedReader); ok { + // check if the mapping has any nested fields that intersect + if nm, ok := i.m.(mapping.NestedMapping); ok { + var fs search.FieldSet + var err error + fs, err = query.ExtractFields(req.Query, i.m, fs) + if err != nil { + return nil, err + } + if nm.IntersectsPrefix(fs) { + return newNestedCollector(nr), nil + } + } + } + } + return newCollector(), nil +} diff --git a/index_test.go b/index_test.go index 7ed27ff86..0cc6ce8e1 100644 --- a/index_test.go +++ b/index_test.go @@ -614,7 +614,7 @@ func TestBytesRead(t *testing.T) { expectedBytesRead := uint64(22049) if supportForVectorSearch { - expectedBytesRead = 22459 + expectedBytesRead = 21574 } if prevBytesRead != expectedBytesRead && res.Cost == prevBytesRead { @@ -772,7 +772,7 @@ func TestBytesReadStored(t *testing.T) { expectedBytesRead := uint64(11911) if supportForVectorSearch { - expectedBytesRead = 12321 + expectedBytesRead = 11435 } if bytesRead != expectedBytesRead && bytesRead == res.Cost { @@ -849,7 +849,7 @@ func TestBytesReadStored(t *testing.T) { expectedBytesRead = uint64(4097) if supportForVectorSearch { - expectedBytesRead = 4507 + expectedBytesRead = 3622 } if bytesRead != expectedBytesRead && bytesRead == res.Cost { diff --git a/index_update.go b/index_update.go index 5666d035b..cdd69e458 100644 --- a/index_update.go +++ b/index_update.go @@ -180,6 +180,10 @@ func checkUpdatedMapping(ori, upd *mapping.DocumentMapping) error { return nil } + if ori.Nested != upd.Nested { + return fmt.Errorf("nested property cannot be changed") + } + var err error // Recursively go through the child mappings for name, updDMapping := range upd.Properties { diff --git a/index_update_test.go b/index_update_test.go index 5d6326576..9ae9df83a 100644 --- a/index_update_test.go +++ b/index_update_test.go @@ -3082,3 +3082,133 @@ func BenchmarkIndexUpdateText(b *testing.B) { } } } + +func TestIndexUpdateNestedMapping(t *testing.T) { + // Helper: create a mapping with optional nested structure + createCompanyMapping := func(nestedEmployees, nestedDepartments, nestedProjects, nestedLocations bool) *mapping.IndexMappingImpl { + rv := mapping.NewIndexMapping() + companyMapping := mapping.NewDocumentMapping() + + // Basic fields + companyMapping.AddFieldMappingsAt("id", mapping.NewTextFieldMapping()) + companyMapping.AddFieldMappingsAt("name", mapping.NewTextFieldMapping()) + + var deptMapping *mapping.DocumentMapping + // Departments nested conditionally + if !nestedDepartments { + deptMapping = mapping.NewDocumentMapping() + } else { + deptMapping = mapping.NewNestedDocumentMapping() + } + deptMapping.AddFieldMappingsAt("name", mapping.NewTextFieldMapping()) + deptMapping.AddFieldMappingsAt("budget", mapping.NewNumericFieldMapping()) + + // Employees nested conditionally + var empMapping *mapping.DocumentMapping + if !nestedEmployees { + empMapping = mapping.NewNestedDocumentMapping() + } else { + empMapping = mapping.NewDocumentMapping() + } + empMapping.AddFieldMappingsAt("name", mapping.NewTextFieldMapping()) + empMapping.AddFieldMappingsAt("role", mapping.NewTextFieldMapping()) + deptMapping.AddSubDocumentMapping("employees", empMapping) + + // Projects nested conditionally + var projMapping *mapping.DocumentMapping + if !nestedProjects { + projMapping = mapping.NewNestedDocumentMapping() + } else { + projMapping = mapping.NewDocumentMapping() + } + projMapping.AddFieldMappingsAt("title", mapping.NewTextFieldMapping()) + projMapping.AddFieldMappingsAt("status", mapping.NewTextFieldMapping()) + deptMapping.AddSubDocumentMapping("projects", projMapping) + + companyMapping.AddSubDocumentMapping("departments", deptMapping) + + // Locations nested conditionally + var locMapping *mapping.DocumentMapping + if nestedLocations { + locMapping = mapping.NewNestedDocumentMapping() + } else { + locMapping = mapping.NewDocumentMapping() + } + locMapping.AddFieldMappingsAt("address", mapping.NewTextFieldMapping()) + locMapping.AddFieldMappingsAt("city", mapping.NewTextFieldMapping()) + + companyMapping.AddSubDocumentMapping("locations", locMapping) + + rv.DefaultMapping.AddSubDocumentMapping("company", companyMapping) + return rv + } + + tests := []struct { + name string + original *mapping.IndexMappingImpl + updated *mapping.IndexMappingImpl + expectErr bool + }{ + { + name: "No nested to all nested", + original: createCompanyMapping(false, false, false, false), + updated: createCompanyMapping(true, true, true, true), + expectErr: true, + }, + { + name: "No nested to mixed nested", + original: createCompanyMapping(false, false, false, false), + updated: createCompanyMapping(true, false, true, false), + expectErr: true, + }, + { + name: "No nested to mixed nested", + original: createCompanyMapping(false, false, false, false), + updated: createCompanyMapping(true, true, true, false), + expectErr: true, + }, + { + name: "Mixed nested to no nested", + original: createCompanyMapping(false, true, false, true), + updated: createCompanyMapping(false, false, true, true), + expectErr: true, + }, + { + name: "All nested to no nested", + original: createCompanyMapping(true, true, true, true), + updated: createCompanyMapping(false, false, false, false), + expectErr: true, + }, + { + name: "Mixed nested to all nested", + original: createCompanyMapping(true, false, true, false), + updated: createCompanyMapping(true, true, true, true), + expectErr: true, + }, + { + name: "All nested to mixed nested", + original: createCompanyMapping(true, true, true, true), + updated: createCompanyMapping(true, false, true, false), + expectErr: true, + }, + { + name: "No nested to no nested", + original: createCompanyMapping(false, false, false, false), + updated: createCompanyMapping(false, false, false, false), + expectErr: false, + }, + { + name: "All nested to all nested", + original: createCompanyMapping(true, true, true, true), + updated: createCompanyMapping(true, true, true, true), + expectErr: false, + }, + } + + for _, test := range tests { + _, err := DeletedFields(test.original, test.updated) + if (err != nil) != test.expectErr { + t.Errorf("Test '%s' unexpected error state: got %v, expectErr %t", test.name, err, test.expectErr) + } + } +} diff --git a/mapping.go b/mapping.go index 723105a29..af02db386 100644 --- a/mapping.go +++ b/mapping.go @@ -34,6 +34,20 @@ func NewDocumentStaticMapping() *mapping.DocumentMapping { return mapping.NewDocumentStaticMapping() } +// NewNestedDocumentMapping returns a new document mapping +// that will treat all objects as nested documents. +func NewNestedDocumentMapping() *mapping.DocumentMapping { + return mapping.NewNestedDocumentMapping() +} + +// NewNestedDocumentStaticMapping returns a new document mapping +// that will treat all objects as nested documents and +// will not automatically index parts of a nested document +// without an explicit mapping. +func NewNestedDocumentStaticMapping() *mapping.DocumentMapping { + return mapping.NewNestedDocumentStaticMapping() +} + // NewDocumentDisabledMapping returns a new document // mapping that will not perform any indexing. func NewDocumentDisabledMapping() *mapping.DocumentMapping { diff --git a/mapping/document.go b/mapping/document.go index a78b27e11..3da925038 100644 --- a/mapping/document.go +++ b/mapping/document.go @@ -22,6 +22,7 @@ import ( "reflect" "time" + "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/registry" "github.com/blevesearch/bleve/v2/util" ) @@ -44,6 +45,7 @@ type DocumentMapping struct { Dynamic bool `json:"dynamic"` Properties map[string]*DocumentMapping `json:"properties,omitempty"` Fields []*FieldMapping `json:"fields,omitempty"` + Nested bool `json:"nested,omitempty"` DefaultAnalyzer string `json:"default_analyzer,omitempty"` DefaultSynonymSource string `json:"default_synonym_source,omitempty"` @@ -230,6 +232,17 @@ func NewDocumentMapping() *DocumentMapping { } } +// NewNestedDocumentMapping returns a new document +// mapping that treats sub-documents as nested +// objects. +func NewNestedDocumentMapping() *DocumentMapping { + return &DocumentMapping{ + Nested: true, + Enabled: true, + Dynamic: true, + } +} + // NewDocumentStaticMapping returns a new document // mapping that will not automatically index parts // of a document without an explicit mapping. @@ -239,6 +252,17 @@ func NewDocumentStaticMapping() *DocumentMapping { } } +// NewNestedDocumentStaticMapping returns a new document +// mapping that treats sub-documents as nested +// objects and will not automatically index parts +// of the nested document without an explicit mapping. +func NewNestedDocumentStaticMapping() *DocumentMapping { + return &DocumentMapping{ + Enabled: true, + Nested: true, + } +} + // NewDocumentDisabledMapping returns a new document // mapping that will not perform any indexing. func NewDocumentDisabledMapping() *DocumentMapping { @@ -312,6 +336,11 @@ func (dm *DocumentMapping) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "nested": + err := util.UnmarshalJSON(v, &dm.Nested) + if err != nil { + return err + } case "default_analyzer": err := util.UnmarshalJSON(v, &dm.DefaultAnalyzer) if err != nil { @@ -381,6 +410,18 @@ func (dm *DocumentMapping) defaultSynonymSource(path []string) string { return rv } +// baseType returns the base type of v by dereferencing pointers +func baseType(v interface{}) reflect.Type { + if v == nil { + return nil + } + t := reflect.TypeOf(v) + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + return t +} + func (dm *DocumentMapping) walkDocument(data interface{}, path []string, indexes []uint64, context *walkContext) { // allow default "json" tag to be overridden structTagKey := dm.StructTagKey @@ -434,11 +475,39 @@ func (dm *DocumentMapping) walkDocument(data interface{}, path []string, indexes } } case reflect.Slice, reflect.Array: + subDocMapping, _ := dm.documentMappingForPathElements(path) + allowNested := subDocMapping != nil && subDocMapping.Nested for i := 0; i < val.Len(); i++ { - if val.Index(i).CanInterface() { - fieldVal := val.Index(i).Interface() - dm.processProperty(fieldVal, path, append(indexes, uint64(i)), context) + // for each array element, check if it can be represented as an interface + idxVal := val.Index(i) + // skip invalid values + if !idxVal.CanInterface() { + continue + } + // get the actual value in interface form + actual := idxVal.Interface() + // if nested mapping, only create nested document for object elements + if allowNested && actual != nil { + // check the kind of the actual value, is it an object (struct or map)? + typ := baseType(actual) + if typ == nil { + continue + } + kind := typ.Kind() + // only create nested docs for real JSON objects + if kind == reflect.Struct || kind == reflect.Map { + // Create nested document only for only object elements + nestedDocument := document.NewDocument( + fmt.Sprintf("%s_$%s_$%d", context.doc.ID(), encodePath(path), i)) + nestedContext := context.im.newWalkContext(nestedDocument, dm) + dm.processProperty(actual, path, append(indexes, uint64(i)), nestedContext) + context.doc.AddNestedDocument(nestedDocument) + continue + } } + // non-nested mapping, or non-object element in nested mapping + // process the element normally + dm.processProperty(actual, path, append(indexes, uint64(i)), context) } case reflect.Ptr: ptrElem := val.Elem() diff --git a/mapping/index.go b/mapping/index.go index 7878cce8b..bafb6ee89 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -17,12 +17,14 @@ package mapping import ( "encoding/json" "fmt" + "strings" "github.com/blevesearch/bleve/v2/analysis" "github.com/blevesearch/bleve/v2/analysis/analyzer/standard" "github.com/blevesearch/bleve/v2/analysis/datetime/optional" "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/registry" + "github.com/blevesearch/bleve/v2/search" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" ) @@ -195,11 +197,19 @@ func (im *IndexMappingImpl) Validate() error { // the map will hold the fully qualified field name to FieldMapping, so we can // check for conflicts as we validate each DocumentMapping. fieldAliasCtx := make(map[string]*FieldMapping) + // ensure that the nested property is not set for top-level default mapping + if im.DefaultMapping.Nested { + return fmt.Errorf("default mapping cannot be nested") + } err = im.DefaultMapping.Validate(im.cache, []string{}, fieldAliasCtx) if err != nil { return err } - for _, docMapping := range im.TypeMapping { + for name, docMapping := range im.TypeMapping { + // ensure that the nested property is not set for top-level mappings + if docMapping.Nested { + return fmt.Errorf("type mapping named: %s cannot be nested", name) + } err = docMapping.Validate(im.cache, []string{}, fieldAliasCtx) if err != nil { return err @@ -366,7 +376,13 @@ func (im *IndexMappingImpl) MapDocument(doc *document.Document, data interface{} // see if the _all field was disabled allMapping, _ := docMapping.documentMappingForPath("_all") if allMapping == nil || allMapping.Enabled { - field := document.NewCompositeFieldWithIndexingOptions("_all", true, []string{}, walkContext.excludedFromAll, index.IndexField|index.IncludeTermVectors) + excludedFromAll := walkContext.excludedFromAll + nf := doc.NestedFields() + if nf != nil { + // if the document has any nested fields, exclude them from _all + excludedFromAll = append(excludedFromAll, nf.Slice()...) + } + field := document.NewCompositeFieldWithIndexingOptions("_all", true, []string{}, excludedFromAll, index.IndexField|index.IncludeTermVectors) doc.AddField(field) } doc.SetIndexed() @@ -574,3 +590,70 @@ func (im *IndexMappingImpl) SynonymSourceVisitor(visitor analysis.SynonymSourceV } return nil } + +func (im *IndexMappingImpl) buildNestedPrefixes() map[string]int { + prefixDepth := make(map[string]int) + var collectNestedFields func(dm *DocumentMapping, pathComponents []string, currentDepth int) + collectNestedFields = func(dm *DocumentMapping, pathComponents []string, currentDepth int) { + for name, docMapping := range dm.Properties { + newPathComponents := append(pathComponents, name) + if docMapping.Nested { + // This is a nested field boundary + newDepth := currentDepth + 1 + prefixDepth[strings.Join(newPathComponents, pathSeparator)] = newDepth + // Continue deeper with incremented depth + collectNestedFields(docMapping, newPathComponents, newDepth) + } else { + // Not nested, continue with same depth + collectNestedFields(docMapping, newPathComponents, currentDepth) + } + } + } + // Start from depth 0 (root) + if im.DefaultMapping != nil && im.DefaultMapping.Enabled { + collectNestedFields(im.DefaultMapping, []string{}, 0) + } + // Now do this for each type mapping + for _, docMapping := range im.TypeMapping { + if docMapping.Enabled { + collectNestedFields(docMapping, []string{}, 0) + } + } + return prefixDepth +} + +func (im *IndexMappingImpl) NestedDepth(fs search.FieldSet) (int, int) { + if im.cache == nil || im.cache.NestedPrefixes == nil { + return 0, 0 + } + + im.cache.NestedPrefixes.InitOnce(func() map[string]int { + return im.buildNestedPrefixes() + }) + + return im.cache.NestedPrefixes.NestedDepth(fs) +} + +func (im *IndexMappingImpl) CountNested() int { + if im.cache == nil || im.cache.NestedPrefixes == nil { + return 0 + } + + im.cache.NestedPrefixes.InitOnce(func() map[string]int { + return im.buildNestedPrefixes() + }) + + return im.cache.NestedPrefixes.CountNested() +} + +func (im *IndexMappingImpl) IntersectsPrefix(fs search.FieldSet) bool { + if im.cache == nil || im.cache.NestedPrefixes == nil { + return false + } + + im.cache.NestedPrefixes.InitOnce(func() map[string]int { + return im.buildNestedPrefixes() + }) + + return im.cache.NestedPrefixes.IntersectsPrefix(fs) +} diff --git a/mapping/mapping.go b/mapping/mapping.go index a6c1591b8..7ff2f9927 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -20,6 +20,7 @@ import ( "github.com/blevesearch/bleve/v2/analysis" "github.com/blevesearch/bleve/v2/document" + "github.com/blevesearch/bleve/v2/search" ) // A Classifier is an interface describing any object which knows how to @@ -74,3 +75,21 @@ type SynonymMapping interface { SynonymSourceVisitor(visitor analysis.SynonymSourceVisitor) error } + +// A NestedMapping extends the IndexMapping interface to provide +// additional methods for working with nested object mappings. +type NestedMapping interface { + // NestedDepth returns two values: + // - common: the highest nested level that is common to all given field paths, + // if 0 then there is no common nested level among the given field paths + // - max: the highest nested level that applies to at least one of the given field paths + // if 0 then none of the given field paths are nested + NestedDepth(fieldPaths search.FieldSet) (int, int) + + // IntersectsPrefix returns true if any of the given + // field paths intersect with a known nested prefix + IntersectsPrefix(fieldPaths search.FieldSet) bool + + // CountNested returns the number of nested object mappings + CountNested() int +} diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index c3dee9310..7c7ff1b98 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -20,6 +20,7 @@ package mapping import ( "fmt" "reflect" + "slices" "github.com/blevesearch/bleve/v2/document" "github.com/blevesearch/bleve/v2/util" @@ -151,8 +152,10 @@ func (fm *FieldMapping) processVector(propertyMightBeVector interface{}, vectorIndexOptimizedFor = index.DefaultIndexOptimization } // normalize raw vector if similarity is cosine + // Since the vector can be multi-vector (flattened array of multiple vectors), + // we use NormalizeMultiVector to normalize each sub-vector independently. if similarity == index.CosineSimilarity { - vector = NormalizeVector(vector) + vector = NormalizeMultiVector(vector, fm.Dims) } fieldName := getFieldName(pathString, path, fm) @@ -186,13 +189,15 @@ func (fm *FieldMapping) processVectorBase64(propertyMightBeVectorBase64 interfac if err != nil || len(decodedVector) != fm.Dims { return } - // normalize raw vector if similarity is cosine + // normalize raw vector if similarity is cosine, multi-vector is not supported + // for base64 encoded vectors, so we use NormalizeVector directly. if similarity == index.CosineSimilarity { decodedVector = NormalizeVector(decodedVector) } fieldName := getFieldName(pathString, path, fm) options := fm.Options() + field := document.NewVectorFieldWithIndexingOptions(fieldName, indexes, decodedVector, fm.Dims, similarity, vectorIndexOptimizedFor, options) context.doc.AddField(field) @@ -292,11 +297,33 @@ func validateVectorFieldAlias(field *FieldMapping, path []string, return nil } +// NormalizeVector normalizes a single vector to unit length. +// It makes a copy of the input vector to avoid modifying it in-place. func NormalizeVector(vec []float32) []float32 { // make a copy of the vector to avoid modifying the original // vector in-place - vecCopy := make([]float32, len(vec)) - copy(vecCopy, vec) + vecCopy := slices.Clone(vec) // normalize the vector copy using in-place normalization provided by faiss return faiss.NormalizeVector(vecCopy) } + +// NormalizeMultiVector normalizes each sub-vector of size `dims` independently. +// For a flattened array containing multiple vectors, each sub-vector is +// normalized separately to unit length. +// It makes a copy of the input vector to avoid modifying it in-place. +func NormalizeMultiVector(vec []float32, dims int) []float32 { + if len(vec) == 0 || dims <= 0 || len(vec)%dims != 0 { + return vec + } + // Single vector - delegate to NormalizeVector + if len(vec) == dims { + return NormalizeVector(vec) + } + // Multi-vector - make a copy to avoid modifying the original + result := slices.Clone(vec) + // Normalize each sub-vector in-place + for i := 0; i < len(result); i += dims { + faiss.NormalizeVector(result[i : i+dims]) + } + return result +} diff --git a/mapping/mapping_vectors_test.go b/mapping/mapping_vectors_test.go index b00e5c094..0620510a0 100644 --- a/mapping/mapping_vectors_test.go +++ b/mapping/mapping_vectors_test.go @@ -18,6 +18,7 @@ package mapping import ( + "math" "reflect" "strings" "testing" @@ -1069,3 +1070,120 @@ func TestNormalizeVector(t *testing.T) { } } } + +func TestNormalizeMultiVectors(t *testing.T) { + tests := []struct { + name string + input []float32 + dims int + expected []float32 + }{ + { + name: "single vector - already normalized", + input: []float32{1, 0, 0}, + dims: 3, + expected: []float32{1, 0, 0}, + }, + { + name: "single vector - needs normalization", + input: []float32{3, 0, 0}, + dims: 3, + expected: []float32{1, 0, 0}, + }, + { + name: "two vectors - X and Y directions", + input: []float32{3, 0, 0, 0, 4, 0}, + dims: 3, + expected: []float32{1, 0, 0, 0, 1, 0}, + }, + { + name: "three vectors", + input: []float32{3, 0, 0, 0, 4, 0, 0, 0, 5}, + dims: 3, + expected: []float32{1, 0, 0, 0, 1, 0, 0, 0, 1}, + }, + { + name: "two 2D vectors", + input: []float32{3, 4, 5, 12}, + dims: 2, + expected: []float32{0.6, 0.8, 0.38461538, 0.92307693}, + }, + { + name: "empty vector", + input: []float32{}, + dims: 3, + expected: []float32{}, + }, + { + name: "zero dims", + input: []float32{1, 2, 3}, + dims: 0, + expected: []float32{1, 2, 3}, + }, + { + name: "negative dims", + input: []float32{1, 2, 3}, + dims: -1, + expected: []float32{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Make a copy of input to verify original is not modified + inputCopy := make([]float32, len(tt.input)) + copy(inputCopy, tt.input) + + result := NormalizeMultiVector(tt.input, tt.dims) + + // Check result matches expected + if len(result) != len(tt.expected) { + t.Errorf("length mismatch: expected %d, got %d", len(tt.expected), len(result)) + return + } + + for i := range result { + if !floatApproxEqual(result[i], tt.expected[i], 1e-5) { + t.Errorf("value mismatch at index %d: expected %v, got %v", + i, tt.expected[i], result[i]) + } + } + + // Verify original input was not modified + if !reflect.DeepEqual(tt.input, inputCopy) { + t.Errorf("original input was modified: was %v, now %v", inputCopy, tt.input) + } + + // For valid multi-vectors, verify each sub-vector has unit magnitude + if tt.dims > 0 && len(tt.input) > 0 && len(tt.input)%tt.dims == 0 { + numVecs := len(result) / tt.dims + for i := 0; i < numVecs; i++ { + subVec := result[i*tt.dims : (i+1)*tt.dims] + mag := magnitude(subVec) + // Allow for zero vectors (magnitude 0) or unit vectors (magnitude 1) + if mag > 1e-6 && !floatApproxEqual(mag, 1.0, 1e-5) { + t.Errorf("sub-vector %d has magnitude %v, expected 1.0", i, mag) + } + } + } + }) + } +} + +// Helper to compute magnitude of a vector +func magnitude(v []float32) float32 { + var sum float32 + for _, x := range v { + sum += x * x + } + return float32(math.Sqrt(float64(sum))) +} + +// Helper for approximate float comparison +func floatApproxEqual(a, b, epsilon float32) bool { + diff := a - b + if diff < 0 { + diff = -diff + } + return diff < epsilon +} diff --git a/registry/nested.go b/registry/nested.go new file mode 100644 index 000000000..a6b5336dd --- /dev/null +++ b/registry/nested.go @@ -0,0 +1,136 @@ +// Copyright (c) 2025 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package registry + +import ( + "strings" + "sync" + + "github.com/blevesearch/bleve/v2/search" +) + +// NestedFieldCache caches nested field prefixes and their corresponding nesting levels. +// A nested field prefix is a field path prefix that indicates the start of a nested document. +// The nesting level indicates how deep the nested document is in the overall document structure. +type NestedFieldCache struct { + // nested prefix -> nested level + prefixDepth map[string]int + once sync.Once + m sync.RWMutex +} + +func NewNestedFieldCache() *NestedFieldCache { + return &NestedFieldCache{} +} + +func (nfc *NestedFieldCache) InitOnce(buildFunc func() map[string]int) { + nfc.once.Do(func() { + nfc.m.Lock() + defer nfc.m.Unlock() + nfc.prefixDepth = buildFunc() + }) +} + +// NestedDepth returns two values: +// - common: The nesting level of the longest prefix that applies to every field path +// in the provided FieldSet. A value of 0 means no nested prefix is shared +// across all field paths. +// - max: The nesting level of the longest prefix that applies to at least one +// field path in the provided FieldSet. A value of 0 means none of the +// field paths match any nested prefix. +func (nfc *NestedFieldCache) NestedDepth(fieldPaths search.FieldSet) (common int, max int) { + // if no field paths, no nested depth + if len(fieldPaths) == 0 { + return + } + nfc.m.RLock() + defer nfc.m.RUnlock() + // if no cached prefixes, no nested depth + if len(nfc.prefixDepth) == 0 { + return + } + // for each prefix, check if its a common prefix or matches any path + // update common and max accordingly with the highest nesting level + // possible for each respective case + for prefix, level := range nfc.prefixDepth { + // only check prefixes that could increase one of the results + if level <= common && level <= max { + continue + } + // check prefix against field paths, getting whether it matches all paths (common) + // and whether it matches at least one path (any) + matchAll, matchAny := nfc.prefixMatch(prefix, fieldPaths) + // if it matches all paths, update common + if matchAll && level > common { + common = level + } + // if it matches any path, update max + if matchAny && level > max { + max = level + } + } + return common, max +} + +// CountNested returns the number of nested prefixes +func (nfc *NestedFieldCache) CountNested() int { + nfc.m.RLock() + defer nfc.m.RUnlock() + + return len(nfc.prefixDepth) +} + +// IntersectsPrefix returns true if any of the given +// field paths have a nested prefix +func (nfc *NestedFieldCache) IntersectsPrefix(fieldPaths search.FieldSet) bool { + // if no field paths, no intersection + if len(fieldPaths) == 0 { + return false + } + nfc.m.RLock() + defer nfc.m.RUnlock() + // if no cached prefixes, no intersection + if len(nfc.prefixDepth) == 0 { + return false + } + // Check each cached nested prefix to see if it intersects with any path + for prefix := range nfc.prefixDepth { + _, matchAny := nfc.prefixMatch(prefix, fieldPaths) + if matchAny { + return true + } + } + return false +} + +// prefixMatch checks whether the prefix matches all paths (common) and whether it matches at least one path (any) +// Caller must hold the read lock. +func (nfc *NestedFieldCache) prefixMatch(prefix string, fieldPaths search.FieldSet) (common bool, any bool) { + common = true + any = false + for path := range fieldPaths { + has := strings.HasPrefix(path, prefix) + if has { + any = true + } else { + common = false + } + // early exit if we have determined both values + if any && !common { + break + } + } + return common, any +} diff --git a/registry/registry.go b/registry/registry.go index 69ee8dd86..36f209d4f 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -49,6 +49,7 @@ type Cache struct { Fragmenters *FragmenterCache Highlighters *HighlighterCache SynonymSources *SynonymSourceCache + NestedPrefixes *NestedFieldCache } func NewCache() *Cache { @@ -63,6 +64,7 @@ func NewCache() *Cache { Fragmenters: NewFragmenterCache(), Highlighters: NewHighlighterCache(), SynonymSources: NewSynonymSourceCache(), + NestedPrefixes: NewNestedFieldCache(), } } diff --git a/search.go b/search.go index ee53ac6e2..41fabbdaa 100644 --- a/search.go +++ b/search.go @@ -625,11 +625,35 @@ func formatHit(rv *strings.Builder, hit *search.DocumentMatch, hitNumber int) *s } } for otherFieldName, otherFieldValue := range hit.Fields { + if otherFieldName == NestedDocumentKey { + continue + } if _, ok := hit.Fragments[otherFieldName]; !ok { fmt.Fprintf(rv, "\t%s\n", otherFieldName) fmt.Fprintf(rv, "\t\t%v\n", otherFieldValue) } } + // nested documents + if nested, ok := hit.Fields[NestedDocumentKey]; ok { + if list, ok := nested.([]*search.NestedDocumentMatch); ok { + fmt.Fprintf(rv, "\t%s (%d nested documents)\n", NestedDocumentKey, len(list)) + for ni, nd := range list { + fmt.Fprintf(rv, "\t\tNested #%d:\n", ni+1) + for f, frags := range nd.Fragments { + fmt.Fprintf(rv, "\t\t\t%s\n", f) + for _, frag := range frags { + fmt.Fprintf(rv, "\t\t\t\t%s\n", frag) + } + } + for f, v := range nd.Fields { + if _, ok := nd.Fragments[f]; !ok { + fmt.Fprintf(rv, "\t\t\t%s\n", f) + fmt.Fprintf(rv, "\t\t\t\t%v\n", v) + } + } + } + } + } if len(hit.DecodedSort) > 0 { fmt.Fprintf(rv, "\t_sort: [") for k, v := range hit.DecodedSort { diff --git a/search/collector/nested.go b/search/collector/nested.go new file mode 100644 index 000000000..9b137c6fe --- /dev/null +++ b/search/collector/nested.go @@ -0,0 +1,103 @@ +// Copyright (c) 2025 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package collector + +import ( + "github.com/blevesearch/bleve/v2/search" + index "github.com/blevesearch/bleve_index_api" +) + +type collectStoreNested struct { + // descAdder is used to customize how descendants are merged into their parent + descAdder search.DescendantAdderCallbackFn + // nested reader to retrieve ancestor information + nr index.NestedReader + // the current root document match being built + currRoot *search.DocumentMatch + // the ancestor ID of the current root document being built + currRootAncestorID index.AncestorID + // prealloc slice for ancestor IDs + ancestors []index.AncestorID +} + +func newStoreNested(nr index.NestedReader, descAdder search.DescendantAdderCallbackFn) *collectStoreNested { + rv := &collectStoreNested{ + descAdder: descAdder, + nr: nr, + } + return rv +} + +// ProcessNestedDocument adds a document to the nested store, merging it into its root document +// as needed. If the returned DocumentMatch is nil, the incoming doc has been merged +// into its parent and should not be processed further. If the returned DocumentMatch +// is non-nil, it represents a complete root document that should be processed further. +// NOTE: This implementation assumes that documents are added in increasing order of their internal IDs +// which is guaranteed by all searchers in bleve. +func (c *collectStoreNested) ProcessNestedDocument(ctx *search.SearchContext, doc *search.DocumentMatch) (*search.DocumentMatch, error) { + // find ancestors for the doc + var err error + c.ancestors, err = c.nr.Ancestors(doc.IndexInternalID, c.ancestors[:0]) + if err != nil { + return nil, err + } + if len(c.ancestors) == 0 { + // should not happen, every doc should have at least itself as ancestor + return nil, nil + } + // root docID is the last ancestor + rootID := c.ancestors[len(c.ancestors)-1] + // check if there is an interim root already and if the incoming doc belongs to it + if c.currRoot != nil && c.currRootAncestorID.Equals(rootID) { + // there is an interim root already, and the incoming doc belongs to it + if err := c.descAdder(c.currRoot, doc); err != nil { + return nil, err + } + // recycle the child document now that it's merged into the interim root + ctx.DocumentMatchPool.Put(doc) + return nil, nil + } + // completedRoot is the root document match to return, if any + var completedRoot *search.DocumentMatch + if c.currRoot != nil { + // we have an existing interim root, return it for processing + completedRoot = c.currRoot + } + // no interim root for now so either we have a root document incoming + // or we have a child doc and need to create an interim root + if len(c.ancestors) == 1 { + // incoming doc is the root itself + c.currRoot = doc + c.currRootAncestorID = rootID + return completedRoot, nil + } + // this is a child doc, create interim root + newDM := ctx.DocumentMatchPool.Get() + newDM.IndexInternalID = rootID.ToIndexInternalID(newDM.IndexInternalID) + // merge the incoming doc into the new interim root + c.currRoot = newDM + c.currRootAncestorID = rootID + if err := c.descAdder(c.currRoot, doc); err != nil { + return nil, err + } + // recycle the child document now that it's merged into the interim root + ctx.DocumentMatchPool.Put(doc) + return completedRoot, nil +} + +// Current returns the current interim root document match being built, if any +func (c *collectStoreNested) Current() *search.DocumentMatch { + return c.currRoot +} diff --git a/search/collector/topn.go b/search/collector/topn.go index 739dd8348..bab318d5c 100644 --- a/search/collector/topn.go +++ b/search/collector/topn.go @@ -78,7 +78,9 @@ type TopNCollector struct { searchAfter *search.DocumentMatch knnHits map[string]*search.DocumentMatch - computeNewScoreExpl search.ScoreExplCorrectionCallbackFunc + hybridMergeCallback search.HybridMergeCallbackFn + + nestedStore *collectStoreNested } // CheckDoneEvery controls how frequently we check the context deadline @@ -88,25 +90,74 @@ const CheckDoneEvery = uint64(1024) // skipping over the first 'skip' hits // ordering hits by the provided sort order func NewTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector { - return newTopNCollector(size, skip, sort) + return newTopNCollector(size, skip, sort, nil) } // NewTopNCollectorAfter builds a collector to find the top 'size' hits // skipping over the first 'skip' hits // ordering hits by the provided sort order +// starting after the provided 'after' sort values func NewTopNCollectorAfter(size int, sort search.SortOrder, after []string) *TopNCollector { - rv := newTopNCollector(size, 0, sort) + rv := newTopNCollector(size, 0, sort, nil) + rv.searchAfter = createSearchAfterDocument(sort, after) + return rv +} + +// NewNestedTopNCollector builds a collector to find the top 'size' hits +// skipping over the first 'skip' hits +// ordering hits by the provided sort order +// while ensuring the nested documents are handled correctly +// (i.e. parent document is returned instead of nested document) +func NewNestedTopNCollector(size int, skip int, sort search.SortOrder, nr index.NestedReader) *TopNCollector { + return newTopNCollector(size, skip, sort, nr) +} + +// NewNestedTopNCollectorAfter builds a collector to find the top 'size' hits +// skipping over the first 'skip' hits +// ordering hits by the provided sort order +// starting after the provided 'after' sort values +// while ensuring the nested documents are handled correctly +// (i.e. parent document is returned instead of nested document) +func NewNestedTopNCollectorAfter(size int, sort search.SortOrder, after []string, nr index.NestedReader) *TopNCollector { + rv := newTopNCollector(size, 0, sort, nr) rv.searchAfter = createSearchAfterDocument(sort, after) return rv } -func newTopNCollector(size int, skip int, sort search.SortOrder) *TopNCollector { +func newTopNCollector(size int, skip int, sort search.SortOrder, nr index.NestedReader) *TopNCollector { hc := &TopNCollector{size: size, skip: skip, sort: sort} hc.store = getOptimalCollectorStore(size, skip, func(i, j *search.DocumentMatch) int { return hc.sort.Compare(hc.cachedScoring, hc.cachedDesc, i, j) }) + if nr != nil { + descAdder := func(parent, child *search.DocumentMatch) error { + // add descendant score to parent score + parent.Score += child.Score + // merge explanations + parent.Expl = parent.Expl.MergeWith(child.Expl) + // merge field term locations + parent.FieldTermLocations = search.MergeFieldTermLocationsFromMatch(parent.FieldTermLocations, child) + // add child's ID to parent's Descendants + // add other as descendant only if it is not the same document + if !parent.IndexInternalID.Equals(child.IndexInternalID) { + // Add a copy of child.IndexInternalID to descendants, because + // child.IndexInternalID will be reset when 'child' is recycled. + var descendantID index.IndexInternalID + // first check if parent's descendants slice has capacity to reuse + if len(parent.Descendants) < cap(parent.Descendants) { + // reuse the buffer element at len(parent.Descendants) + descendantID = parent.Descendants[:len(parent.Descendants)+1][len(parent.Descendants)] + } + // copy the contents of id into descendantID, allocating if needed + parent.Descendants = append(parent.Descendants, index.NewIndexInternalIDFrom(descendantID, child.IndexInternalID)) + } + return nil + } + hc.nestedStore = newStoreNested(nr, search.DescendantAdderCallbackFn(descAdder)) + } + // these lookups traverse an interface, so do once up-front if sort.RequiresDocID() { hc.needDocIds = true @@ -283,8 +334,13 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, default: next, err = searcher.Next(searchContext) } + // use a local totalDocs for counting total docs seen + // for context deadline checking, as hc.total is only + // incremented for actual(root) collected documents, and + // we need to check deadline for every document seen (root or nested) + var totalDocs uint64 for err == nil && next != nil { - if hc.total%CheckDoneEvery == 0 { + if totalDocs%CheckDoneEvery == 0 { select { case <-ctx.Done(): search.RecordSearchCost(ctx, search.AbortM, 0) @@ -292,27 +348,60 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, default: } } - - err = hc.adjustDocumentMatch(searchContext, reader, next) - if err != nil { - break - } - - err = hc.prepareDocumentMatch(searchContext, reader, next, false) - if err != nil { - break + totalDocs++ + if hc.nestedStore != nil { + // This may be a nested document — add it to the nested store first. + // If the nested store returns nil, the document was merged into its parent + // and should not be processed further. + // If it returns a non-nil document, it represents a complete root document + // and should be processed further. + next, err = hc.nestedStore.ProcessNestedDocument(searchContext, next) + if err != nil { + break + } } - - err = dmHandler(next) - if err != nil { - break + if next != nil { + err = hc.adjustDocumentMatch(searchContext, reader, next) + if err != nil { + break + } + err = hc.prepareDocumentMatch(searchContext, reader, next, false) + if err != nil { + break + } + err = dmHandler(next) + if err != nil { + break + } } - next, err = searcher.Next(searchContext) } if err != nil { return err } + + // if we have a nested store, we may have an interim root + // that needs to be returned for processing + if hc.nestedStore != nil { + currRoot := hc.nestedStore.Current() + if currRoot != nil { + err = hc.adjustDocumentMatch(searchContext, reader, currRoot) + if err != nil { + return err + } + // no descendants at this point + err = hc.prepareDocumentMatch(searchContext, reader, currRoot, false) + if err != nil { + return err + } + + err = dmHandler(currRoot) + if err != nil { + return err + } + } + } + if hc.knnHits != nil { // we may have some knn hits left that did not match any of the top N tf-idf hits // we need to add them to the collector store to consider them as well. @@ -366,7 +455,10 @@ func (hc *TopNCollector) adjustDocumentMatch(ctx *search.SearchContext, return err } if knnHit, ok := hc.knnHits[d.ID]; ok { - d.Score, d.Expl = hc.computeNewScoreExpl(d, knnHit) + // we have a knn hit corresponding to this document + hc.hybridMergeCallback(d, knnHit) + // remove this knn hit from the map as it's already + // been merged delete(hc.knnHits, d.ID) } } @@ -501,6 +593,14 @@ func (hc *TopNCollector) visitFieldTerms(reader index.IndexReader, d *search.Doc } } + // first visit descendants if any + for _, descID := range d.Descendants { + err := hc.dvReader.VisitDocValues(descID, v) + if err != nil { + return err + } + } + // now visit the doc values for this document err := hc.dvReader.VisitDocValues(d.IndexInternalID, v) if hc.facetsBuilder != nil { hc.facetsBuilder.EndDoc() @@ -579,10 +679,10 @@ func (hc *TopNCollector) FacetResults() search.FacetResults { return nil } -func (hc *TopNCollector) SetKNNHits(knnHits search.DocumentMatchCollection, newScoreExplComputer search.ScoreExplCorrectionCallbackFunc) { +func (hc *TopNCollector) SetKNNHits(knnHits search.DocumentMatchCollection, hybridMergeCallback search.HybridMergeCallbackFn) { hc.knnHits = make(map[string]*search.DocumentMatch, len(knnHits)) for _, hit := range knnHits { hc.knnHits[hit.ID] = hit } - hc.computeNewScoreExpl = newScoreExplComputer + hc.hybridMergeCallback = hybridMergeCallback } diff --git a/search/explanation.go b/search/explanation.go index 924050016..98c5e099d 100644 --- a/search/explanation.go +++ b/search/explanation.go @@ -29,6 +29,8 @@ func init() { reflectStaticSizeExplanation = int(reflect.TypeOf(e).Size()) } +const MergedExplMessage = "sum of merged explanations:" + type Explanation struct { Value float64 `json:"value"` Message string `json:"message"` @@ -54,3 +56,50 @@ func (expl *Explanation) Size() int { return sizeInBytes } + +// MergeExpl merges two explanations into one. +// If either explanation is nil, the other is returned. +// If the first explanation is already a merged explanation, +// the second explanation is appended to its children. +// Otherwise, a new merged explanation is created +// with the two explanations as its children. +func (expl *Explanation) MergeWith(other *Explanation) *Explanation { + if expl == nil { + return other + } + if other == nil || expl == other { + return expl + } + + newScore := expl.Value + other.Value + + // if both are merged explanations, combine children + if expl.Message == MergedExplMessage && other.Message == MergedExplMessage { + expl.Value = newScore + expl.Children = append(expl.Children, other.Children...) + return expl + } + + // atleast one is not a merged explanation see which one it is + // if expl is merged, append other + if expl.Message == MergedExplMessage { + // append other as a child to first + expl.Value = newScore + expl.Children = append(expl.Children, other) + return expl + } + + // if other is merged, append expl + if other.Message == MergedExplMessage { + other.Value = newScore + other.Children = append(other.Children, expl) + return other + } + // create a new explanation to hold the merged one + rv := &Explanation{ + Value: expl.Value + other.Value, + Message: MergedExplMessage, + Children: []*Explanation{expl, other}, + } + return rv +} diff --git a/search/facet/facet_builder_terms_test.go b/search/facet/facet_builder_terms_test.go index fad4be301..3ed2fcccb 100644 --- a/search/facet/facet_builder_terms_test.go +++ b/search/facet/facet_builder_terms_test.go @@ -201,11 +201,11 @@ func TestTermsFacetPrefixAndRegex(t *testing.T) { terms := []string{ "env:prod", "env:staging", - "env:dev", // has prefix but doesn't match regex - "env:test", // has prefix but doesn't match regex - "type:server", // no prefix - "env:prod", // duplicate - "env:staging", // duplicate + "env:dev", // has prefix but doesn't match regex + "env:test", // has prefix but doesn't match regex + "type:server", // no prefix + "env:prod", // duplicate + "env:staging", // duplicate } for _, term := range terms { diff --git a/search/highlight/highlighter/simple/highlighter_simple.go b/search/highlight/highlighter/simple/highlighter_simple.go index e898a1e61..d0adfa81f 100644 --- a/search/highlight/highlighter/simple/highlighter_simple.go +++ b/search/highlight/highlighter/simple/highlighter_simple.go @@ -146,12 +146,8 @@ func (s *Highlighter) BestFragmentsInField(dm *search.DocumentMatch, doc index.D formattedFragments[i] += s.sep } } - - if dm.Fragments == nil { - dm.Fragments = make(search.FieldFragmentMap, 0) - } if len(formattedFragments) > 0 { - dm.Fragments[field] = formattedFragments + dm.AddFragments(field, formattedFragments) } return formattedFragments diff --git a/search/query/conjunction.go b/search/query/conjunction.go index a2043720a..6870b1ae2 100644 --- a/search/query/conjunction.go +++ b/search/query/conjunction.go @@ -54,14 +54,39 @@ func (q *ConjunctionQuery) AddQuery(aq ...Query) { func (q *ConjunctionQuery) Searcher(ctx context.Context, i index.IndexReader, m mapping.IndexMapping, options search.SearcherOptions) (search.Searcher, error) { ss := make([]search.Searcher, 0, len(q.Conjuncts)) + cleanup := func() { + for _, searcher := range ss { + if searcher != nil { + _ = searcher.Close() + } + } + } + nestedMode, _ := ctx.Value(search.NestedSearchKey).(bool) + var nm mapping.NestedMapping + if nestedMode { + var ok bool + // get the nested mapping + if nm, ok = m.(mapping.NestedMapping); !ok { + // shouldn't be in nested mode if no nested mapping + nestedMode = false + } + } + // set of fields used in this query + var qfs search.FieldSet + var err error + for _, conjunct := range q.Conjuncts { + // Gather fields when nested mode is enabled + if nestedMode { + qfs, err = ExtractFields(conjunct, m, qfs) + if err != nil { + cleanup() + return nil, err + } + } sr, err := conjunct.Searcher(ctx, i, m, options) if err != nil { - for _, searcher := range ss { - if searcher != nil { - _ = searcher.Close() - } - } + cleanup() return nil, err } if _, ok := sr.(*searcher.MatchNoneSearcher); ok && q.queryStringMode { @@ -75,6 +100,17 @@ func (q *ConjunctionQuery) Searcher(ctx context.Context, i index.IndexReader, m return searcher.NewMatchNoneSearcher(i) } + if nestedMode { + // first determine the nested depth info for the query fields + commonDepth, maxDepth := nm.NestedDepth(qfs) + // if we have common depth == max depth then we can just use + // the normal conjunction searcher, as all fields share the same + // nested context, otherwise we need to use the nested conjunction searcher + if commonDepth < maxDepth { + return searcher.NewNestedConjunctionSearcher(ctx, i, ss, commonDepth, options) + } + } + return searcher.NewConjunctionSearcher(ctx, i, ss, options) } diff --git a/search/query/knn.go b/search/query/knn.go index ea3d38ce4..ea8780a41 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -53,7 +53,7 @@ func (q *KNNQuery) SetK(k int64) { q.K = k } -func (q *KNNQuery) SetFieldVal(field string) { +func (q *KNNQuery) SetField(field string) { q.VectorField = field } diff --git a/search/query/query.go b/search/query/query.go index 27c3978b1..06e924882 100644 --- a/search/query/query.go +++ b/search/query/query.go @@ -455,13 +455,10 @@ func DumpQuery(m mapping.IndexMapping, query Query) (string, error) { return string(data), err } -// FieldSet represents a set of queried fields. -type FieldSet map[string]struct{} - // ExtractFields returns a set of fields referenced by the query. // The returned set may be nil if the query does not explicitly reference any field // and the DefaultSearchField is unset in the index mapping. -func ExtractFields(q Query, m mapping.IndexMapping, fs FieldSet) (FieldSet, error) { +func ExtractFields(q Query, m mapping.IndexMapping, fs search.FieldSet) (search.FieldSet, error) { if q == nil || m == nil { return fs, nil } @@ -474,9 +471,9 @@ func ExtractFields(q Query, m mapping.IndexMapping, fs FieldSet) (FieldSet, erro } if f != "" { if fs == nil { - fs = make(FieldSet) + fs = search.NewFieldSet() } - fs[f] = struct{}{} + fs.AddField(f) } case *QueryStringQuery: var expandedQuery Query diff --git a/search/scorer/scorer_knn.go b/search/scorer/scorer_knn.go index 8d9043427..06f50cd4a 100644 --- a/search/scorer/scorer_knn.go +++ b/search/scorer/scorer_knn.go @@ -123,7 +123,7 @@ func (sqs *KNNQueryScorer) Score(ctx *search.SearchContext, if sqs.options.Explain { rv.Expl = scoreExplanation } - rv.IndexInternalID = append(rv.IndexInternalID, knnMatch.ID...) + rv.IndexInternalID = index.NewIndexInternalIDFrom(rv.IndexInternalID, knnMatch.ID) return rv } diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index f5f8ec935..d7e77f977 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -243,7 +243,7 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } } - rv.IndexInternalID = append(rv.IndexInternalID, termMatch.ID...) + rv.IndexInternalID = index.NewIndexInternalIDFrom(rv.IndexInternalID, termMatch.ID) if len(termMatch.Vectors) > 0 { if cap(rv.FieldTermLocations) < len(termMatch.Vectors) { diff --git a/search/search.go b/search/search.go index 724025787..541bbe42a 100644 --- a/search/search.go +++ b/search/search.go @@ -165,9 +165,9 @@ type DocumentMatch struct { // used to indicate the sub-scores that combined to form the // final score for this document match. This is only populated - // when the search request's query is a DisjunctionQuery - // or a ConjunctionQuery. The map key is the index of the sub-query - // in the DisjunctionQuery or ConjunctionQuery. The map value is the + // when the search request's query is a DisjunctionQuery. + // The map key is the index of the sub-query + // in the DisjunctionQuery. The map value is the // sub-score for that sub-query. ScoreBreakdown map[int]float64 `json:"score_breakdown,omitempty"` @@ -178,6 +178,10 @@ type DocumentMatch struct { // of the index that this match came from // of the current alias view, used in alias of aliases scenario IndexNames []string `json:"index_names,omitempty"` + + // Descendants holds the IDs of any child/descendant document that contributed + // to this root DocumentMatch. + Descendants []index.IndexInternalID `json:"-"` } func (dm *DocumentMatch) AddFieldValue(name string, value interface{}) { @@ -201,6 +205,21 @@ func (dm *DocumentMatch) AddFieldValue(name string, value interface{}) { dm.Fields[name] = valSlice } +func (dm *DocumentMatch) AddFragments(field string, fragments []string) { + if dm.Fragments == nil { + dm.Fragments = make(FieldFragmentMap) + } +OUTER: + for _, newFrag := range fragments { + for _, existingFrag := range dm.Fragments[field] { + if existingFrag == newFrag { + continue OUTER // no duplicates allowed + } + } + dm.Fragments[field] = append(dm.Fragments[field], newFrag) + } +} + // Reset allows an already allocated DocumentMatch to be reused func (dm *DocumentMatch) Reset() *DocumentMatch { // remember the []byte used for the IndexInternalID @@ -218,6 +237,11 @@ func (dm *DocumentMatch) Reset() *DocumentMatch { scoreBreakdown := dm.ScoreBreakdown // clear out the score breakdown map clear(scoreBreakdown) + // remember the Descendants backing array + descendants := dm.Descendants + for i := range descendants { // recycle each IndexInternalID + descendants[i] = descendants[i][:0] + } // idiom to copy over from empty DocumentMatch (0 allocations) *dm = DocumentMatch{} // reuse the []byte already allocated (and reset len to 0) @@ -228,6 +252,8 @@ func (dm *DocumentMatch) Reset() *DocumentMatch { dm.DecodedSort = decodedSort[:0] // reuse the FieldTermLocations already allocated (and reset len to 0) dm.FieldTermLocations = ftls[:0] + // reuse the Descendants already allocated (and reset len to 0) + dm.Descendants = descendants[:0] // reuse the score breakdown map already allocated (after clearing it) dm.ScoreBreakdown = scoreBreakdown return dm @@ -402,3 +428,20 @@ func (sc *SearchContext) Size() int { return sizeInBytes } + +// A NestedDocumentMatch is like a DocumentMatch but used for nested documents +// and does not have score or locations, or a score and is mainly used to +// hold field values and fragments, to be embedded in the parent DocumentMatch +type NestedDocumentMatch struct { + Fields map[string]interface{} `json:"fields,omitempty"` + Fragments FieldFragmentMap `json:"fragments,omitempty"` +} + +// NewNestedDocumentMatch creates a new NestedDocumentMatch instance +// with the given fields and fragments +func NewNestedDocumentMatch(fields map[string]interface{}, fragments FieldFragmentMap) *NestedDocumentMatch { + return &NestedDocumentMatch{ + Fields: fields, + Fragments: fragments, + } +} diff --git a/search/searcher/search_conjunction_nested.go b/search/searcher/search_conjunction_nested.go new file mode 100644 index 000000000..688142e13 --- /dev/null +++ b/search/searcher/search_conjunction_nested.go @@ -0,0 +1,499 @@ +// Copyright (c) 2025 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package searcher + +import ( + "context" + "fmt" + "math" + "reflect" + "slices" + + "github.com/blevesearch/bleve/v2/search" + "github.com/blevesearch/bleve/v2/size" + index "github.com/blevesearch/bleve_index_api" +) + +var reflectStaticSizeNestedConjunctionSearcher int + +func init() { + var ncs NestedConjunctionSearcher + reflectStaticSizeNestedConjunctionSearcher = int(reflect.TypeOf(ncs).Size()) +} + +type NestedConjunctionSearcher struct { + nestedReader index.NestedReader + searchers []search.Searcher + queryNorm float64 + currs []*search.DocumentMatch + currAncestors [][]index.AncestorID + currKeys []index.AncestorID + initialized bool + joinIdx int + options search.SearcherOptions + docQueue *CoalesceQueue + // reusable ID buffer for Advance() calls + advanceID index.IndexInternalID + // reusable buffer for Advance() calls + ancestors []index.AncestorID +} + +func NewNestedConjunctionSearcher(ctx context.Context, indexReader index.IndexReader, + searchers []search.Searcher, joinIdx int, options search.SearcherOptions) (search.Searcher, error) { + + var nr index.NestedReader + var ok bool + if nr, ok = indexReader.(index.NestedReader); !ok { + return nil, fmt.Errorf("indexReader does not support nested documents") + } + + // build our searcher + rv := NestedConjunctionSearcher{ + nestedReader: nr, + options: options, + searchers: searchers, + currs: make([]*search.DocumentMatch, len(searchers)), + currAncestors: make([][]index.AncestorID, len(searchers)), + currKeys: make([]index.AncestorID, len(searchers)), + joinIdx: joinIdx, + docQueue: NewCoalesceQueue(), + } + rv.computeQueryNorm() + + return &rv, nil +} + +func (s *NestedConjunctionSearcher) computeQueryNorm() { + // first calculate sum of squared weights + sumOfSquaredWeights := 0.0 + for _, searcher := range s.searchers { + sumOfSquaredWeights += searcher.Weight() + } + // now compute query norm from this + s.queryNorm = 1.0 / math.Sqrt(sumOfSquaredWeights) + // finally tell all the downstream searchers the norm + for _, searcher := range s.searchers { + searcher.SetQueryNorm(s.queryNorm) + } +} + +func (s *NestedConjunctionSearcher) Size() int { + sizeInBytes := reflectStaticSizeNestedConjunctionSearcher + size.SizeOfPtr + + for _, entry := range s.searchers { + sizeInBytes += entry.Size() + } + + for _, entry := range s.currs { + if entry != nil { + sizeInBytes += entry.Size() + } + } + + return sizeInBytes +} + +func (s *NestedConjunctionSearcher) Weight() float64 { + var rv float64 + for _, searcher := range s.searchers { + rv += searcher.Weight() + } + return rv +} + +func (s *NestedConjunctionSearcher) SetQueryNorm(qnorm float64) { + for _, searcher := range s.searchers { + searcher.SetQueryNorm(qnorm) + } +} + +func (s *NestedConjunctionSearcher) Count() uint64 { + // for now return a worst case + var sum uint64 + for _, searcher := range s.searchers { + sum += searcher.Count() + } + return sum +} + +func (s *NestedConjunctionSearcher) Close() (rv error) { + for _, searcher := range s.searchers { + err := searcher.Close() + if err != nil && rv == nil { + rv = err + } + } + return rv +} + +func (s *NestedConjunctionSearcher) Min() int { + return 0 +} + +func (s *NestedConjunctionSearcher) DocumentMatchPoolSize() int { + rv := len(s.currs) + for _, s := range s.searchers { + rv += s.DocumentMatchPoolSize() + } + return rv +} + +func (s *NestedConjunctionSearcher) initialize(ctx *search.SearchContext) (bool, error) { + var err error + for i, searcher := range s.searchers { + if s.currs[i] != nil { + ctx.DocumentMatchPool.Put(s.currs[i]) + } + s.currs[i], err = searcher.Next(ctx) + if err != nil { + return false, err + } + if s.currs[i] == nil { + // one of the searchers is exhausted, so we are done + return true, nil + } + // get the ancestry chain for this match + s.currAncestors[i], err = s.nestedReader.Ancestors(s.currs[i].IndexInternalID, s.currAncestors[i][:0]) + if err != nil { + return false, err + } + // check if the ancestry chain is > joinIdx, if not we reset the joinIdx + // to the minimum possible value across all searchers, ideally this will be + // done in query construction time itself, by using the covering depth across + // all sub-queries, but we do this here as a fallback + if s.joinIdx >= len(s.currAncestors[i]) { + s.joinIdx = len(s.currAncestors[i]) - 1 + } + } + // build currKeys for each searcher, do it here as we may have adjusted joinIdx + for i := range s.searchers { + s.currKeys[i] = ancestorFromRoot(s.currAncestors[i], s.joinIdx) + } + s.initialized = true + return false, nil +} + +func (s *NestedConjunctionSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch, error) { + // initialize on first call to Next, by getting first match + // from each searcher and their ancestry chains + if !s.initialized { + done, err := s.initialize(ctx) + if err != nil { + return nil, err + } + if done { + return nil, nil + } + } + // check if the docQueue has any buffered matches + if s.docQueue.Len() > 0 { + return s.docQueue.Dequeue() + } + // now enter the main alignment loop + n := len(s.searchers) +OUTER: + for { + // pick the pivot searcher with the highest key (ancestor at joinIdx level) + if s.currs[0] == nil { + return nil, nil + } + maxKey := s.currKeys[0] + for i := 1; i < n; i++ { + // currs[i] is nil means one of the searchers is exhausted + if s.currs[i] == nil { + return nil, nil + } + currKey := s.currKeys[i] + if maxKey.Compare(currKey) < 0 { + maxKey = currKey + } + } + // store maxkey as advanceID only once only if needed + var advanceID index.IndexInternalID + // flag to track if all searchers are aligned + var aligned bool = true + // now try to align all other searchers to the + // we check if the a searchers key matches maxKey + // if not, we advance the pivot searcher to maxKey + // else do nothing and move to the next searcher + for i := 0; i < n; i++ { + cmp := s.currKeys[i].Compare(maxKey) + if cmp < 0 { + // not aligned, so advance this searcher to maxKey + // convert maxKey to advanceID only once + if advanceID == nil { + advanceID = s.toAdvanceID(maxKey) + } + var err error + ctx.DocumentMatchPool.Put(s.currs[i]) + s.currs[i], err = s.searchers[i].Advance(ctx, advanceID) + if err != nil { + return nil, err + } + if s.currs[i] == nil { + // one of the searchers is exhausted, so we are done + return nil, nil + } + // recalc ancestors + s.currAncestors[i], err = s.nestedReader.Ancestors(s.currs[i].IndexInternalID, s.currAncestors[i][:0]) + if err != nil { + return nil, err + } + // recalc key + s.currKeys[i] = ancestorFromRoot(s.currAncestors[i], s.joinIdx) + // recalc cmp + cmp = s.currKeys[i].Compare(maxKey) + } + if cmp != 0 { + // not aligned + aligned = false + } + } + // now check if all the searchers are aligned at the same maxKey + // if they are not aligned, we need to restart the loop of picking + // the pivot searcher with the highest key + if !aligned { + continue OUTER + } + // if we are here, all the searchers are aligned at maxKey + // now we need to buffer all the intermediate matches for every + // searcher at this key, until either the searcher's key changes + // or the searcher is exhausted + for i := 0; i < n; i++ { + for { + // buffer the current match + recycle, err := s.docQueue.Enqueue(s.currs[i]) + if err != nil { + return nil, err + } + if recycle != nil { + // we got a match to recycle + ctx.DocumentMatchPool.Put(recycle) + } + // advance to next match + s.currs[i], err = s.searchers[i].Next(ctx) + if err != nil { + return nil, err + } + if s.currs[i] == nil { + // searcher exhausted, break out + break + } + // recalc ancestors + s.currAncestors[i], err = s.nestedReader.Ancestors(s.currs[i].IndexInternalID, s.currAncestors[i][:0]) + if err != nil { + return nil, err + } + // recalc key + s.currKeys[i] = ancestorFromRoot(s.currAncestors[i], s.joinIdx) + // check if key has changed + if !s.currKeys[i].Equals(maxKey) { + // key changed, break out + break + } + } + } + // finalize the docQueue for dequeueing + s.docQueue.Finalize() + // finally return the first buffered match + return s.docQueue.Dequeue() + } +} + +// ancestorFromRoot gets the AncestorID at the given position from the root +// if pos is 0, it returns the root AncestorID, and so on +func ancestorFromRoot(ancestors []index.AncestorID, pos int) index.AncestorID { + return ancestors[len(ancestors)-pos-1] +} + +// toAdvanceID converts an AncestorID to IndexInternalID, reusing the advanceID buffer. +// The returned ID is safe to pass to Advance() since Advance() never retains references. +func (s *NestedConjunctionSearcher) toAdvanceID(key index.AncestorID) index.IndexInternalID { + // Reset length to 0 while preserving capacity for buffer reuse + s.advanceID = s.advanceID[:0] + // Convert key to IndexInternalID, reusing the underlying buffer + s.advanceID = key.ToIndexInternalID(s.advanceID) + return s.advanceID +} + +func (s *NestedConjunctionSearcher) Advance(ctx *search.SearchContext, ID index.IndexInternalID) (*search.DocumentMatch, error) { + if !s.initialized { + done, err := s.initialize(ctx) + if err != nil { + return nil, err + } + if done { + return nil, nil + } + } + // first check if the docQueue has any buffered matches + // if so we first check if any of them can satisfy the Advance(ID) + for s.docQueue.Len() > 0 { + dm, err := s.docQueue.Dequeue() + if err != nil { + return nil, err + } + if dm.IndexInternalID.Compare(ID) >= 0 { + return dm, nil + } + // otherwise recycle this match + ctx.DocumentMatchPool.Put(dm) + } + var err error + // now we first get the ancestry chain for the given ID + s.ancestors, err = s.nestedReader.Ancestors(ID, s.ancestors[:0]) + if err != nil { + return nil, err + } + // we now follow the the following logic for each searcher: + // let S be the length of the ancestry chain for the searcher + // let I be the length of the ancestry chain for the given ID + // 1. if S > I: + // then we just Advance() the searcher to the given ID if required + // 2. else if S <= I: + // then we get the AncestorID at position (S - 1) from the root of + // the given ID's ancestry chain, and Advance() the searcher to + // it if required + for i, searcher := range s.searchers { + if s.currs[i] == nil { + return nil, nil // already exhausted, nothing to do + } + var targetID index.IndexInternalID + S := len(s.currAncestors[i]) + I := len(s.ancestors) + if S > I { + // case 1: S > I + targetID = ID + } else { + // case 2: S <= I + targetID = s.toAdvanceID(ancestorFromRoot(s.ancestors, S-1)) + } + if s.currs[i].IndexInternalID.Compare(targetID) < 0 { + // need to advance this searcher + ctx.DocumentMatchPool.Put(s.currs[i]) + s.currs[i], err = searcher.Advance(ctx, targetID) + if err != nil { + return nil, err + } + if s.currs[i] == nil { + // one of the searchers is exhausted, so we are done + return nil, nil + } + // recalc ancestors + s.currAncestors[i], err = s.nestedReader.Ancestors(s.currs[i].IndexInternalID, s.currAncestors[i][:0]) + if err != nil { + return nil, err + } + // recalc key + s.currKeys[i] = ancestorFromRoot(s.currAncestors[i], s.joinIdx) + } + } + // we need to call Next() in a loop until we reach or exceed the given ID + // the Next() call basically gives us a match that is aligned correctly, but + // if joinIdx < I, we can have multiple matches for the same joinIdx ancestor + // and they may be < ID, so we need to loop + for { + next, err := s.Next(ctx) + if err != nil { + return nil, err + } + if next == nil { + return nil, nil + } + if next.IndexInternalID.Compare(ID) >= 0 { + return next, nil + } + ctx.DocumentMatchPool.Put(next) + } +} + +// ------------------------------------------------------------------------------------------ +type CoalesceQueue struct { + order []*search.DocumentMatch // queue of DocumentMatch + items map[uint64]*search.DocumentMatch // map of ID to DocumentMatch +} + +func NewCoalesceQueue() *CoalesceQueue { + cq := &CoalesceQueue{ + order: make([]*search.DocumentMatch, 0), + items: make(map[uint64]*search.DocumentMatch), + } + return cq +} + +// Enqueue adds the given DocumentMatch to the queue. If a DocumentMatch with the same +// IndexInternalID already exists in the queue, it merges the scores and explanations, +// and returns the given DocumentMatch for recycling. If it's a new entry, it adds it +// to the queue and returns nil. +func (cq *CoalesceQueue) Enqueue(it *search.DocumentMatch) (*search.DocumentMatch, error) { + val, err := it.IndexInternalID.Value() + if err != nil { + // cannot coalesce without a valid uint64 ID + return nil, err + } + + if existing, ok := cq.items[val]; ok { + // merge with current version + existing.Score += it.Score + existing.Expl = existing.Expl.MergeWith(it.Expl) + existing.FieldTermLocations = search.MergeFieldTermLocationsFromMatch( + existing.FieldTermLocations, it) + // return it to caller for recycling + return it, nil + } + + // first time we see this ID — enqueue + cq.items[val] = it + // append to order slice (this is a stack) + cq.order = append(cq.order, it) + // no recycling needed as we added a new item + return nil, nil +} + +// Finalize prepares the queue for dequeue operations by sorting the items based on +// their IndexInternalID values. This MUST be called before any Dequeue operations, +// and after all Enqueue operations are complete. The sort is done in descending order +// so that dequeueing will basically be popping from the end of the slice, allowing for +// slice reuse. +func (cq *CoalesceQueue) Finalize() { + slices.SortFunc(cq.order, func(a, b *search.DocumentMatch) int { + return b.IndexInternalID.Compare(a.IndexInternalID) + }) +} + +// Dequeue removes and returns the next DocumentMatch from the queue in sorted order. +// If the queue is empty, it returns nil. +func (cq *CoalesceQueue) Dequeue() (*search.DocumentMatch, error) { + if cq.Len() == 0 { + return nil, nil + } + + // pop from end of slice + rv := cq.order[len(cq.order)-1] + cq.order = cq.order[:len(cq.order)-1] + + val, err := rv.IndexInternalID.Value() + if err != nil { + return nil, err + } + + delete(cq.items, val) + return rv, nil +} + +// Len returns the number of DocumentMatch items currently in the queue. +func (cq *CoalesceQueue) Len() int { + return len(cq.order) +} diff --git a/search/searcher/search_disjunction_heap.go b/search/searcher/search_disjunction_heap.go index 3da876bd3..4c68e5691 100644 --- a/search/searcher/search_disjunction_heap.go +++ b/search/searcher/search_disjunction_heap.go @@ -15,7 +15,6 @@ package searcher import ( - "bytes" "container/heap" "context" "math" @@ -169,7 +168,7 @@ func (s *DisjunctionHeapSearcher) updateMatches() error { matchingIdxs = append(matchingIdxs, next.matchingIdx) // now as long as top of heap matches, keep popping - for len(s.heap) > 0 && bytes.Compare(next.curr.IndexInternalID, s.heap[0].curr.IndexInternalID) == 0 { + for len(s.heap) > 0 && next.curr.IndexInternalID.Equals(s.heap[0].curr.IndexInternalID) { next = heap.Pop(s).(*SearcherCurr) matching = append(matching, next.curr) matchingCurrs = append(matchingCurrs, next) @@ -264,7 +263,7 @@ func (s *DisjunctionHeapSearcher) Advance(ctx *search.SearchContext, // find all searchers that actually need to be advanced // advance them, using s.matchingCurrs as temp storage - for len(s.heap) > 0 && bytes.Compare(s.heap[0].curr.IndexInternalID, ID) < 0 { + for len(s.heap) > 0 && s.heap[0].curr.IndexInternalID.Compare(ID) < 0 { searcherCurr := heap.Pop(s).(*SearcherCurr) ctx.DocumentMatchPool.Put(searcherCurr.curr) curr, err := searcherCurr.searcher.Advance(ctx, ID) @@ -347,7 +346,7 @@ func (s *DisjunctionHeapSearcher) Less(i, j int) bool { } else if s.heap[j].curr == nil { return false } - return bytes.Compare(s.heap[i].curr.IndexInternalID, s.heap[j].curr.IndexInternalID) < 0 + return s.heap[i].curr.IndexInternalID.Compare(s.heap[j].curr.IndexInternalID) < 0 } func (s *DisjunctionHeapSearcher) Swap(i, j int) { diff --git a/search/searcher/search_match_all.go b/search/searcher/search_match_all.go index 57d8d0727..57966a924 100644 --- a/search/searcher/search_match_all.go +++ b/search/searcher/search_match_all.go @@ -36,6 +36,8 @@ type MatchAllSearcher struct { reader index.DocIDReader scorer *scorer.ConstantScorer count uint64 + nested bool + ancestors []index.AncestorID } func NewMatchAllSearcher(ctx context.Context, indexReader index.IndexReader, boost float64, options search.SearcherOptions) (*MatchAllSearcher, error) { @@ -50,11 +52,15 @@ func NewMatchAllSearcher(ctx context.Context, indexReader index.IndexReader, boo } scorer := scorer.NewConstantScorer(1.0, boost, options) + // check if we are in nested mode + nested, _ := ctx.Value(search.NestedSearchKey).(bool) + return &MatchAllSearcher{ indexReader: indexReader, reader: reader, scorer: scorer, count: count, + nested: nested, }, nil } @@ -76,6 +82,23 @@ func (s *MatchAllSearcher) SetQueryNorm(qnorm float64) { s.scorer.SetQueryNorm(qnorm) } +func (s *MatchAllSearcher) isNested(id index.IndexInternalID) bool { + // if not running in nested mode, always return false + if !s.nested { + return false + } + var err error + // check if this doc has ancestors, if so it is nested + if nr, ok := s.reader.(index.NestedReader); ok { + s.ancestors, err = nr.Ancestors(id, s.ancestors[:0]) + if err != nil { + return false + } + return len(s.ancestors) > 1 + } + return false +} + func (s *MatchAllSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatch, error) { id, err := s.reader.Next() if err != nil { @@ -86,6 +109,11 @@ func (s *MatchAllSearcher) Next(ctx *search.SearchContext) (*search.DocumentMatc return nil, nil } + if s.isNested(id) { + // if nested then skip and get next + return s.Next(ctx) + } + // score match docMatch := s.scorer.Score(ctx, id) // return doc match @@ -103,6 +131,11 @@ func (s *MatchAllSearcher) Advance(ctx *search.SearchContext, ID index.IndexInte return nil, nil } + if s.isNested(id) { + // if nested then return next + return s.Next(ctx) + } + // score match docMatch := s.scorer.Score(ctx, id) diff --git a/search/searcher/search_numeric_range.go b/search/searcher/search_numeric_range.go index f086051c1..cd8f00719 100644 --- a/search/searcher/search_numeric_range.go +++ b/search/searcher/search_numeric_range.go @@ -132,7 +132,7 @@ func filterCandidateTerms(indexReader index.IndexReader, for err == nil && tfd != nil { termBytes := []byte(tfd.Term) i := sort.Search(len(terms), func(i int) bool { return bytes.Compare(terms[i], termBytes) >= 0 }) - if i < len(terms) && bytes.Compare(terms[i], termBytes) == 0 { + if i < len(terms) && bytes.Equal(terms[i], termBytes) { rv = append(rv, terms[i]) } terms = terms[i:] diff --git a/search/util.go b/search/util.go index 005fda67d..b12f7e780 100644 --- a/search/util.go +++ b/search/util.go @@ -50,41 +50,54 @@ func MergeTermLocationMaps(rv, other TermLocationMap) TermLocationMap { func MergeFieldTermLocations(dest []FieldTermLocation, matches []*DocumentMatch) []FieldTermLocation { n := len(dest) for _, dm := range matches { - n += len(dm.FieldTermLocations) + if dm != nil { + n += len(dm.FieldTermLocations) + } } if cap(dest) < n { dest = append(make([]FieldTermLocation, 0, n), dest...) } for _, dm := range matches { - for _, ftl := range dm.FieldTermLocations { - dest = append(dest, FieldTermLocation{ - Field: ftl.Field, - Term: ftl.Term, - Location: Location{ - Pos: ftl.Location.Pos, - Start: ftl.Location.Start, - End: ftl.Location.End, - ArrayPositions: append(ArrayPositions(nil), ftl.Location.ArrayPositions...), - }, - }) + if dm != nil { + dest = mergeFieldTermLocationFromMatch(dest, dm) } } return dest } -type SearchIOStatsCallbackFunc func(uint64) +// MergeFieldTermLocationsFromMatch merges field term locations from a single DocumentMatch +// into dest, returning the updated slice. +func MergeFieldTermLocationsFromMatch(dest []FieldTermLocation, match *DocumentMatch) []FieldTermLocation { + if match == nil { + return dest + } + n := len(dest) + len(match.FieldTermLocations) + if cap(dest) < n { + dest = append(make([]FieldTermLocation, 0, n), dest...) + } + return mergeFieldTermLocationFromMatch(dest, match) +} + +// mergeFieldTermLocationFromMatch appends field term locations from a DocumentMatch into dest. +// Assumes dest has sufficient capacity. +func mergeFieldTermLocationFromMatch(dest []FieldTermLocation, dm *DocumentMatch) []FieldTermLocation { + for _, ftl := range dm.FieldTermLocations { + dest = append(dest, FieldTermLocation{ + Field: ftl.Field, + Term: ftl.Term, + Location: Location{ + Pos: ftl.Location.Pos, + Start: ftl.Location.Start, + End: ftl.Location.End, + ArrayPositions: append(ArrayPositions(nil), ftl.Location.ArrayPositions...), + }, + }) + } -// Implementation of SearchIncrementalCostCallbackFn should handle the following messages -// - add: increment the cost of a search operation -// (which can be specific to a query type as well) -// - abort: query was aborted due to a cancel of search's context (for eg), -// which can be handled differently as well -// - done: indicates that a search was complete and the tracked cost can be -// handled safely by the implementation. -type SearchIncrementalCostCallbackFn func(SearchIncrementalCostCallbackMsg, - SearchQueryType, uint64) + return dest +} type ( SearchIncrementalCostCallbackMsg uint @@ -156,6 +169,10 @@ const ( // ScoreFusionKey is used to communicate whether KNN hits need to be preserved for // hybrid search algorithms (like RRF) ScoreFusionKey ContextKey = "_fusion_rescoring_key" + + // NestedSearchKey is used to communicate whether the search is performed + // in an index with nested documents + NestedSearchKey ContextKey = "_nested_search_key" ) func RecordSearchCost(ctx context.Context, @@ -184,9 +201,7 @@ const ( MinGeoBufPoolSize = 24 ) -type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool - -// *PreSearchDataKey are used to store the data gathered during the presearch phase +// PreSearchDataKey are used to store the data gathered during the presearch phase // which would be use in the actual search phase. const ( KnnPreSearchDataKey = "_knn_pre_search_data_key" @@ -197,14 +212,39 @@ const ( const GlobalScoring = "_global_scoring" type ( + // SearcherStartCallbackFn is a callback function type used to signal the start of + // searcher creation phase. SearcherStartCallbackFn func(size uint64) error - SearcherEndCallbackFn func(size uint64) error + // SearcherEndCallbackFn is a callback function type used to signal the end of + // a searcher creation phase. + SearcherEndCallbackFn func(size uint64) error + // GetScoringModelCallbackFn is a callback function type used to get the scoring model + // to be used for scoring documents during search. + GetScoringModelCallbackFn func() string + // HybridMergeCallbackFn is a callback function type used to merge a KNN document match + // into a full text search document match, of the same docID as part of hybrid search. + HybridMergeCallbackFn func(ftsMatch *DocumentMatch, knnMatch *DocumentMatch) + // DescendantAdderCallback is a callback function type used to customize how a descendant + // DocumentMatch is merged into its parent. This allows different descendant addition strategies for + // different use cases (e.g., TopN vs KNN collection). + DescendantAdderCallbackFn func(parent *DocumentMatch, descendant *DocumentMatch) error + // GeoBufferPoolCallbackFunc is a callback function type used to get the geo buffer pool + // to be used during geo searches. + GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool + // SearchIOStatsCallbackFunc is a callback function type used to report search IO stats + // during search. + SearchIOStatsCallbackFunc func(uint64) + // Implementation of SearchIncrementalCostCallbackFn should handle the following messages + // - add: increment the cost of a search operation + // (which can be specific to a query type as well) + // - abort: query was aborted due to a cancel of search's context (for eg), + // which can be handled differently as well + // - done: indicates that a search was complete and the tracked cost can be + // handled safely by the implementation. + SearchIncrementalCostCallbackFn func(SearchIncrementalCostCallbackMsg, + SearchQueryType, uint64) ) -type GetScoringModelCallbackFn func() string - -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) - // field -> term -> synonyms type FieldTermSynonymMap map[string]map[string][]string @@ -237,3 +277,25 @@ type BM25Stats struct { DocCount float64 `json:"doc_count"` FieldCardinality map[string]int `json:"field_cardinality"` } + +// FieldSet represents a set of queried fields. +type FieldSet map[string]struct{} + +// NewFieldSet creates a new FieldSet. +func NewFieldSet() FieldSet { + return make(map[string]struct{}) +} + +// Add adds a field to the set. +func (fs FieldSet) AddField(field string) { + fs[field] = struct{}{} +} + +// Slice returns the fields in this set as a slice of strings. +func (fs FieldSet) Slice() []string { + rv := make([]string, 0, len(fs)) + for field := range fs { + rv = append(rv, field) + } + return rv +} diff --git a/search_knn.go b/search_knn.go index fae4f52e9..203d02629 100644 --- a/search_knn.go +++ b/search_knn.go @@ -288,10 +288,15 @@ func createKNNQuery(req *SearchRequest, knnFilterResults map[int]index.EligibleD // If it's a filtered kNN but has no eligible filter hits, then // do not run the kNN query. if selector, exists := knnFilterResults[i]; exists && selector == nil { + // if the kNN query is filtered and has no eligible filter hits, then + // do not run the kNN query, so we add a match_none query to the subQueries. + // this will ensure that the score breakdown is set to 0 for this kNN query. + subQueries = append(subQueries, NewMatchNoneQuery()) + kArray = append(kArray, 0) continue } knnQuery := query.NewKNNQuery(knn.Vector) - knnQuery.SetFieldVal(knn.Field) + knnQuery.SetField(knn.Field) knnQuery.SetK(knn.K) knnQuery.SetBoost(knn.Boost.Value()) knnQuery.SetParams(knn.Params) @@ -372,7 +377,7 @@ func addSortAndFieldsToKNNHits(req *SearchRequest, knnHits []*search.DocumentMat } } req.Sort.Value(hit) - err, _ = LoadAndHighlightFields(hit, req, "", reader, nil) + err, _ = LoadAndHighlightAllFields(hit, req, "", reader, nil) if err != nil { return err } @@ -469,17 +474,15 @@ func (i *indexImpl) runKnnCollector(ctx context.Context, req *SearchRequest, rea return knnHits, nil } -func setKnnHitsInCollector(knnHits []*search.DocumentMatch, req *SearchRequest, coll *collector.TopNCollector) { +func setKnnHitsInCollector(knnHits []*search.DocumentMatch, coll *collector.TopNCollector) { if len(knnHits) > 0 { - newScoreExplComputer := func(queryMatch *search.DocumentMatch, knnMatch *search.DocumentMatch) (float64, *search.Explanation) { - totalScore := queryMatch.Score + knnMatch.Score - if !req.Explain { - // exit early as we don't need to compute the explanation - return totalScore, nil - } - return totalScore, &search.Explanation{Value: totalScore, Message: "sum of:", Children: []*search.Explanation{queryMatch.Expl, knnMatch.Expl}} + mergeFn := func(ftsMatch *search.DocumentMatch, knnMatch *search.DocumentMatch) { + // Boost the FTS score using the KNN score + ftsMatch.Score += knnMatch.Score + // Combine the FTS explanation with the KNN explanation, if present + ftsMatch.Expl.MergeWith(knnMatch.Expl) } - coll.SetKNNHits(knnHits, search.ScoreExplCorrectionCallbackFunc(newScoreExplComputer)) + coll.SetKNNHits(knnHits, search.HybridMergeCallbackFn(mergeFn)) } } diff --git a/search_knn_test.go b/search_knn_test.go index f518d337e..988f8ec73 100644 --- a/search_knn_test.go +++ b/search_knn_test.go @@ -1281,23 +1281,29 @@ func TestKNNScoreBoosting(t *testing.T) { searchRequest.AddKNN("vector", queryVec, 3, 1.0) searchRequest.Fields = []string{"content", "vector"} - hits, _ := index.Search(searchRequest) + hits, err := index.Search(searchRequest) + if err != nil { + t.Fatal(err) + } hitsMap := make(map[string]float64, 0) for _, hit := range hits.Hits { hitsMap[hit.ID] = (hit.Score) } - searchRequest2 := NewSearchRequest(NewMatchNoneQuery()) + searchRequest = NewSearchRequest(NewMatchNoneQuery()) searchRequest.AddKNN("vector", queryVec, 3, 10.0) searchRequest.Fields = []string{"content", "vector"} - hits2, _ := index.Search(searchRequest2) + hits, err = index.Search(searchRequest) + if err != nil { + t.Fatal(err) + } hitsMap2 := make(map[string]float64, 0) - for _, hit := range hits2.Hits { + for _, hit := range hits.Hits { hitsMap2[hit.ID] = (hit.Score) } - for _, hit := range hits2.Hits { + for _, hit := range hits.Hits { if hitsMap[hit.ID] != hitsMap2[hit.ID]/10 { t.Errorf("boosting not working: %v %v \n", hitsMap[hit.ID], hitsMap2[hit.ID]) } @@ -1645,6 +1651,347 @@ func TestNestedVectors(t *testing.T) { } } +// ----------------------------------------------------------------------------- +// TestMultiVector tests the KNN functionality which handles duplicate +// vectors being matched within the same document. When a document has multiple vectors +// (via [[]] array of vectors or [{}] array of objects with vectors), the KNN +// searcher must pick the best scoring vector match for that document. This test covers these scenarios: +// - Single vector field (baseline) +// - [[]] style: array of vectors (same doc appears multiple times) +// - [{}] style: array of objects with vector field (chunks pattern) +func TestMultiVector(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + // JSON documents covering merger scenarios: + // - Single vector (baseline) + // - [[]] style: array of vectors (same doc appears multiple times) + // - [{}] style: array of objects with vector field (chunks pattern) + docs := map[string]string{ + // Single vector - baseline + "doc1": `{ + "vec": [10, 10, 10], + "vecB": [100, 100, 100] + }`, + // [[]] style - array of 2 vectors + "doc2": `{ + "vec": [[0, 0, 0], [500, 500, 500]], + "vecB": [[900, 900, 900], [950, 950, 950], [975, 975, 975], [990, 990, 990]] + }`, + // [[]] style - array of 3 vectors + "doc3": `{ + "vec": [[50, 50, 50], [200, 200, 200], [400, 400, 400]], + "vecB": [[800, 800, 800], [850, 850, 850]] + }`, + // Single vector - baseline + "doc4": `{ + "vec": [1000, 1000, 1000], + "vecB": [1, 1, 1] + }`, + // [{}] style - array of objects with vector field (chunks pattern) + "doc5": `{ + "chunks": [ + {"vec": [10, 10, 10], "text": "chunk1"}, + {"vec": [20, 20, 20], "text": "chunk2"}, + {"vec": [30, 30, 30], "text": "chunk3"}, + {"vec": [40, 40, 40], "text": "chunk4"} + ] + }`, + "doc6": `{ + "chunks": [ + {"vec": [[10, 10, 10],[20, 20, 20]], "text": "chunk1"}, + {"vec": [[30, 30, 30],[40, 40, 40]], "text": "chunk2"} + ] + }`, + } + + // Parse JSON documents + dataset := make(map[string]map[string]interface{}) + for docID, jsonStr := range docs { + var doc map[string]interface{} + if err := json.Unmarshal([]byte(jsonStr), &doc); err != nil { + t.Fatalf("failed to unmarshal %s: %v", docID, err) + } + dataset[docID] = doc + } + + // Index mapping + indexMapping := NewIndexMapping() + + vecMapping := mapping.NewVectorFieldMapping() + vecMapping.Dims = 3 + vecMapping.Similarity = index.InnerProduct + indexMapping.DefaultMapping.AddFieldMappingsAt("vec", vecMapping) + indexMapping.DefaultMapping.AddFieldMappingsAt("vecB", vecMapping) + + // Nested chunks mapping for [{}] style + chunksMapping := mapping.NewDocumentMapping() + chunksMapping.AddFieldMappingsAt("vec", vecMapping) + indexMapping.DefaultMapping.AddSubDocumentMapping("chunks", chunksMapping) + + // Create and populate index + idx, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := idx.Close(); err != nil { + t.Fatal(err) + } + }() + + batch := idx.NewBatch() + for docID, doc := range dataset { + if err := batch.Index(docID, doc); err != nil { + t.Fatal(err) + } + } + if err := idx.Batch(batch); err != nil { + t.Fatal(err) + } + + // Test: Single KNN query - basic functionality + t.Run("VecFieldSingle", func(t *testing.T) { + searchReq := NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vec", []float32{1, 1, 1}, 20, 1.0) + res, err := idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + // Inner product: score = sum(query_i * doc_i) + // doc1 vec=[10,10,10]: 1*10*3 = 30 + // doc2 vec best is [500,500,500]: 1*500*3 = 1500 + // doc3 vec best is [400,400,400]: 1*400*3 = 1200 + // doc4 vec=[1000,1000,1000]: 1*1000*3 = 3000 + expectedResult := []struct { + docID string + expectedScore float64 + }{ + {docID: "doc4", expectedScore: 3000}, + {docID: "doc2", expectedScore: 1500}, + {docID: "doc3", expectedScore: 1200}, + {docID: "doc1", expectedScore: 30}, + } + + if len(res.Hits) != len(expectedResult) { + t.Fatalf("expected %d hits, got %d", len(expectedResult), len(res.Hits)) + } + + for i, expected := range expectedResult { + if res.Hits[i].ID != expected.docID { + t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expected.docID, res.Hits[i].ID) + } + if res.Hits[i].Score != expected.expectedScore { + t.Fatalf("at rank %d, expected score %v, got %v", i+1, expected.expectedScore, res.Hits[i].Score) + } + } + }) + + // Test: Single KNN query on vecB field + t.Run("VecBFieldSingle", func(t *testing.T) { + searchReq := NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vecB", []float32{1000, 1000, 1000}, 20, 1.0) + res, err := idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + // Inner product: score = sum(query_i * doc_i) for each dimension + // doc1: vecB=[100,100,100] -> 1000*100*3 = 300,000 + // doc2: vecB best is [990,990,990] -> 1000*990*3 = 2,970,000 + // doc3: vecB best is [850,850,850] -> 1000*850*3 = 2,550,000 + // doc4: vecB=[1,1,1] -> 1000*1*3 = 3,000 + expectedResult := []struct { + docID string + expectedScore float64 + }{ + {docID: "doc2", expectedScore: 2970000}, + {docID: "doc3", expectedScore: 2550000}, + {docID: "doc1", expectedScore: 300000}, + {docID: "doc4", expectedScore: 3000}, + } + + if len(res.Hits) != len(expectedResult) { + t.Fatalf("expected %d hits, got %d", len(expectedResult), len(res.Hits)) + } + + for i, expected := range expectedResult { + if res.Hits[i].ID != expected.docID { + t.Fatalf("at rank %d, expected docID %s, got %s", i+1, expected.docID, res.Hits[i].ID) + } + if res.Hits[i].Score != expected.expectedScore { + t.Fatalf("at rank %d, expected score %v, got %v", i+1, expected.expectedScore, res.Hits[i].Score) + } + } + }) + + // Test: Single KNN query on nested chunks.vec field + t.Run("ChunksVecFieldSingle", func(t *testing.T) { + searchReq := NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("chunks.vec", []float32{1, 1, 1}, 20, 1.0) + searchReq.SortBy([]string{"_score", "docID"}) + res, err := idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + + // Only doc5 and doc6 have chunks.vec + // doc5 chunks: [10,10,10], [20,20,20], [30,30,30], [40,40,40] + // Best score: 1*40*3 = 120 + // doc6 chunks: [[10,10,10],[20,20,20]], [[30,30,30],[40,40,40]] + // Best score: 1*40*3 = 120 + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + + // Both should have score 120 + for _, hit := range res.Hits { + if hit.ID != "doc5" && hit.ID != "doc6" { + t.Fatalf("unexpected docID %s, expected doc5 or doc6", hit.ID) + } + if hit.Score != 120 { + t.Fatalf("for %s, expected score 120, got %v", hit.ID, hit.Score) + } + } + }) +} + +// TestMultiVectorCosineNormalization verifies that multi-vector fields are +// normalized correctly with cosine similarity. Each sub-vector in a multi-vector +// should be independently normalized, producing correct similarity scores. +func TestMultiVectorCosineNormalization(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + const dims = 3 + + // Create index with cosine similarity + indexMapping := NewIndexMapping() + vecFieldMapping := mapping.NewVectorFieldMapping() + vecFieldMapping.Dims = dims + vecFieldMapping.Similarity = index.CosineSimilarity + indexMapping.DefaultMapping.AddFieldMappingsAt("vec", vecFieldMapping) + + // Multi-vector field + vecFieldMappingNested := mapping.NewVectorFieldMapping() + vecFieldMappingNested.Dims = dims + vecFieldMappingNested.Similarity = index.CosineSimilarity + indexMapping.DefaultMapping.AddFieldMappingsAt("vec_nested", vecFieldMappingNested) + + idx, err := New(tmpIndexPath, indexMapping) + if err != nil { + t.Fatal(err) + } + defer func() { + err := idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + docsString := []string{ + `{"vec": [3, 0, 0]}`, + `{"vec": [0, 4, 0]}`, + `{"vec_nested": [[3, 0, 0], [0, 4, 0]]}`, + } + + for i, docStr := range docsString { + var doc map[string]interface{} + err = json.Unmarshal([]byte(docStr), &doc) + if err != nil { + t.Fatal(err) + } + err = idx.Index(fmt.Sprintf("doc%d", i+1), doc) + if err != nil { + t.Fatal(err) + } + } + + // Query for X direction [1,0,0] + searchReq := NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vec", []float32{1, 0, 0}, 3, 1.0) + res, err := idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + // Hit 1 should be doc1 with score 1.0 (perfect match) + if res.Hits[0].ID != "doc1" { + t.Fatalf("expected doc1 as first hit, got %s", res.Hits[0].ID) + } + if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 { + t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score) + } + // Hit 2 should be doc2 with a score of 0.0 (orthogonal) + if res.Hits[1].ID != "doc2" { + t.Fatalf("expected doc2 as second hit, got %s", res.Hits[1].ID) + } + if math.Abs(float64(res.Hits[1].Score-0.0)) > 1e-6 { + t.Fatalf("expected score 0.0, got %f", res.Hits[1].Score) + } + + // Query for Y direction [0,1,0] + searchReq = NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vec", []float32{0, 1, 0}, 3, 1.0) + res, err = idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + // Hit 1 should be doc2 with score 1.0 (perfect match) + if res.Hits[0].ID != "doc2" { + t.Fatalf("expected doc2 as first hit, got %s", res.Hits[0].ID) + } + if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 { + t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score) + } + // Hit 2 should be doc1 with a score of 0.0 (orthogonal) + if res.Hits[1].ID != "doc1" { + t.Fatalf("expected doc1 as second hit, got %s", res.Hits[1].ID) + } + if math.Abs(float64(res.Hits[1].Score-0.0)) > 1e-6 { + t.Fatalf("expected score 0.0, got %f", res.Hits[1].Score) + } + + // Now test querying the nested multi-vector field + searchReq = NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vec_nested", []float32{1, 0, 0}, 3, 1.0) + res, err = idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + // Hit should be doc3 with score 1.0 (perfect match on first sub-vector) + if res.Hits[0].ID != "doc3" { + t.Fatalf("expected doc3 as first hit, got %s", res.Hits[0].ID) + } + if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 { + t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score) + } + // Query for Y direction [0,1,0] on nested field + searchReq = NewSearchRequest(query.NewMatchNoneQuery()) + searchReq.AddKNN("vec_nested", []float32{0, 1, 0}, 3, 1.0) + res, err = idx.Search(searchReq) + if err != nil { + t.Fatal(err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + // Hit should be doc3 with score 1.0 (perfect match on second sub-vector) + if res.Hits[0].ID != "doc3" { + t.Fatalf("expected doc3 as first hit, got %s", res.Hits[0].ID) + } + if math.Abs(float64(res.Hits[0].Score-1.0)) > 1e-6 { + t.Fatalf("expected score 1.0, got %f", res.Hits[0].Score) + } +} + func TestNumVecsStat(t *testing.T) { dataset, _, err := readDatasetAndQueries(testInputCompressedFile) diff --git a/search_test.go b/search_test.go index 3768e11fe..e3e03421f 100644 --- a/search_test.go +++ b/search_test.go @@ -5219,3 +5219,1024 @@ func TestSearchRequestValidatePagination(t *testing.T) { }) } } + +func createNestedIndexMapping() mapping.IndexMapping { + + /* + company + ├── id + ├── name + ├── departments[] (nested) + │ ├── name + │ ├── budget + │ ├── employees[] (nested) + │ │ ├── name + │ │ ├── role + │ └── projects[] (nested) + │ ├── title + │ ├── status + └── locations[] (nested) + ├── city + ├── country + */ + + // Create the index mapping + imap := mapping.NewIndexMapping() + + // Create company mapping + companyMapping := mapping.NewDocumentMapping() + + // Company ID field + companyIDField := mapping.NewTextFieldMapping() + companyMapping.AddFieldMappingsAt("id", companyIDField) + + // Company name field + companyNameField := mapping.NewTextFieldMapping() + companyMapping.AddFieldMappingsAt("name", companyNameField) + + // Departments mapping + departmentsMapping := mapping.NewNestedDocumentMapping() + + // Department name field + deptNameField := mapping.NewTextFieldMapping() + departmentsMapping.AddFieldMappingsAt("name", deptNameField) + + // Department budget field + deptBudgetField := mapping.NewNumericFieldMapping() + departmentsMapping.AddFieldMappingsAt("budget", deptBudgetField) + + // Employees mapping + employeesMapping := mapping.NewNestedDocumentMapping() + + // Employee name field + empNameField := mapping.NewTextFieldMapping() + employeesMapping.AddFieldMappingsAt("name", empNameField) + + // Employee role field + empRoleField := mapping.NewTextFieldMapping() + employeesMapping.AddFieldMappingsAt("role", empRoleField) + + departmentsMapping.AddSubDocumentMapping("employees", employeesMapping) + + // Projects mapping + projectsMapping := mapping.NewNestedDocumentMapping() + + // Project title field + projTitleField := mapping.NewTextFieldMapping() + projectsMapping.AddFieldMappingsAt("title", projTitleField) + + // Project status field + projStatusField := mapping.NewTextFieldMapping() + projectsMapping.AddFieldMappingsAt("status", projStatusField) + + departmentsMapping.AddSubDocumentMapping("projects", projectsMapping) + + companyMapping.AddSubDocumentMapping("departments", departmentsMapping) + + // Locations mapping + locationsMapping := mapping.NewNestedDocumentMapping() + + // Location city field + cityField := mapping.NewTextFieldMapping() + locationsMapping.AddFieldMappingsAt("city", cityField) + + // Location country field + countryField := mapping.NewTextFieldMapping() + locationsMapping.AddFieldMappingsAt("country", countryField) + + companyMapping.AddSubDocumentMapping("locations", locationsMapping) + + // Add company to type mapping + imap.DefaultMapping.AddSubDocumentMapping("company", companyMapping) + + return imap +} + +func TestNestedPrefixes(t *testing.T) { + imap := createNestedIndexMapping() + + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + idx, err := New(tmpIndexPath, imap) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := idx.Close(); err != nil { + t.Fatal(err) + } + }() + + nmap, ok := imap.(mapping.NestedMapping) + if !ok { + t.Fatal("index mapping is not a NestedMapping") + } + + // ---------------------------------------------------------------------- + // Test 1: Employee Role AND Employee Name + // ---------------------------------------------------------------------- + fs := search.NewFieldSet() + fs.AddField("company.departments.employees.role") + fs.AddField("company.departments.employees.name") + + expectedCommon := 2 + expectedMax := 2 + + common, max := nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test1: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 2: Employee Role AND Employee Name AND Department Name + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.departments.employees.role") + fs.AddField("company.departments.employees.name") + fs.AddField("company.departments.name") + + expectedCommon = 1 + expectedMax = 2 // employees nested deeper + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test2: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 3: Employee Role AND Location City + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.departments.employees.role") + fs.AddField("company.locations.city") + + expectedCommon = 0 + expectedMax = 2 // employees deeper than locations (1) + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test3: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 4: Company Name AND Location Country + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.name") + fs.AddField("company.locations.country") + fs.AddField("company.locations.city") + + expectedCommon = 0 + expectedMax = 1 // locations.country and locations.city share depth 1 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test4: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 5: Department Budget AND Project Status AND Employee Name + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.departments.budget") + fs.AddField("company.departments.projects.status") + fs.AddField("company.departments.employees.name") + + expectedCommon = 1 + expectedMax = 2 // employees + projects go deeper + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test5: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 6: Single Field + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.id") + + expectedCommon = 0 + expectedMax = 0 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test6: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 7: No Fields + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + + expectedCommon = 0 + expectedMax = 0 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test7: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 8: All Fields + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.id") + fs.AddField("company.name") + fs.AddField("company.departments.name") + fs.AddField("company.departments.budget") + fs.AddField("company.departments.employees.name") + fs.AddField("company.departments.employees.role") + fs.AddField("company.departments.projects.title") + fs.AddField("company.departments.projects.status") + fs.AddField("company.locations.city") + fs.AddField("company.locations.country") + + expectedCommon = 0 // spans different contexts + expectedMax = 2 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test8: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 9: Project Title AND Project Status + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.departments.projects.title") + fs.AddField("company.departments.projects.status") + + expectedCommon = 2 + expectedMax = 2 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test9: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } + + // ---------------------------------------------------------------------- + // Test 10: Department Name AND Location Country + // ---------------------------------------------------------------------- + fs = search.NewFieldSet() + fs.AddField("company.departments.name") + fs.AddField("company.locations.country") + fs.AddField("company.locations.city") + + expectedCommon = 0 + expectedMax = 1 // locations share depth 1 + + common, max = nmap.NestedDepth(fs) + if common != expectedCommon || max != expectedMax { + t.Fatalf("Test10: expected (common=%d, max=%d), got (common=%d, max=%d)", + expectedCommon, expectedMax, common, max) + } +} + +func TestNestedConjunctionQuery(t *testing.T) { + imap := createNestedIndexMapping() + err := imap.Validate() + if err != nil { + t.Fatalf("expected valid nested index mapping, got error: %v", err) + } + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + idx, err := New(tmpIndexPath, imap) + if err != nil { + t.Fatal(err) + } + defer func() { + err = idx.Close() + if err != nil { + t.Fatal(err) + } + }() + // Index 3 sample documents + docs := []struct { + id string + data string + }{ + { + id: "doc1", + data: `{ + "company": { + "id": "c1", + "name": "TechCorp", + "departments": [ + { + "name": "Engineering", + "budget": 2000000, + "employees": [ + {"name": "Alice", "role": "Engineer"}, + {"name": "Bob", "role": "Manager"} + ], + "projects": [ + {"title": "Project X", "status": "ongoing"}, + {"title": "Project Y", "status": "completed"} + ] + }, + { + "name": "Sales", + "budget": 300000, + "employees": [ + {"name": "Eve", "role": "Salesperson"}, + {"name": "Mallory", "role": "Manager"} + ], + "projects": [ + {"title": "Project A", "status": "completed"}, + {"title": "Project B", "status": "ongoing"} + ] + } + ], + "locations": [ + {"city": "Athens", "country": "Greece"}, + {"city": "Berlin", "country": "USA"} + ] + } + }`, + }, + { + id: "doc2", + data: `{ + "company" : { + "id": "c2", + "name": "BizInc", + "departments": [ + { + "name": "Marketing", + "budget": 800000, + "employees": [ + {"name": "Eve", "role": "Marketer"}, + {"name": "David", "role": "Manager"} + ], + "projects": [ + {"title": "Project Z", "status": "ongoing"}, + {"title": "Project W", "status": "planned"} + ] + }, + { + "name": "Engineering", + "budget": 800000, + "employees": [ + {"name": "Frank", "role": "Manager"}, + {"name": "Grace", "role": "Engineer"} + ], + "projects": [ + {"title": "Project Alpha", "status": "completed"}, + {"title": "Project Beta", "status": "ongoing"} + ] + } + ], + "locations": [ + {"city": "Athens", "country": "USA"}, + {"city": "London", "country": "UK"} + ] + } + }`, + }, + { + id: "doc3", + data: `{ + "company": { + "id": "c3", + "name": "WebSolutions", + "departments": [ + { + "name": "HR", + "budget": 800000, + "employees": [ + {"name": "Eve", "role": "Manager"}, + {"name": "Frank", "role": "HR"} + ], + "projects": [ + {"title": "Project Beta", "status": "completed"}, + {"title": "Project B", "status": "ongoing"} + ] + }, + { + "name": "Engineering", + "budget": 200000, + "employees": [ + {"name": "Heidi", "role": "Support Engineer"}, + {"name": "Ivan", "role": "Manager"} + ], + "projects": [ + {"title": "Project Helpdesk", "status": "ongoing"}, + {"title": "Project FAQ", "status": "completed"} + ] + } + ], + "locations": [ + {"city": "Edinburgh", "country": "UK"}, + {"city": "London", "country": "Canada"} + ] + } + }`, + }, + } + + for _, doc := range docs { + var dataMap map[string]interface{} + err := json.Unmarshal([]byte(doc.data), &dataMap) + if err != nil { + t.Fatalf("failed to unmarshal document %s: %v", doc.id, err) + } + err = idx.Index(doc.id, dataMap) + if err != nil { + t.Fatalf("failed to index document %s: %v", doc.id, err) + } + } + + var buildReq = func(subQueries []query.Query) *SearchRequest { + rv := NewSearchRequest(query.NewConjunctionQuery(subQueries)) + rv.SortBy([]string{"_id"}) + rv.Fields = []string{"*"} + rv.Highlight = NewHighlightWithStyle(ansi.Name) + return rv + } + + var ( + req *SearchRequest + res *SearchResult + deptNameQuery *query.MatchQuery + deptBudgetQuery *query.NumericRangeQuery + empNameQuery *query.MatchQuery + empRoleQuery *query.MatchQuery + projTitleQuery *query.MatchPhraseQuery + projStatusQuery *query.MatchQuery + countryQuery *query.MatchQuery + cityQuery *query.MatchQuery + ) + + // Test 1: Find companies with a department named "Engineering" AND budget over 900000 + deptNameQuery = query.NewMatchQuery("Engineering") + deptNameQuery.SetField("company.departments.name") + + min := float64(800000) + deptBudgetQuery = query.NewNumericRangeQuery(&min, nil) + deptBudgetQuery.SetField("company.departments.budget") + + req = buildReq([]query.Query{deptNameQuery, deptBudgetQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc1" || res.Hits[1].ID != "doc2" { + t.Fatalf("unexpected hit IDs: %v, %v", res.Hits[0].ID, res.Hits[1].ID) + } + + // Test 2: Find companies with an employee named "Eve" AND project status "completed" + empNameQuery = query.NewMatchQuery("Eve") + empNameQuery.SetField("company.departments.employees.name") + + projStatusQuery = query.NewMatchQuery("completed") + projStatusQuery.SetField("company.departments.projects.status") + + req = buildReq([]query.Query{empNameQuery, projStatusQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc1" || res.Hits[1].ID != "doc3" { + t.Fatalf("unexpected hit IDs: %v, %v", res.Hits[0].ID, res.Hits[1].ID) + } + + // Test 3: Find companies located in "Athens, USA" AND with an Engineering department + countryQuery = query.NewMatchQuery("USA") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("Athens") + cityQuery.SetField("company.locations.city") + + locQuery := query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + deptNameQuery = query.NewMatchQuery("Engineering") + deptNameQuery.SetField("company.departments.name") + + req = buildReq([]query.Query{locQuery, deptNameQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc2" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } + + // Test 4a: Find companies located in "Athens, USA" AND with an Engineering department with a budget over 1M + countryQuery = query.NewMatchQuery("USA") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("Athens") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + deptNameQuery = query.NewMatchQuery("Engineering") + deptNameQuery.SetField("company.departments.name") + + min = float64(1000000) + deptBudgetQuery = query.NewNumericRangeQuery(&min, nil) + deptBudgetQuery.SetField("company.departments.budget") + + deptQuery := query.NewConjunctionQuery([]query.Query{deptNameQuery, deptBudgetQuery}) + + req = buildReq([]query.Query{locQuery, deptQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hits, got %d", len(res.Hits)) + } + + // Test 4b: Find companies located in "Athens, Greece" AND with an Engineering department with a budget over 1M + countryQuery = query.NewMatchQuery("Greece") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("Athens") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + deptNameQuery = query.NewMatchQuery("Engineering") + deptNameQuery.SetField("company.departments.name") + + min = float64(1000000) + deptBudgetQuery = query.NewNumericRangeQuery(&min, nil) + deptBudgetQuery.SetField("company.departments.budget") + + deptQuery = query.NewConjunctionQuery([]query.Query{deptNameQuery, deptBudgetQuery}) + + req = buildReq([]query.Query{locQuery, deptQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hits, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc1" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } + + // Test 5a: Find companies with an employee named "Frank" AND role "Manager" whose department is + // handling a project titled "Project Beta" which is marked as "completed" + empNameQuery = query.NewMatchQuery("Frank") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery := query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + projTitleQuery = query.NewMatchPhraseQuery("Project Beta") + projTitleQuery.SetField("company.departments.projects.title") + + projStatusQuery = query.NewMatchQuery("completed") + projStatusQuery.SetField("company.departments.projects.status") + + projQuery := query.NewConjunctionQuery([]query.Query{projTitleQuery, projStatusQuery}) + + req = buildReq([]query.Query{empQuery, projQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hit, got %d", len(res.Hits)) + } + + // Test 5b: Find companies with an employee named "Frank" AND role "Manager" whose department is + // handling a project titled "Project Beta" which is marked as "ongoing" + empNameQuery = query.NewMatchQuery("Frank") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + projTitleQuery = query.NewMatchPhraseQuery("Project Beta") + projTitleQuery.SetField("company.departments.projects.title") + + projStatusQuery = query.NewMatchQuery("ongoing") + projStatusQuery.SetField("company.departments.projects.status") + + projQuery = query.NewConjunctionQuery([]query.Query{projTitleQuery, projStatusQuery}) + + req = buildReq([]query.Query{empQuery, projQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc2" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } + + // Test 6a: Find companies with an employee named "Eve" AND role "Manager" + // who is working in a department located in "London, UK" + empNameQuery = query.NewMatchQuery("Eve") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + countryQuery = query.NewMatchQuery("UK") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("London") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + req = buildReq([]query.Query{empQuery, locQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hit, got %d", len(res.Hits)) + } + + // Test 6b: Find companies with an employee named "Eve" AND role "Manager" + // who is working in a department located in "London, Canada" + empNameQuery = query.NewMatchQuery("Eve") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + countryQuery = query.NewMatchQuery("Canada") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("London") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + req = buildReq([]query.Query{empQuery, locQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc3" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } + + // Test 7a: Find companies where Ivan the Manager works London, UK + + empNameQuery = query.NewMatchQuery("Ivan") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + countryQuery = query.NewMatchQuery("UK") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("London") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + req = buildReq([]query.Query{empQuery, locQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hit, got %d", len(res.Hits)) + } + + // Test 7b: Find companies where Ivan the Manager works London, Canada + + empNameQuery = query.NewMatchQuery("Ivan") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + countryQuery = query.NewMatchQuery("Canada") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("London") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + req = buildReq([]query.Query{empQuery, locQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc3" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } + + // Test 8: Find companies where Frank the Manager works in Engineering department located in London, UK + empNameQuery = query.NewMatchQuery("Frank") + empNameQuery.SetField("company.departments.employees.name") + + empRoleQuery = query.NewMatchQuery("Manager") + empRoleQuery.SetField("company.departments.employees.role") + + empQuery = query.NewConjunctionQuery([]query.Query{empNameQuery, empRoleQuery}) + + deptNameQuery = query.NewMatchQuery("Engineering") + deptNameQuery.SetField("company.departments.name") + + deptQuery = query.NewConjunctionQuery([]query.Query{empQuery, deptNameQuery}) + + countryQuery = query.NewMatchQuery("UK") + countryQuery.SetField("company.locations.country") + + cityQuery = query.NewMatchQuery("London") + cityQuery.SetField("company.locations.city") + + locQuery = query.NewConjunctionQuery([]query.Query{countryQuery, cityQuery}) + + req = buildReq([]query.Query{deptQuery, locQuery}) + res, err = idx.Search(req) + if err != nil { + t.Fatalf("search failed: %v", err) + } + if len(res.Hits) != 1 { + t.Fatalf("expected 1 hit, got %d", len(res.Hits)) + } + if res.Hits[0].ID != "doc2" { + t.Fatalf("unexpected hit ID: %v", res.Hits[0].ID) + } +} + +func TestNestedArrayConjunctionQuery(t *testing.T) { + imap := NewIndexMapping() + groupsMapping := mapping.NewNestedDocumentMapping() + + nameField := mapping.NewTextFieldMapping() + groupsMapping.AddFieldMappingsAt("first_name", nameField) + groupsMapping.AddFieldMappingsAt("last_name", nameField) + + imap.DefaultMapping.AddSubDocumentMapping("groups", groupsMapping) + + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + idx, err := New(tmpIndexPath, imap) + if err != nil { + t.Fatal(err) + } + defer func() { + err = idx.Close() + if err != nil { + t.Fatal(err) + } + }() + + docs := []string{ + `{ + "groups": [ + [ + { + "first_name": "Alice", + "last_name": "Smith" + }, + { + "first_name": "Bob", + "last_name": "Johnson" + } + ], + [ + { + "first_name": "Charlie", + "last_name": "Williams" + }, + { + "first_name": "Diana", + "last_name": "Brown" + } + ] + ] + }`, + `{ + "groups": [ + { + "first_name": "Alice", + "last_name": "Smith" + }, + { + "first_name": "Bob", + "last_name": "Johnson" + }, + { + "first_name": "Charlie", + "last_name": "Williams" + }, + { + "first_name": "Diana", + "last_name": "Brown" + } + ] + }`, + } + + for i, doc := range docs { + var dataMap map[string]interface{} + err := json.Unmarshal([]byte(doc), &dataMap) + if err != nil { + t.Fatalf("failed to unmarshal document %d: %v", i, err) + } + err = idx.Index(fmt.Sprintf("%d", i+1), dataMap) + if err != nil { + t.Fatalf("failed to index document %d: %v", i, err) + } + } + + var ( + firstNameQuery *query.MatchQuery + lastNameQuery *query.MatchQuery + conjQuery *query.ConjunctionQuery + searchReq *SearchRequest + res *SearchResult + ) + + // Search for documents where first_name is "Alice" AND last_name is "Johnson" + firstNameQuery = query.NewMatchQuery("Alice") + firstNameQuery.SetField("groups.first_name") + + lastNameQuery = query.NewMatchQuery("Johnson") + lastNameQuery.SetField("groups.last_name") + + conjQuery = query.NewConjunctionQuery([]query.Query{firstNameQuery, lastNameQuery}) + + searchReq = NewSearchRequest(conjQuery) + searchReq.SortBy([]string{"_id"}) + + res, err = idx.Search(searchReq) + if err != nil { + t.Fatalf("search failed: %v", err) + } + + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hits, got %d", len(res.Hits)) + } + + // Search for documents where first_name is "Bob" AND last_name is "Johnson" + firstNameQuery = query.NewMatchQuery("Bob") + firstNameQuery.SetField("groups.first_name") + + lastNameQuery = query.NewMatchQuery("Johnson") + lastNameQuery.SetField("groups.last_name") + + conjQuery = query.NewConjunctionQuery([]query.Query{firstNameQuery, lastNameQuery}) + + searchReq = NewSearchRequest(conjQuery) + searchReq.SortBy([]string{"_id"}) + + res, err = idx.Search(searchReq) + if err != nil { + t.Fatalf("search failed: %v", err) + } + + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + + if res.Hits[0].ID != "1" || res.Hits[1].ID != "2" { + t.Fatalf("unexpected hit IDs: %v, %v", res.Hits[0].ID, res.Hits[1].ID) + } + + // Search for documents where first_name is "Alice" AND last_name is "Williams" + firstNameQuery = query.NewMatchQuery("Alice") + firstNameQuery.SetField("groups.first_name") + + lastNameQuery = query.NewMatchQuery("Williams") + lastNameQuery.SetField("groups.last_name") + + conjQuery = query.NewConjunctionQuery([]query.Query{firstNameQuery, lastNameQuery}) + + searchReq = NewSearchRequest(conjQuery) + searchReq.SortBy([]string{"_id"}) + + res, err = idx.Search(searchReq) + if err != nil { + t.Fatalf("search failed: %v", err) + } + + if len(res.Hits) != 0 { + t.Fatalf("expected 0 hits, got %d", len(res.Hits)) + } + + // Search for documents where first_name is "Diana" AND last_name is "Brown" + firstNameQuery = query.NewMatchQuery("Diana") + firstNameQuery.SetField("groups.first_name") + + lastNameQuery = query.NewMatchQuery("Brown") + lastNameQuery.SetField("groups.last_name") + + conjQuery = query.NewConjunctionQuery([]query.Query{firstNameQuery, lastNameQuery}) + + searchReq = NewSearchRequest(conjQuery) + searchReq.SortBy([]string{"_id"}) + + res, err = idx.Search(searchReq) + if err != nil { + t.Fatalf("search failed: %v", err) + } + + if len(res.Hits) != 2 { + t.Fatalf("expected 2 hits, got %d", len(res.Hits)) + } + + if res.Hits[0].ID != "1" || res.Hits[1].ID != "2" { + t.Fatalf("unexpected hit IDs: %v, %v", res.Hits[0].ID, res.Hits[1].ID) + } +} + +func TestValidNestedMapping(t *testing.T) { + // ensure that top-level mappings - DefaultMapping and any type mappings - cannot be nested mappings + imap := mapping.NewIndexMapping() + nestedMapping := mapping.NewNestedDocumentMapping() + imap.DefaultMapping = nestedMapping + err := imap.Validate() + if err == nil { + t.Fatalf("expected error for nested DefaultMapping, got nil") + } + // invalid nested type mapping + imap = mapping.NewIndexMapping() + imap.AddDocumentMapping("type1", nestedMapping) + err = imap.Validate() + if err == nil { + t.Fatalf("expected error for nested type mapping, got nil") + } + // valid nested mappings within DefaultMapping + imap = mapping.NewIndexMapping() + docMapping := mapping.NewDocumentMapping() + nestedMapping = mapping.NewNestedDocumentMapping() + fieldMapping := mapping.NewTextFieldMapping() + nestedMapping.AddFieldMappingsAt("field1", fieldMapping) + docMapping.AddSubDocumentMapping("nestedField", nestedMapping) + imap.DefaultMapping = docMapping + err = imap.Validate() + if err != nil { + t.Fatalf("expected valid nested mapping, got error: %v", err) + } + // valid nested mappings within type mapping + imap = mapping.NewIndexMapping() + docMapping = mapping.NewDocumentMapping() + nestedMapping = mapping.NewNestedDocumentMapping() + fieldMapping = mapping.NewTextFieldMapping() + nestedMapping.AddFieldMappingsAt("field1", fieldMapping) + docMapping.AddSubDocumentMapping("nestedField", nestedMapping) + imap.AddDocumentMapping("type1", docMapping) + err = imap.Validate() + if err != nil { + t.Fatalf("expected valid nested mapping, got error: %v", err) + } + // some nested type mappings + imap = mapping.NewIndexMapping() + nestedMapping = mapping.NewNestedDocumentMapping() + regularMapping := mapping.NewDocumentMapping() + imap.AddDocumentMapping("non_nested1", regularMapping) + imap.AddDocumentMapping("non_nested2", regularMapping) + imap.AddDocumentMapping("nested1", nestedMapping) + imap.AddDocumentMapping("nested2", nestedMapping) + err = imap.Validate() + if err == nil { + t.Fatalf("expected error for nested type mappings, got nil") + } +}