Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.metadata.MetadataStorageTablesConfig;
import org.apache.druid.metadata.TestDerbyConnector;
import org.apache.druid.security.basic.BasicAuthCommonCacheConfig;
Expand Down Expand Up @@ -53,9 +55,14 @@

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.Response;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

public class CoordinatorBasicAuthorizerResourceTest
{
Expand Down Expand Up @@ -581,6 +588,129 @@ public void testUsersRolesAndPerms()
Assert.assertEquals(expectedRoleFull2, response.getEntity());
}

@Test
public void testConcurrentUpdate()
{
final int testMultiple = 100;

// setup a user and the roles
Response response = resource.createUser(req, AUTHORIZER_NAME, "druid");
Assert.assertEquals(200, response.getStatus());

List<ResourceAction> perms = ImmutableList.of(
new ResourceAction(new Resource("A", ResourceType.DATASOURCE), Action.READ),
new ResourceAction(new Resource("B", ResourceType.DATASOURCE), Action.WRITE),
new ResourceAction(new Resource("C", ResourceType.CONFIG), Action.WRITE)
);

for (int i = 0; i < testMultiple; i++) {
String roleName = "druidRole-" + i;
response = resource.createRole(req, AUTHORIZER_NAME, roleName);
Assert.assertEquals(200, response.getStatus());

response = resource.setRolePermissions(req, AUTHORIZER_NAME, roleName, perms);
Assert.assertEquals(200, response.getStatus());
}

ExecutorService exec = Execs.multiThreaded(testMultiple, "thread---");
int[] responseCodesAssign = new int[testMultiple];

// assign 'testMultiple' roles to the user concurrently
List<Callable<Void>> addRoleCallables = new ArrayList<>();
for (int i = 0; i < testMultiple; i++) {
final int innerI = i;
String roleName = "druidRole-" + i;
addRoleCallables.add(
new Callable<Void>()
{
@Override
public Void call() throws Exception
{
Response response = resource.assignRoleToUser(req, AUTHORIZER_NAME, "druid", roleName);
responseCodesAssign[innerI] = response.getStatus();
return null;
}
}
);
}
try {
List<Future<Void>> futures = exec.invokeAll(addRoleCallables);
for (Future future : futures) {
future.get();
}
}
catch (Exception e) {
throw new RuntimeException(e);
}

// the API can return !200 if the update attempt fails by exhausting retries because of
// too much contention from other conflicting requests, make sure that we don't get any successful requests
// that didn't actually take effect
Set<String> roleNames = getRoleNamesAssignedToUser("druid");
for (int i = 0; i < testMultiple; i++) {
String roleName = "druidRole-" + i;
if (responseCodesAssign[i] == 200 && !roleNames.contains(roleName)) {
Assert.fail(
StringUtils.format("Got response status 200 for assigning role [%s] but user did not have role.", roleName)
);
}
}

// Now unassign the roles concurrently
List<Callable<Void>> removeRoleCallables = new ArrayList<>();
int[] responseCodesRemove = new int[testMultiple];

for (int i = 0; i < testMultiple; i++) {
final int innerI = i;
String roleName = "druidRole-" + i;
removeRoleCallables.add(
new Callable<Void>()
{
@Override
public Void call() throws Exception
{
Response response = resource.unassignRoleFromUser(req, AUTHORIZER_NAME, "druid", roleName);
responseCodesRemove[innerI] = response.getStatus();
return null;
}
}
);
}
try {
List<Future<Void>> futures = exec.invokeAll(removeRoleCallables);
for (Future future : futures) {
future.get();
}
}
catch (Exception e) {
throw new RuntimeException(e);
}

roleNames = getRoleNamesAssignedToUser("druid");
for (int i = 0; i < testMultiple; i++) {
String roleName = "druidRole-" + i;
if (responseCodesRemove[i] == 200 && roleNames.contains(roleName)) {
Assert.fail(
StringUtils.format("Got response status 200 for removing role [%s] but user still has role.", roleName)
);
}
}
}

