diff --git a/go.mod b/go.mod index 9a1985795..70abb8efb 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( ) require ( + github.com/PaesslerAG/jsonpath v0.1.1 github.com/go-chi/chi/v5 v5.0.10 github.com/kloudlite/container-registry-authorizer v0.0.0-20231021122509-161dc30fde55 github.com/miekg/dns v1.1.55 @@ -58,6 +59,7 @@ require ( ) require ( + github.com/PaesslerAG/gval v1.0.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/lipgloss v0.10.0 // indirect github.com/charmbracelet/log v0.4.0 // indirect diff --git a/go.sum b/go.sum index e277383b8..4637c250f 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,11 @@ github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7Y github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= +github.com/PaesslerAG/gval v1.0.0 h1:GEKnRwkWDdf9dOmKcNrar9EA1bz1z9DqPIO1+iLzhd8= +github.com/PaesslerAG/gval v1.0.0/go.mod h1:y/nm5yEyTeX6av0OfKJNp9rBNj2XrGhAf5+v24IBN1I= +github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= +github.com/PaesslerAG/jsonpath v0.1.1 h1:c1/AToHQMVsduPAa4Vh6xp2U0evy4t8SWp8imEsylIk= +github.com/PaesslerAG/jsonpath v0.1.1/go.mod h1:lVboNxFGal/VwW6d9JzIy56bUsYAP6tH/x80vjnCseY= github.com/PuerkitoBio/goquery v1.9.1 h1:mTL6XjbJTZdpfL+Gwl5U2h1l9yEkJjhmlTeV9VPW7UI= github.com/PuerkitoBio/goquery v1.9.1/go.mod h1:cW1n6TmIMDoORQU5IU/P1T3tGFunOeXEpGP2WHRwkbY= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= @@ -175,8 +180,6 @@ github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLA github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kloudlite/container-registry-authorizer v0.0.0-20231021122509-161dc30fde55 h1:YnZh3TL6AG4EfoInx1/L5zcPHd2QxgLKseJB1KtHjdQ= github.com/kloudlite/container-registry-authorizer v0.0.0-20231021122509-161dc30fde55/go.mod h1:GZj3wZmIw/qCciclRhgQTgmGiqe8wxoVzMXQjbOfnbc= -github.com/kloudlite/operator v0.0.0-20240710071747-9a61e7de9e93 h1:vbF6PPTjgmtE5pNHKdZmTMmjfC4njjGW1EO8m7Njx1w= -github.com/kloudlite/operator v0.0.0-20240710071747-9a61e7de9e93/go.mod h1:c6FiZvYztvr92/UcIUvQurp3oWMrrEK7deAriHckTPw= github.com/kloudlite/operator v0.0.0-20240718072819-c625b77c43b2 h1:K6tlpBcl4PWv6ZiojFTXADYuAvGjcOkrYQbsPvbXaOU= github.com/kloudlite/operator v0.0.0-20240718072819-c625b77c43b2/go.mod h1:c6FiZvYztvr92/UcIUvQurp3oWMrrEK7deAriHckTPw= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= diff --git a/pkg/repos/db-repo-mongo.go b/pkg/repos/db-repo-mongo.go index ef2a0b0fc..5f7e2b2c7 100644 --- a/pkg/repos/db-repo-mongo.go +++ b/pkg/repos/db-repo-mongo.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "fmt" + "github.com/PaesslerAG/jsonpath" + "go.mongodb.org/mongo-driver/bson/primitive" "regexp" "strings" "time" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" "go.uber.org/fx" "github.com/kloudlite/api/pkg/errors" @@ -81,24 +82,24 @@ func bsonToStruct[T any](r *mongo.SingleResult) (T, error) { return result, nil } -func cursorToStruct[T any](ctx context.Context, curr *mongo.Cursor) ([]T, error) { +func cursorToStruct[T any](ctx context.Context, curr *mongo.Cursor) ([]T, []map[string]any, error) { var m []map[string]any var results []T if err := curr.All(ctx, &m); err != nil { - return results, errors.NewE(err) + return results, m, errors.NewE(err) } b, err := json.Marshal(m) if err != nil { - return results, errors.NewE(err) + return results, m, errors.NewE(err) } if err := json.Unmarshal(b, &results); err != nil { - return results, errors.NewE(err) + return results, m, errors.NewE(err) } - return results, nil + return results, m, nil } func (repo *dbRepo[T]) NewId() ID { @@ -122,8 +123,8 @@ func (repo *dbRepo[T]) Find(ctx context.Context, query Query) ([]T, error) { } return nil, errors.NewE(err) } - - return cursorToStruct[T](ctx, curr) + toStruct, _, err := cursorToStruct[T](ctx, curr) + return toStruct, err } func (repo *dbRepo[T]) Count(ctx context.Context, filter Filter) (int64, error) { @@ -171,6 +172,14 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat return nil, errors.Newf("paramter `before` requires paramter `last` to be specified") } + var cursorKey string + + if pagination.OrderBy == "" { + cursorKey = "_id" + } else { + cursorKey = pagination.OrderBy + } + queryFilter := Filter{} for k, v := range filter { @@ -182,11 +191,13 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat if err != nil { return nil, errors.NewE(err) } - objectID, err := primitive.ObjectIDFromHex(string(aft)) - if err != nil { - return nil, errors.NewE(err) + + if pagination.SortDirection == SortDirectionAsc { + queryFilter[cursorKey] = bson.M{"$gte": string(aft)} + } else { + queryFilter[cursorKey] = bson.M{"$lte": string(aft)} } - queryFilter["_id"] = bson.M{"$gt": objectID} + } if pagination.Before != nil { @@ -194,11 +205,12 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat if err != nil { return nil, errors.NewE(err) } - objectID, err := primitive.ObjectIDFromHex(string(bef)) - if err != nil { - return nil, errors.NewE(err) + + if pagination.SortDirection == SortDirectionAsc { + queryFilter[cursorKey] = bson.M{"$lte": string(bef)} + } else { + queryFilter[cursorKey] = bson.M{"$gte": string(bef)} } - queryFilter["_id"] = bson.M{"$lt": objectID} } var limit int64 @@ -210,7 +222,6 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat limit = *pagination.First + 1 } - // var results []T curr, err := repo.db.Collection(repo.collectionName).Find( ctx, queryFilter, &options.FindOptions{ Limit: &limit, @@ -226,7 +237,7 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat return nil, errors.NewE(err) } - results, err := cursorToStruct[T](ctx, curr) + results, rawResults, err := cursorToStruct[T](ctx, curr) if err != nil { return nil, errors.NewE(err) } @@ -238,9 +249,18 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat pageInfo := PageInfo{} + getCursorOfResult := func(r T, m map[string]any) (string, error) { + if cursorKey == "_id" { + return CursorToBase64(Cursor(r.GetId())), nil + } + val, err := jsonpath.Get(fmt.Sprintf("$.%s", cursorKey), m) + if err != nil { + return "", errors.NewE(err) + } + return CursorToBase64(Cursor(fmt.Sprintf("%v", val))), nil + } + if len(results) > 0 { - pageInfo.StartCursor = CursorToBase64(Cursor(string(results[0].GetPrimitiveID()))) - pageInfo.EndCursor = CursorToBase64(Cursor(string(results[len(results)-1].GetPrimitiveID()))) if pagination.First != nil { pageInfo.HasNextPage = fn.New(len(results) > int(*pagination.First)) @@ -253,18 +273,30 @@ func (repo *dbRepo[T]) FindPaginated(ctx context.Context, filter Filter, paginat if pagination.Last != nil { pageInfo.HasNextPage = fn.New(pagination.Before != nil) pageInfo.HasPrevPage = fn.New(len(results) > int(*pagination.Last)) - if pageInfo.HasPrevPage != nil && *pageInfo.HasPrevPage { results = results[:*pagination.Last] } } + + pageInfo.StartCursor, err = getCursorOfResult(results[0], rawResults[0]) + if err != nil { + return nil, errors.NewE(err) + } + pageInfo.EndCursor, err = getCursorOfResult(results[len(results)-1], rawResults[len(results)-1]) + if err != nil { + return nil, errors.NewE(err) + } } edges := make([]RecordEdge[T], len(results)) for i := range results { + c, err := getCursorOfResult(results[i], rawResults[i]) + if err != nil { + return nil, errors.NewE(err) + } edges[i] = RecordEdge[T]{ Node: results[i], - Cursor: CursorToBase64(Cursor(results[i].GetPrimitiveID())), + Cursor: c, } }