diff --git a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java index 3111cc63..9c4922e9 100644 --- a/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java +++ b/document-store/src/integrationTest/java/org/hypertrace/core/documentstore/FlatCollectionWriteTest.java @@ -3403,6 +3403,135 @@ void testCreateOrReplaceRefreshesSchemaOnDroppedColumn() throws Exception { } } + @Nested + @DisplayName("Key-Specific Bulk Update Operations") + class KeySpecificBulkUpdateTests { + + @Test + @DisplayName("Should update multiple keys with all operator types in a single batch") + void testBulkUpdateAllOperatorTypes() throws Exception { + Map> updates = new LinkedHashMap<>(); + updates.put( + rawKey("1"), + List.of( + SubDocumentUpdate.of("item", "UpdatedSoap"), + SubDocumentUpdate.builder() + .subDocument("price") + .operator(UpdateOperator.ADD) + .subDocumentValue(SubDocumentValue.of(5)) + .build(), + SubDocumentUpdate.builder() + .subDocument("props.brand") + .operator(UpdateOperator.SET) + .subDocumentValue(SubDocumentValue.of("NewBrand")) + .build())); + + updates.put( + rawKey("3"), + List.of( + SubDocumentUpdate.builder() + .subDocument("props.brand") + .operator(UpdateOperator.UNSET) + .build(), + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.APPEND_TO_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"newTag1", "newTag2"})) + .build())); + + updates.put( + rawKey("5"), + List.of( + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.ADD_TO_LIST_IF_ABSENT) + .subDocumentValue(SubDocumentValue.of(new String[] {"hygiene", "uniqueTag"})) + .build())); + + updates.put( + rawKey("6"), + List.of( + SubDocumentUpdate.builder() + .subDocument("tags") + .operator(UpdateOperator.REMOVE_ALL_FROM_LIST) + .subDocumentValue(SubDocumentValue.of(new String[] {"plastic"})) + .build())); + + BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build()); + + assertEquals(4, result.getUpdatedCount()); + + try (CloseableIterator iter = flatCollection.find(queryById("1"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertEquals("UpdatedSoap", json.get("item").asText()); + assertEquals(15, json.get("price").asInt()); // 10 + 5 + assertEquals("NewBrand", json.get("props").get("brand").asText()); + assertEquals("M", json.get("props").get("size").asText()); // preserved + } + + try (CloseableIterator iter = flatCollection.find(queryById("3"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + assertFalse(json.get("props").has("brand")); + assertEquals("L", json.get("props").get("size").asText()); // preserved + JsonNode tagsNode = json.get("tags"); + assertEquals(6, tagsNode.size()); // Original 4 + 2 new + } + + try (CloseableIterator iter = flatCollection.find(queryById("5"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode tagsNode = json.get("tags"); + assertEquals(4, tagsNode.size()); // Original 3 + 1 new unique + Set tags = new HashSet<>(); + tagsNode.forEach(n -> tags.add(n.asText())); + assertTrue(tags.contains("uniqueTag")); + } + + try (CloseableIterator iter = flatCollection.find(queryById("6"))) { + assertTrue(iter.hasNext()); + JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson()); + JsonNode tagsNode = json.get("tags"); + assertEquals(2, tagsNode.size()); // grooming, essential remain + Set tags = new HashSet<>(); + tagsNode.forEach(n -> tags.add(n.asText())); + assertFalse(tags.contains("plastic")); + } + } + + @Test + @DisplayName("Should handle edge cases: empty map, null map, non-existent keys") + void testBulkUpdateEdgeCases() throws Exception { + UpdateOptions options = UpdateOptions.builder().build(); + + // Empty map + assertEquals(0, flatCollection.bulkUpdate(new HashMap<>(), options).getUpdatedCount()); + + // Null map + Map> nullUpdates = null; + assertEquals(0, flatCollection.bulkUpdate(nullUpdates, options).getUpdatedCount()); + + // Non-existent key + Map> nonExistent = new LinkedHashMap<>(); + nonExistent.put(rawKey("non-existent"), List.of(SubDocumentUpdate.of("item", "X"))); + assertEquals(0, flatCollection.bulkUpdate(nonExistent, options).getUpdatedCount()); + } + + // Creates a key with raw ID (matching test data format) + private Key rawKey(String id) { + return Key.from(id); + } + + private Query queryById(String id) { + return Query.builder() + .setFilter( + RelationalExpression.of( + IdentifierExpression.of("id"), RelationalOperator.EQ, ConstantExpression.of(id))) + .build(); + } + } + private static void executeInsertStatements() { PostgresDatastore pgDatastore = (PostgresDatastore) postgresDatastore; try { diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/Collection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/Collection.java index c1d5357a..a79e3c99 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/Collection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/Collection.java @@ -398,5 +398,54 @@ CloseableIterator bulkUpdate( final UpdateOptions updateOptions) throws IOException; + /** + * Bulk update sub-documents with key-specific updates. Each key can have its own set of + * SubDocumentUpdate operations, allowing different updates per document. + * + *

This method supports all update operators (SET, UNSET, ADD, APPEND_TO_LIST, + * ADD_TO_LIST_IF_ABSENT, REMOVE_ALL_FROM_LIST). Updates for each individual key are applied + * atomically, but there is no atomicity guarantee across different keys - some keys may be + * updated while others fail. Batch-level atomicity is not guaranteed, while per-key update + * atomicity is guaranteed. + * + *

Example usage: + * + *

{@code
+   * Map> updates = new HashMap<>();
+   *
+   * // Key 1: SET a field and ADD to a number
+   * updates.put(key1, List.of(
+   *     SubDocumentUpdate.of("name", "NewName"),
+   *     SubDocumentUpdate.builder()
+   *         .subDocument("count")
+   *         .operator(UpdateOperator.ADD)
+   *         .subDocumentValue(SubDocumentValue.of(5))
+   *         .build()
+   * ));
+   *
+   * // Key 2: APPEND to an array
+   * updates.put(key2, List.of(
+   *     SubDocumentUpdate.builder()
+   *         .subDocument("tags")
+   *         .operator(UpdateOperator.APPEND_TO_LIST)
+   *         .subDocumentValue(SubDocumentValue.of(new String[]{"newTag"}))
+   *         .build()
+   * ));
+   *
+   * BulkUpdateResult result = collection.bulkUpdate(updates, UpdateOptions.builder().build());
+   * }
+ * + * @param updates Map of Key to Collection of SubDocumentUpdate operations. Each key's updates are + * applied atomically, but no cross-key atomicity is guaranteed. + * @param updateOptions Options for the update operation + * @return BulkUpdateResult containing the count of successfully updated documents + * @throws IOException if the update operation fails + */ + default BulkUpdateResult bulkUpdate( + Map> updates, UpdateOptions updateOptions) + throws IOException { + throw new UnsupportedOperationException("bulkUpdate is not supported!"); + } + String UNSUPPORTED_QUERY_OPERATION = "Query operation is not supported"; } diff --git a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java index ad8c1d1a..b54b36d1 100644 --- a/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java +++ b/document-store/src/main/java/org/hypertrace/core/documentstore/postgres/FlatPostgresCollection.java @@ -25,8 +25,19 @@ import java.sql.Timestamp; import java.sql.Types; import java.time.Instant; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import org.hypertrace.core.documentstore.BulkArrayValueUpdateRequest; import org.hypertrace.core.documentstore.BulkDeleteResult; @@ -896,6 +907,109 @@ public CloseableIterator bulkUpdate( } } + @Override + public BulkUpdateResult bulkUpdate( + Map> updates, UpdateOptions updateOptions) + throws IOException { + + if (updates == null || updates.isEmpty()) { + return new BulkUpdateResult(0); + } + + Preconditions.checkArgument(updateOptions != null, "UpdateOptions cannot be NULL"); + + String tableName = tableIdentifier.getTableName(); + String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName)); + + Set updatedKeys = new HashSet<>(); + + try (Connection connection = client.getPooledConnection()) { + for (Map.Entry> entry : updates.entrySet()) { + Key key = entry.getKey(); + Collection keyUpdates = entry.getValue(); + + if (keyUpdates == null || keyUpdates.isEmpty()) { + continue; + } + + try { + boolean updated = updateSingleKey(connection, key, keyUpdates, tableName, quotedPkColumn); + if (updated) { + updatedKeys.add(key); + } + } catch (Exception e) { + LOGGER.warn("Failed to update key {}: {}", key, e.getMessage()); + // Continue with other keys - no cross-key atomicity + } + } + } catch (SQLException e) { + throw new IOException("Failed to get connection for bulk update", e); + } + + return new BulkUpdateResult(updatedKeys.size()); + } + + private boolean updateSingleKey( + Connection connection, + Key key, + Collection keyUpdates, + String tableName, + String quotedPkColumn) + throws IOException, SQLException { + + updateValidator.validate(keyUpdates); + Map resolvedColumns = resolvePathsToColumns(keyUpdates, tableName); + + return executeKeyUpdate( + connection, key, keyUpdates, tableName, quotedPkColumn, resolvedColumns); + } + + private boolean executeKeyUpdate( + Connection connection, + Key key, + java.util.Collection keyUpdates, + String tableName, + String quotedPkColumn, + Map resolvedColumns) + throws SQLException { + + List setFragments = new ArrayList<>(); + List params = new ArrayList<>(); + + boolean hasUpdates = + buildSetClauseFragments( + connection, keyUpdates, tableName, resolvedColumns, setFragments, params); + + if (!hasUpdates) { + return false; + } + + // Add lastUpdatedTime to the same UPDATE statement + if (lastUpdatedTsColumn != null) { + setFragments.add(String.format("\"%s\" = ?", lastUpdatedTsColumn)); + params.add(new Timestamp(System.currentTimeMillis())); + } + + // Build and execute UPDATE SQL for this key + String sql = + String.format( + "UPDATE %s SET %s WHERE %s = ?", + tableIdentifier, String.join(", ", setFragments), quotedPkColumn); + + params.add(key.toString()); + + LOGGER.debug("Executing key update SQL: {}", sql); + + try (PreparedStatement ps = connection.prepareStatement(sql)) { + int idx = 1; + for (Object param : params) { + ps.setObject(idx++, param); + } + int rowsUpdated = ps.executeUpdate(); + return rowsUpdated > 0; + } + } + /** * Validates all updates and resolves column names. * @@ -1014,6 +1128,56 @@ private void executeUpdate( String filterClause = filterParser.buildFilterClause(); Params filterParams = filterParser.getParamsBuilder().build(); + List setFragments = new ArrayList<>(); + List params = new ArrayList<>(); + + boolean hasUpdates = + buildSetClauseFragments( + connection, updates, tableName, resolvedColumns, setFragments, params); + + if (!hasUpdates) { + LOGGER.warn("All update paths were skipped - no valid columns to update"); + return; + } + + // Build final UPDATE SQL + String sql = + String.format( + "UPDATE %s SET %s %s", tableIdentifier, String.join(", ", setFragments), filterClause); + + LOGGER.debug("Executing update SQL: {}", sql); + + try (PreparedStatement ps = connection.prepareStatement(sql)) { + int idx = 1; + for (Object param : params) { + ps.setObject(idx++, param); + } + for (Object param : filterParams.getObjectParams().values()) { + ps.setObject(idx++, param); + } + int rowsUpdated = ps.executeUpdate(); + LOGGER.debug("Rows updated: {}", rowsUpdated); + } catch (SQLException e) { + LOGGER.error("Failed to execute update. SQL: {}, SQLState: {}", sql, e.getSQLState(), e); + throw e; + } + } + + /** + * Builds SET clause fragments for an UPDATE statement by grouping updates by column and chaining + * nested JSONB updates. + * + * @return true if at least one valid update fragment was built, false otherwise + */ + private boolean buildSetClauseFragments( + Connection connection, + Collection updates, + String tableName, + Map resolvedColumns, + List setFragments, + List params) + throws SQLException { + // Group updates by column to handle multiple nested updates to the same JSONB column Map> updatesByColumn = new LinkedHashMap<>(); for (SubDocumentUpdate update : updates) { @@ -1026,9 +1190,9 @@ private void executeUpdate( updatesByColumn.computeIfAbsent(columnName, k -> new ArrayList<>()).add(update); } - // Build SET clause fragments - one per column - List setFragments = new ArrayList<>(); - List params = new ArrayList<>(); + if (updatesByColumn.isEmpty()) { + return false; + } for (Map.Entry> entry : updatesByColumn.entrySet()) { String columnName = entry.getKey(); @@ -1095,33 +1259,7 @@ private void executeUpdate( } } - // If all updates were skipped, nothing to do - if (setFragments.isEmpty()) { - LOGGER.warn("All update paths were skipped - no valid columns to update"); - return; - } - - // Build final UPDATE SQL - String sql = - String.format( - "UPDATE %s SET %s %s", tableIdentifier, String.join(", ", setFragments), filterClause); - - LOGGER.debug("Executing update SQL: {}", sql); - - try (PreparedStatement ps = connection.prepareStatement(sql)) { - int idx = 1; - for (Object param : params) { - ps.setObject(idx++, param); - } - for (Object param : filterParams.getObjectParams().values()) { - ps.setObject(idx++, param); - } - int rowsUpdated = ps.executeUpdate(); - LOGGER.debug("Rows updated: {}", rowsUpdated); - } catch (SQLException e) { - LOGGER.error("Failed to execute update. SQL: {}, SQLState: {}", sql, e.getSQLState(), e); - throw e; - } + return !setFragments.isEmpty(); } /*isRetry: Whether this is a retry attempt*/