private Set<String> getRoleNamesAssignedToUser(
String user
)
{
Response response = resource.getUser(req, AUTHORIZER_NAME, user, "");
Assert.assertEquals(200, response.getStatus());
BasicAuthorizerUserFull userFull = (BasicAuthorizerUserFull) response.getEntity();
Set<String> roleNames = new HashSet<>();
for (BasicAuthorizerRole role : userFull.getRoles()) {
roleNames.add(role.getName());
}
return roleNames;
}

private static Map<String, String> errorMapWithMsg(String errorMsg)
{
return ImmutableMap.of("error", errorMsg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,36 @@
package org.apache.druid.metadata.storage.postgresql;

import com.google.common.base.Supplier;
import com.google.common.base.Throwables;
import com.google.inject.Inject;
import org.apache.commons.dbcp2.BasicDataSource;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.metadata.MetadataCASUpdate;
import org.apache.druid.metadata.MetadataStorageConnectorConfig;
import org.apache.druid.metadata.MetadataStorageTablesConfig;
import org.apache.druid.metadata.SQLMetadataConnector;
import org.postgresql.PGProperty;
import org.postgresql.util.PSQLException;
import org.skife.jdbi.v2.DBI;
import org.skife.jdbi.v2.Handle;
import org.skife.jdbi.v2.exceptions.CallbackFailedException;
import org.skife.jdbi.v2.tweak.HandleCallback;
import org.skife.jdbi.v2.util.StringMapper;

import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.util.List;

public class PostgreSQLConnector extends SQLMetadataConnector
{
private static final Logger log = new Logger(PostgreSQLConnector.class);
private static final String PAYLOAD_TYPE = "BYTEA";
private static final String SERIAL_TYPE = "BIGSERIAL";
private static final String QUOTE_STRING = "\\\"";
private static final String PSQL_SERIALIZATION_FAILURE_MSG =
"ERROR: could not serialize access due to concurrent update";
private static final String PSQL_SERIALIZATION_FAILURE_SQL_STATE = "40001";
public static final int DEFAULT_STREAMING_RESULT_SIZE = 100;

private final DBI dbi;
Expand Down Expand Up @@ -208,6 +216,45 @@ public Void withHandle(Handle handle) throws Exception
);
}

@Override
public boolean compareAndSwap(List<MetadataCASUpdate> updates)
{
try {
return super.compareAndSwap(updates);
}
catch (CallbackFailedException cfe) {
Throwable root = Throwables.getRootCause(cfe);
if (checkRootCauseForPSQLSerializationFailure(root)) {
return false;
} else {
throw cfe;
}
}
}

/**
* Used by compareAndSwap to check if the transaction was terminated because of concurrent updates.
*
* The parent implementation's compareAndSwap transaction has isolation level REPEATABLE_READ.
* In Postgres, such transactions will be canceled when another transaction commits a conflicting update:
* https://www.postgresql.org/docs/10/transaction-iso.html#XACT-REPEATABLE-READ
*
* When this occurs, we need to retry the transaction from the beginning: by returning false in compareAndSwap,
* the calling code will attempt retries.
*/
private boolean checkRootCauseForPSQLSerializationFailure(
Throwable root
)
{
if (root instanceof PSQLException) {
PSQLException psqlException = (PSQLException) root;
return PSQL_SERIALIZATION_FAILURE_SQL_STATE.equals(psqlException.getSQLState()) &&
PSQL_SERIALIZATION_FAILURE_MSG.equals(psqlException.getMessage());
} else {
return false;
}
}

@Override
public DBI getDBI()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ public boolean compareAndSwap(
)
{
return getDBI().inTransaction(
TransactionIsolationLevel.REPEATABLE_READ,
new TransactionCallback<Boolean>()
{
@Override
Expand All @@ -467,7 +468,7 @@ public Boolean inTransaction(Handle handle, TransactionStatus transactionStatus)
byte[] currentValue = handle
.createQuery(
StringUtils.format(
"SELECT %1$s FROM %2$s WHERE %3$s = :key",
"SELECT %1$s FROM %2$s WHERE %3$s = :key FOR UPDATE",
update.getValueColumn(),
update.getTableName(),
update.getKeyColumn()
Expand Down