Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,10 @@ public class Protocol {
}
}

public static boolean apiVersionSupported(short apiKey, short apiVersion) {
return apiKey < CURR_VERSION.length && apiVersion >= MIN_VERSIONS[apiKey] && apiVersion <= CURR_VERSION[apiKey];
}

private static String indentString(int size) {
StringBuilder b = new StringBuilder(size);
for (int i = 0; i < size; i++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@
import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.errors.UnsupportedVersionException;
import org.apache.kafka.common.network.Authenticator;
import org.apache.kafka.common.network.Mode;
import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.TransportLayer;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.Errors;
import org.apache.kafka.common.protocol.Protocol;
import org.apache.kafka.common.protocol.types.SchemaException;
import org.apache.kafka.common.requests.AbstractRequest;
import org.apache.kafka.common.requests.AbstractRequestResponse;
Expand All @@ -75,7 +77,7 @@ public class SaslServerAuthenticator implements Authenticator {
private static final Logger LOG = LoggerFactory.getLogger(SaslServerAuthenticator.class);

public enum SaslState {
HANDSHAKE_REQUEST, AUTHENTICATE, COMPLETE, FAILED
GSSAPI_OR_HANDSHAKE_REQUEST, HANDSHAKE_REQUEST, AUTHENTICATE, COMPLETE, FAILED
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I am not quite sure what the new state GSSAPI_OR_HANDSHAKE_REQUEST is for. It's making the same call as the HANDSHAKE_REQUEST and there is no code to change saslState to GSSAPI_OR_HANDSHAKE_REQUEST.

Copy link
Copy Markdown
Contributor Author

@rajinisivaram rajinisivaram May 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrao Oops, sorry, missed out the change of initiate state. In HANDSHAKE_REQUEST state, any invalid request is treated as a GSSAPI token. This made sense when only a SaslHandshakeRequest was expected prior to SASL tokens. Since SaslHandshakeRequest can now be preceded by any number of ApiVersionsRequests, I think it makes sense to revert to 0.9.0.x GSSAPI authentication only for the first token. GSSAPI_OR_HANDSHAKE_REQUEST is the initiate state which checks that. Have updated the PR.

}

private final String node;
Expand All @@ -85,7 +87,7 @@ public enum SaslState {
private final String host;

// Current SASL state
private SaslState saslState = SaslState.HANDSHAKE_REQUEST;
private SaslState saslState = SaslState.GSSAPI_OR_HANDSHAKE_REQUEST;
// Next SASL state to be set when outgoing writes associated with the current SASL state complete
private SaslState pendingSaslState = null;
private SaslServer saslServer;
Expand Down Expand Up @@ -215,6 +217,9 @@ public void authenticate() throws IOException {
try {
switch (saslState) {
case HANDSHAKE_REQUEST:
handleKafkaRequest(clientToken);
break;
case GSSAPI_OR_HANDSHAKE_REQUEST:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this case to the first one after switch since it's the initial state?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrao We are relying on fall-through for the second case, so it's not straightforward to make that change.

if (handleKafkaRequest(clientToken))
break;
// For default GSSAPI, fall through to authenticate using the client token as the first GSSAPI packet.
Expand Down Expand Up @@ -288,39 +293,53 @@ private boolean handleKafkaRequest(byte[] requestBytes) throws IOException, Auth
try {
ByteBuffer requestBuffer = ByteBuffer.wrap(requestBytes);
RequestHeader requestHeader = RequestHeader.parse(requestBuffer);
AbstractRequest request = AbstractRequest.getRequest(requestHeader.apiKey(), requestHeader.apiVersion(), requestBuffer);
ApiKeys apiKey = ApiKeys.forId(requestHeader.apiKey());
// A valid Kafka request header was received. SASL authentication tokens are now expected only
// following a SaslHandshakeRequest since this is not a GSSAPI client token from a Kafka 0.9.0.x client.
setSaslState(SaslState.HANDSHAKE_REQUEST);
isKafkaRequest = true;

ApiKeys apiKey = ApiKeys.forId(requestHeader.apiKey());
LOG.debug("Handle Kafka request {}", apiKey);
switch (apiKey) {
case API_VERSIONS:
handleApiVersionsRequest(requestHeader, (ApiVersionsRequest) request);
break;
case SASL_HANDSHAKE:
clientMechanism = handleHandshakeRequest(requestHeader, (SaslHandshakeRequest) request);
break;
default:
throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake.");
if (!Protocol.apiVersionSupported(requestHeader.apiKey(), requestHeader.apiVersion())) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requestHeader.apiKey() can just be apiKey.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apiKey is of type ApiKeys while requestHeader.apiKey() returns a short.

if (apiKey == ApiKeys.API_VERSIONS)
sendKafkaResponse(requestHeader, ApiVersionsResponse.fromError(Errors.UNSUPPORTED_VERSION));
else
throw new UnsupportedVersionException("Version " + requestHeader.apiVersion() + " is not supported for apiKey " + apiKey);
} else {
AbstractRequest request = AbstractRequest.getRequest(requestHeader.apiKey(), requestHeader.apiVersion(), requestBuffer);

LOG.debug("Handle Kafka request {}", apiKey);
switch (apiKey) {
case API_VERSIONS:
handleApiVersionsRequest(requestHeader, (ApiVersionsRequest) request);
break;
case SASL_HANDSHAKE:
clientMechanism = handleHandshakeRequest(requestHeader, (SaslHandshakeRequest) request);
break;
default:
throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake.");
}
}
} catch (SchemaException | IllegalArgumentException e) {
// SchemaException is thrown if the request is not in Kafka format. IIlegalArgumentException is thrown
// if the API key is invalid. For compatibility with 0.9.0.x where the first packet is a GSSAPI token
// starting with 0x60, revert to GSSAPI for both these exceptions.
if (LOG.isDebugEnabled()) {
StringBuilder tokenBuilder = new StringBuilder();
for (byte b : requestBytes) {
tokenBuilder.append(String.format("%02x", b));
if (tokenBuilder.length() >= 20)
break;
if (saslState == SaslState.GSSAPI_OR_HANDSHAKE_REQUEST) {
// SchemaException is thrown if the request is not in Kafka format. IIlegalArgumentException is thrown
// if the API key is invalid. For compatibility with 0.9.0.x where the first packet is a GSSAPI token
// starting with 0x60, revert to GSSAPI for both these exceptions.
if (LOG.isDebugEnabled()) {
StringBuilder tokenBuilder = new StringBuilder();
for (byte b : requestBytes) {
tokenBuilder.append(String.format("%02x", b));
if (tokenBuilder.length() >= 20)
break;
}
LOG.debug("Received client packet of length {} starting with bytes 0x{}, process as GSSAPI packet", requestBytes.length, tokenBuilder);
}
LOG.debug("Received client packet of length {} starting with bytes 0x{}, process as GSSAPI packet", requestBytes.length, tokenBuilder);
}
if (enabledMechanisms.contains(SaslConfigs.GSSAPI_MECHANISM)) {
LOG.debug("First client packet is not a SASL mechanism request, using default mechanism GSSAPI");
clientMechanism = SaslConfigs.GSSAPI_MECHANISM;
if (enabledMechanisms.contains(SaslConfigs.GSSAPI_MECHANISM)) {
LOG.debug("First client packet is not a SASL mechanism request, using default mechanism GSSAPI");
clientMechanism = SaslConfigs.GSSAPI_MECHANISM;
} else
throw new UnsupportedSaslMechanismException("Exception handling first SASL packet from client, GSSAPI is not supported by server", e);
} else
throw new UnsupportedSaslMechanismException("Exception handling first SASL packet from client, GSSAPI is not supported by server", e);
throw e;
}
if (clientMechanism != null) {
createSaslServer(clientMechanism);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.kafka.common.requests.MetadataRequest;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.requests.RequestSend;
import org.apache.kafka.common.requests.ResponseHeader;
import org.apache.kafka.common.requests.SaslHandshakeRequest;
import org.apache.kafka.common.requests.SaslHandshakeResponse;
import org.apache.kafka.common.security.JaasUtils;
Expand Down Expand Up @@ -243,6 +244,62 @@ public void testUnauthenticatedApiVersionsRequestOverSsl() throws Exception {
testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_SSL);
}

/**
* Tests that unsupported version of ApiVersionsRequest before SASL handshake request
* returns error response and does not result in authentication failure. This test
* is similar to {@link #testUnauthenticatedApiVersionsRequest(SecurityProtocol)}
* where a non-SASL client is used to send requests that are processed by
* {@link SaslServerAuthenticator} of the server prior to client authentication.
*/
@Test
public void testApiVersionsRequestWithUnsupportedVersion() throws Exception {
SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
server = NetworkTestUtils.createEchoServer(securityProtocol, saslServerConfigs);

// Send ApiVersionsRequest with unsupported version and validate error response.
String node = "1";
createClientConnection(SecurityProtocol.PLAINTEXT, node);
RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS.id, Short.MAX_VALUE, "someclient", 1);
selector.send(new NetworkSend(node, RequestSend.serialize(header, new ApiVersionsRequest().toStruct())));
ByteBuffer responseBuffer = waitForResponse();
ResponseHeader.parse(responseBuffer);
ApiVersionsResponse response = ApiVersionsResponse.parse(responseBuffer);
assertEquals(Errors.UNSUPPORTED_VERSION.code(), response.errorCode());

// Send ApiVersionsRequest with a supported version. This should succeed.
sendVersionRequestReceiveResponse(node);

// Test that client can authenticate successfully
sendHandshakeRequestReceiveResponse(node);
authenticateUsingSaslPlainAndCheckConnection(node);
}

/**
* Tests that unsupported version of SASL handshake request returns error
* response and fails authentication. This test is similar to
* {@link #testUnauthenticatedApiVersionsRequest(SecurityProtocol)}
* where a non-SASL client is used to send requests that are processed by
* {@link SaslServerAuthenticator} of the server prior to client authentication.
*/
@Test
public void testSaslHandshakeRequestWithUnsupportedVersion() throws Exception {
SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT;
configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
server = NetworkTestUtils.createEchoServer(securityProtocol, saslServerConfigs);

// Send ApiVersionsRequest and validate error response.
String node1 = "invalid1";
createClientConnection(SecurityProtocol.PLAINTEXT, node1);
RequestHeader header = new RequestHeader(ApiKeys.SASL_HANDSHAKE.id, Short.MAX_VALUE, "someclient", 2);
selector.send(new NetworkSend(node1, RequestSend.serialize(header, new SaslHandshakeRequest("PLAIN").toStruct())));
NetworkTestUtils.waitForChannelClose(selector, node1);
selector.close();

// Test good connection still works
createAndCheckClientConnection(securityProtocol, "good1");
}

/**
* Tests that any invalid data during Kafka SASL handshake request flow
* or the actual SASL authentication flow result in authentication failure
Expand Down Expand Up @@ -485,6 +542,11 @@ private void testUnauthenticatedApiVersionsRequest(SecurityProtocol securityProt
SaslHandshakeResponse handshakeResponse = sendHandshakeRequestReceiveResponse(node);
assertEquals(Collections.singletonList("PLAIN"), handshakeResponse.enabledMechanisms());

// Complete manual authentication and check send/receive succeed
authenticateUsingSaslPlainAndCheckConnection(node);
}

private void authenticateUsingSaslPlainAndCheckConnection(String node) throws Exception {
// Authenticate using PLAIN username/password
String authString = "\u0000" + TestJaasConfig.USERNAME + "\u0000" + TestJaasConfig.PASSWORD;
selector.send(new NetworkSend(node, ByteBuffer.wrap(authString.getBytes("UTF-8"))));
Expand Down
13 changes: 9 additions & 4 deletions core/src/main/scala/kafka/network/RequestChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import kafka.metrics.KafkaMetricsGroup
import kafka.utils.{Logging, SystemTime}
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.network.Send
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
import org.apache.kafka.common.requests.{RequestSend, ProduceRequest, AbstractRequest, RequestHeader}
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol, Protocol}
import org.apache.kafka.common.requests.{RequestSend, ProduceRequest, AbstractRequest, RequestHeader, ApiVersionsRequest}
import org.apache.kafka.common.security.auth.KafkaPrincipal
import org.apache.log4j.Logger

Expand Down Expand Up @@ -84,8 +84,13 @@ object RequestChannel extends Logging {
null
val body: AbstractRequest =
if (requestObj == null)
try AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer)
catch {
try {
// For unsupported version of ApiVersionsRequest, create a dummy request to enable an error response to be returned later
if (header.apiKey == ApiKeys.API_VERSIONS.id && !Protocol.apiVersionSupported(header.apiKey, header.apiVersion))
new ApiVersionsRequest
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to just respond with the error from here instead of creating this dummy request. But let's see what @junrao says.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the logic here is a bit awkward, but probably the simplest. Sending a response directly from SocketServer seems to require more work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ijuma @junrao Thank you both for the reviews. I took a look at changing this, but left it as is to keep the code change small.

else
AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer)
} catch {
case ex: Throwable =>
throw new InvalidRequestException(s"Error getting request for apiKey: ${header.apiKey} and apiVersion: ${header.apiVersion}", ex)
}
Expand Down
4 changes: 1 addition & 3 deletions core/src/main/scala/kafka/server/KafkaApis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1028,9 +1028,7 @@ class KafkaApis(val requestChannel: RequestChannel,
// with client authentication which is performed at an earlier stage of the connection where the
// ApiVersionRequest is not available.
val responseHeader = new ResponseHeader(request.header.correlationId)
val isApiVersionsRequestVersionSupported = request.header.apiVersion <= Protocol.CURR_VERSION(ApiKeys.API_VERSIONS.id) &&
request.header.apiVersion >= Protocol.MIN_VERSIONS(ApiKeys.API_VERSIONS.id)
val responseBody = if (isApiVersionsRequestVersionSupported)
val responseBody = if (Protocol.apiVersionSupported(ApiKeys.API_VERSIONS.id, request.header.apiVersion))
ApiVersionsResponse.apiVersionsResponse
else
ApiVersionsResponse.fromError(Errors.UNSUPPORTED_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package kafka.server

import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.protocol.{ApiKeys, Errors}
import org.apache.kafka.common.requests.ApiVersionsResponse.ApiVersion
import org.apache.kafka.common.requests.{ApiVersionsRequest, ApiVersionsResponse}
import org.junit.Assert._
Expand Down Expand Up @@ -48,6 +48,12 @@ class ApiVersionsRequestTest extends BaseRequestTest {
ApiVersionsRequestTest.validateApiVersionsResponse(apiVersionsResponse)
}

@Test
def testApiVersionsRequestWithUnsupportedVersion() {
val apiVersionsResponse = sendApiVersionsRequest(new ApiVersionsRequest, Short.MaxValue)
assertEquals(Errors.UNSUPPORTED_VERSION.code(), apiVersionsResponse.errorCode)
}

private def sendApiVersionsRequest(request: ApiVersionsRequest, version: Short): ApiVersionsResponse = {
val response = send(request, ApiKeys.API_VERSIONS, version)
ApiVersionsResponse.parse(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package kafka.server
import java.io.IOException
import java.net.Socket
import java.util.Collections
import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
import org.apache.kafka.common.protocol.{ApiKeys, Errors, SecurityProtocol}
import org.apache.kafka.common.requests.{ApiVersionsRequest, ApiVersionsResponse}
import org.apache.kafka.common.requests.SaslHandshakeRequest
import org.apache.kafka.common.requests.SaslHandshakeResponse
import org.apache.kafka.common.protocol.Errors
import org.junit.Test
import org.junit.Assert._
import kafka.api.SaslTestHarness
Expand Down Expand Up @@ -64,6 +63,20 @@ class SaslApiVersionsRequestTest extends BaseRequestTest with SaslTestHarness {
}
}

@Test
def testApiVersionsRequestWithUnsupportedVersion() {
val plaintextSocket = connect(protocol = securityProtocol)
try {
val apiVersionsResponse = sendApiVersionsRequest(plaintextSocket, new ApiVersionsRequest, Short.MaxValue)
assertEquals(Errors.UNSUPPORTED_VERSION.code(), apiVersionsResponse.errorCode)
val apiVersionsResponse2 = sendApiVersionsRequest(plaintextSocket, new ApiVersionsRequest, 0)
ApiVersionsRequestTest.validateApiVersionsResponse(apiVersionsResponse2)
sendSaslHandshakeRequestValidateResponse(plaintextSocket)
} finally {
plaintextSocket.close()
}
}

private def sendApiVersionsRequest(socket: Socket, request: ApiVersionsRequest, version: Short): ApiVersionsResponse = {
val response = send(socket, request, ApiKeys.API_VERSIONS, version)
ApiVersionsResponse.parse(response)
Expand Down