diff --git a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/CoordinatorBasicAuthorizerResourceTest.java b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/CoordinatorBasicAuthorizerResourceTest.java index ef10dcc30dee..47695275326b 100644 --- a/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/CoordinatorBasicAuthorizerResourceTest.java +++ b/extensions-core/druid-basic-security/src/test/java/org/apache/druid/security/authorization/CoordinatorBasicAuthorizerResourceTest.java @@ -25,6 +25,8 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.metadata.MetadataStorageTablesConfig; import org.apache.druid.metadata.TestDerbyConnector; import org.apache.druid.security.basic.BasicAuthCommonCacheConfig; @@ -53,9 +55,14 @@ import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; public class CoordinatorBasicAuthorizerResourceTest { @@ -581,6 +588,129 @@ public void testUsersRolesAndPerms() Assert.assertEquals(expectedRoleFull2, response.getEntity()); } + @Test + public void testConcurrentUpdate() + { + final int testMultiple = 100; + + // setup a user and the roles + Response response = resource.createUser(req, AUTHORIZER_NAME, "druid"); + Assert.assertEquals(200, response.getStatus()); + + List perms = ImmutableList.of( + new ResourceAction(new Resource("A", ResourceType.DATASOURCE), Action.READ), + new ResourceAction(new Resource("B", ResourceType.DATASOURCE), Action.WRITE), + new ResourceAction(new Resource("C", ResourceType.CONFIG), Action.WRITE) + ); + + for (int i = 0; i < testMultiple; i++) { + String roleName = "druidRole-" + i; + response = resource.createRole(req, AUTHORIZER_NAME, roleName); + Assert.assertEquals(200, response.getStatus()); + + response = resource.setRolePermissions(req, AUTHORIZER_NAME, roleName, perms); + Assert.assertEquals(200, response.getStatus()); + } + + ExecutorService exec = Execs.multiThreaded(testMultiple, "thread---"); + int[] responseCodesAssign = new int[testMultiple]; + + // assign 'testMultiple' roles to the user concurrently + List> addRoleCallables = new ArrayList<>(); + for (int i = 0; i < testMultiple; i++) { + final int innerI = i; + String roleName = "druidRole-" + i; + addRoleCallables.add( + new Callable() + { + @Override + public Void call() throws Exception + { + Response response = resource.assignRoleToUser(req, AUTHORIZER_NAME, "druid", roleName); + responseCodesAssign[innerI] = response.getStatus(); + return null; + } + } + ); + } + try { + List> futures = exec.invokeAll(addRoleCallables); + for (Future future : futures) { + future.get(); + } + } + catch (Exception e) { + throw new RuntimeException(e); + } + + // the API can return !200 if the update attempt fails by exhausting retries because of + // too much contention from other conflicting requests, make sure that we don't get any successful requests + // that didn't actually take effect + Set roleNames = getRoleNamesAssignedToUser("druid"); + for (int i = 0; i < testMultiple; i++) { + String roleName = "druidRole-" + i; + if (responseCodesAssign[i] == 200 && !roleNames.contains(roleName)) { + Assert.fail( + StringUtils.format("Got response status 200 for assigning role [%s] but user did not have role.", roleName) + ); + } + } + + // Now unassign the roles concurrently + List> removeRoleCallables = new ArrayList<>(); + int[] responseCodesRemove = new int[testMultiple]; + + for (int i = 0; i < testMultiple; i++) { + final int innerI = i; + String roleName = "druidRole-" + i; + removeRoleCallables.add( + new Callable() + { + @Override + public Void call() throws Exception + { + Response response = resource.unassignRoleFromUser(req, AUTHORIZER_NAME, "druid", roleName); + responseCodesRemove[innerI] = response.getStatus(); + return null; + } + } + ); + } + try { + List> futures = exec.invokeAll(removeRoleCallables); + for (Future future : futures) { + future.get(); + } + } + catch (Exception e) { + throw new RuntimeException(e); + } + + roleNames = getRoleNamesAssignedToUser("druid"); + for (int i = 0; i < testMultiple; i++) { + String roleName = "druidRole-" + i; + if (responseCodesRemove[i] == 200 && roleNames.contains(roleName)) { + Assert.fail( + StringUtils.format("Got response status 200 for removing role [%s] but user still has role.", roleName) + ); + } + } + } + + private Set getRoleNamesAssignedToUser( + String user + ) + { + Response response = resource.getUser(req, AUTHORIZER_NAME, user, ""); + Assert.assertEquals(200, response.getStatus()); + BasicAuthorizerUserFull userFull = (BasicAuthorizerUserFull) response.getEntity(); + Set roleNames = new HashSet<>(); + for (BasicAuthorizerRole role : userFull.getRoles()) { + roleNames.add(role.getName()); + } + return roleNames; + } + private static Map errorMapWithMsg(String errorMsg) { return ImmutableMap.of("error", errorMsg); diff --git a/extensions-core/postgresql-metadata-storage/src/main/java/org/apache/druid/metadata/storage/postgresql/PostgreSQLConnector.java b/extensions-core/postgresql-metadata-storage/src/main/java/org/apache/druid/metadata/storage/postgresql/PostgreSQLConnector.java index e234a157989a..1d7187442f0e 100644 --- a/extensions-core/postgresql-metadata-storage/src/main/java/org/apache/druid/metadata/storage/postgresql/PostgreSQLConnector.java +++ b/extensions-core/postgresql-metadata-storage/src/main/java/org/apache/druid/metadata/storage/postgresql/PostgreSQLConnector.java @@ -20,21 +20,26 @@ package org.apache.druid.metadata.storage.postgresql; import com.google.common.base.Supplier; +import com.google.common.base.Throwables; import com.google.inject.Inject; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.metadata.MetadataCASUpdate; import org.apache.druid.metadata.MetadataStorageConnectorConfig; import org.apache.druid.metadata.MetadataStorageTablesConfig; import org.apache.druid.metadata.SQLMetadataConnector; import org.postgresql.PGProperty; +import org.postgresql.util.PSQLException; import org.skife.jdbi.v2.DBI; import org.skife.jdbi.v2.Handle; +import org.skife.jdbi.v2.exceptions.CallbackFailedException; import org.skife.jdbi.v2.tweak.HandleCallback; import org.skife.jdbi.v2.util.StringMapper; import java.sql.DatabaseMetaData; import java.sql.SQLException; +import java.util.List; public class PostgreSQLConnector extends SQLMetadataConnector { @@ -42,6 +47,9 @@ public class PostgreSQLConnector extends SQLMetadataConnector private static final String PAYLOAD_TYPE = "BYTEA"; private static final String SERIAL_TYPE = "BIGSERIAL"; private static final String QUOTE_STRING = "\\\""; + private static final String PSQL_SERIALIZATION_FAILURE_MSG = + "ERROR: could not serialize access due to concurrent update"; + private static final String PSQL_SERIALIZATION_FAILURE_SQL_STATE = "40001"; public static final int DEFAULT_STREAMING_RESULT_SIZE = 100; private final DBI dbi; @@ -208,6 +216,45 @@ public Void withHandle(Handle handle) throws Exception ); } + @Override + public boolean compareAndSwap(List updates) + { + try { + return super.compareAndSwap(updates); + } + catch (CallbackFailedException cfe) { + Throwable root = Throwables.getRootCause(cfe); + if (checkRootCauseForPSQLSerializationFailure(root)) { + return false; + } else { + throw cfe; + } + } + } + + /** + * Used by compareAndSwap to check if the transaction was terminated because of concurrent updates. + * + * The parent implementation's compareAndSwap transaction has isolation level REPEATABLE_READ. + * In Postgres, such transactions will be canceled when another transaction commits a conflicting update: + * https://www.postgresql.org/docs/10/transaction-iso.html#XACT-REPEATABLE-READ + * + * When this occurs, we need to retry the transaction from the beginning: by returning false in compareAndSwap, + * the calling code will attempt retries. + */ + private boolean checkRootCauseForPSQLSerializationFailure( + Throwable root + ) + { + if (root instanceof PSQLException) { + PSQLException psqlException = (PSQLException) root; + return PSQL_SERIALIZATION_FAILURE_SQL_STATE.equals(psqlException.getSQLState()) && + PSQL_SERIALIZATION_FAILURE_MSG.equals(psqlException.getMessage()); + } else { + return false; + } + } + @Override public DBI getDBI() { diff --git a/server/src/main/java/org/apache/druid/metadata/SQLMetadataConnector.java b/server/src/main/java/org/apache/druid/metadata/SQLMetadataConnector.java index 2104c165d403..3d8433272156 100644 --- a/server/src/main/java/org/apache/druid/metadata/SQLMetadataConnector.java +++ b/server/src/main/java/org/apache/druid/metadata/SQLMetadataConnector.java @@ -455,6 +455,7 @@ public boolean compareAndSwap( ) { return getDBI().inTransaction( + TransactionIsolationLevel.REPEATABLE_READ, new TransactionCallback() { @Override @@ -467,7 +468,7 @@ public Boolean inTransaction(Handle handle, TransactionStatus transactionStatus) byte[] currentValue = handle .createQuery( StringUtils.format( - "SELECT %1$s FROM %2$s WHERE %3$s = :key", + "SELECT %1$s FROM %2$s WHERE %3$s = :key FOR UPDATE", update.getValueColumn(), update.getTableName(), update.getKeyColumn()