Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.CONNECT_ACK;
import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.DISCONNECTED;
import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.ESTABLISHED;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getAuthData;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getAuthMethod;
import static io.streamnative.pulsar.handlers.mqtt.utils.NettyUtils.ATTR_KEY_CONNECTION;
import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater;
import io.netty.channel.Channel;
Expand Down Expand Up @@ -246,6 +248,8 @@ public void sendConnAck() {
.topicAliasMaximum(clientRestrictions.getTopicAliasMaximum())
.cleanSession(clientRestrictions.isCleanSession())
.maximumQos(MqttQoS.AT_LEAST_ONCE.value())
.authMethod(getAuthMethod(connectMessage))
.authData(getAuthData(connectMessage))
.maximumPacketSize(getServerRestrictions().getMaximumPacketSize());
MqttProperties.StringProperty resInformation = (MqttProperties.StringProperty) connectMessage.variableHeader()
.properties().getProperty(MqttProperties.MqttPropertyType.RESPONSE_INFORMATION.value());
Expand All @@ -259,7 +263,6 @@ public void sendConnAck() {
}
}


public enum ConnectionState {
DISCONNECTED,
CONNECT_ACK,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_BASIC;
import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_TOKEN;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectPayload;
import io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -62,6 +64,19 @@ private Map<String, AuthenticationProvider> getAuthenticationProviders(List<Stri
return providers;
}

public AuthenticationResult authenticate(MqttConnectMessage connectMessage) {
String authMethod = MqttMessageUtils.getAuthMethod(connectMessage);
if (authMethod != null) {
byte[] authData = MqttMessageUtils.getAuthData(connectMessage);
if (authData == null) {
return AuthenticationResult.FAILED;
}
return authenticate(connectMessage.payload().clientIdentifier(), authMethod,
new AuthenticationDataCommand(new String(authData)));
}
return authenticate(connectMessage.payload());
}

public AuthenticationResult authenticate(MqttConnectPayload payload) {
String userRole = null;
boolean authenticated = false;
Expand All @@ -79,6 +94,25 @@ public AuthenticationResult authenticate(MqttConnectPayload payload) {
return new AuthenticationResult(authenticated, userRole);
}

private AuthenticationResult authenticate(String clientIdentifier,
String authMethod,
AuthenticationDataCommand command) {
AuthenticationProvider authenticationProvider = authenticationProviders.get(authMethod);
if (authenticationProvider == null) {
log.warn("Authentication failed, no authMethod : {} for CId={}", clientIdentifier, authMethod);
return AuthenticationResult.FAILED;
}
String userRole = null;
boolean authenticated = false;
try {
userRole = authenticationProvider.authenticate(command);
authenticated = true;
} catch (AuthenticationException e) {
log.warn("Authentication failed for CId={}", clientIdentifier);
}
return new AuthenticationResult(authenticated, userRole);
}

public AuthenticationDataSource getAuthData(String authMethod, MqttConnectPayload payload) {
switch (authMethod) {
case AUTH_BASIC:
Expand All @@ -94,6 +128,8 @@ public AuthenticationDataSource getAuthData(String authMethod, MqttConnectPayloa
@Getter
@RequiredArgsConstructor
public static class AuthenticationResult {

public static final AuthenticationResult FAILED = new AuthenticationResult(false, null);
private final boolean authenticated;
private final String userRole;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public final static class MqttConnectSuccessAckBuilder {

private int maximumPacketSize;

private String authMethod;

private byte[] authData;

public MqttConnectSuccessAckBuilder(int protocolVersion) {
this.protocolVersion = protocolVersion;
}
Expand Down Expand Up @@ -88,6 +92,16 @@ public MqttConnectSuccessAckBuilder maximumPacketSize(int maximumPacketSize) {
return this;
}

public MqttConnectSuccessAckBuilder authMethod(String authMethod) {
this.authMethod = authMethod;
return this;
}

public MqttConnectSuccessAckBuilder authData(byte[] authData) {
this.authData = authData;
return this;
}

public MqttAck build() {
MqttMessageBuilders.ConnAckBuilder commonBuilder = MqttMessageBuilders.connAck()
.sessionPresent(!cleanSession);
Expand Down Expand Up @@ -124,6 +138,17 @@ public MqttAck build() {
responseInformation);
properties.add(responseInformationProperty);
}
if (StringUtils.isNotEmpty(authMethod)) {
MqttProperties.StringProperty authMethodProperty =
new MqttProperties.StringProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value(),
authMethod);
properties.add(authMethodProperty);

MqttProperties.BinaryProperty authDataProperty =
new MqttProperties.BinaryProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value(),
authData);
properties.add(authDataProperty);
}
return MqttAck.createSupportedAck(
commonBuilder.returnCode(Mqtt5ConnReasonCode.SUCCESS.toConnectionReasonCode())
.properties(properties)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ public void processConnect(MqttAdapterMessage adapter) {
clientId, username);
}
} else {
MQTTAuthenticationService.AuthenticationResult authResult = authenticationService.authenticate(payload);
MQTTAuthenticationService.AuthenticationResult authResult = authenticationService
.authenticate(connectMessage);
if (authResult.isFailed()) {
MqttMessage mqttMessage = MqttConnectAck.errorBuilder().authFail(protocolVersion);
log.error("[CONNECT] Invalid or incorrect authentication. CId={}, username={}", clientId, username);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,18 @@ public static long getMessageExpiryInterval(MqttPublishMessage msg) {
MqttProperties.MqttPropertyType.PUBLICATION_EXPIRY_INTERVAL.value())
.stream().map(prop -> ((MqttProperties.IntegerProperty) prop).value()).findFirst().orElse(0);
}

public static String getAuthMethod(MqttConnectMessage connectMessage) {
MqttProperties properties = connectMessage.variableHeader().properties();
MqttProperties.StringProperty authMethodProperty = (MqttProperties.StringProperty) properties
.getProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value());
return authMethodProperty != null ? authMethodProperty.value() : null;
}

public static byte[] getAuthData(MqttConnectMessage connectMessage) {
MqttProperties properties = connectMessage.variableHeader().properties();
MqttProperties.BinaryProperty authDataProperty = (MqttProperties.BinaryProperty) properties
.getProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value());
return authDataProperty != null ? authDataProperty.value() : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,24 @@
*/
package io.streamnative.pulsar.handlers.mqtt.mqtt3.fusesource.base;

import com.hivemq.client.internal.shaded.org.jetbrains.annotations.NotNull;
import com.hivemq.client.mqtt.MqttGlobalPublishFilter;
import com.hivemq.client.mqtt.datatypes.MqttQos;
import com.hivemq.client.mqtt.datatypes.MqttUtf8String;
import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient;
import com.hivemq.client.mqtt.mqtt5.Mqtt5Client;
import com.hivemq.client.mqtt.mqtt5.Mqtt5ClientConfig;
import com.hivemq.client.mqtt.mqtt5.auth.Mqtt5EnhancedAuthMechanism;
import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5Auth;
import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5AuthBuilder;
import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5EnhancedAuthBuilder;
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5Connect;
import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAck;
import com.hivemq.client.mqtt.mqtt5.message.disconnect.Mqtt5Disconnect;
import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish;
import io.streamnative.pulsar.handlers.mqtt.base.BasicAuthenticationConfig;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;
import org.fusesource.mqtt.client.BlockingConnection;
import org.fusesource.mqtt.client.MQTT;
Expand Down Expand Up @@ -47,6 +64,92 @@ public void testAuthenticate() throws Exception {
connection.disconnect();
}

@Test(timeOut = TIMEOUT)
public void testAuthenticateWithAuthMethod() throws Exception {
String topic = "persistent://public/default/testAuthenticateWithAuthMethod";
Mqtt5BlockingClient client1 = Mqtt5Client.builder()
.identifier("abc")
.serverHost("127.0.0.1")
.serverPort(getMqttBrokerPortList().get(0))
.buildBlocking();
client1.connectWith().enhancedAuth(new Mqtt5EnhancedAuthMechanism() {
@Override
public @NotNull MqttUtf8String getMethod() {
return MqttUtf8String.of("basic");
}

@Override
public int getTimeout() {
return 10;
}

@Override
public @NotNull CompletableFuture<Void> onAuth(@NotNull Mqtt5ClientConfig clientConfig,
@NotNull Mqtt5Connect connect,
@NotNull Mqtt5EnhancedAuthBuilder authBuilder) {
authBuilder.data("superUser:supepass".getBytes(StandardCharsets.UTF_8));
return CompletableFuture.completedFuture(null);
}

@Override
public @NotNull CompletableFuture<Void> onReAuth(@NotNull Mqtt5ClientConfig clientConfig,
@NotNull Mqtt5AuthBuilder authBuilder) {
return CompletableFuture.completedFuture(null);
}

@Override
public @NotNull CompletableFuture<Boolean> onContinue(@NotNull Mqtt5ClientConfig clientConfig,
@NotNull Mqtt5Auth auth,
@NotNull Mqtt5AuthBuilder authBuilder) {
return CompletableFuture.completedFuture(false);
}

@Override
public @NotNull CompletableFuture<Boolean> onAuthSuccess(@NotNull Mqtt5ClientConfig clientConfig,
@NotNull Mqtt5ConnAck connAck) {
return CompletableFuture.completedFuture(true);
}

@Override
public @NotNull CompletableFuture<Boolean> onReAuthSuccess(@NotNull Mqtt5ClientConfig clientConfig,
@NotNull Mqtt5Auth auth) {
return null;
}

@Override
public void onAuthRejected(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Mqtt5ConnAck connAck) {
//NOP
}

@Override
public void onReAuthRejected(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Mqtt5Disconnect disconnect) {
//NOP
}

@Override
public void onAuthError(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Throwable cause) {
//NOP
}

@Override
public void onReAuthError(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Throwable cause) {
//NOP
}
}).send();
Mqtt5Publish publishMessage = Mqtt5Publish.builder().topic(topic)
.qos(MqttQos.AT_LEAST_ONCE).build();
client1.subscribeWith()
.topicFilter(topic)
.qos(MqttQos.AT_LEAST_ONCE)
.send();
Mqtt5BlockingClient.Mqtt5Publishes publishes = client1.publishes(MqttGlobalPublishFilter.ALL);
client1.publish(publishMessage);
Mqtt5Publish message = publishes.receive();
Assert.assertNotNull(message);
publishes.close();
client1.disconnect();
}

@Test(expectedExceptions = {MQTTException.class}, timeOut = TIMEOUT)
public void testNoAuthenticated() throws Exception {
MQTT mqtt = createMQTTClient();
Expand Down