diff --git a/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/Connection.java b/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/Connection.java index eabd01e23..d7a9b2bb7 100644 --- a/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/Connection.java +++ b/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/Connection.java @@ -16,6 +16,8 @@ import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.CONNECT_ACK; import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.DISCONNECTED; import static io.streamnative.pulsar.handlers.mqtt.Connection.ConnectionState.ESTABLISHED; +import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getAuthData; +import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.getAuthMethod; import static io.streamnative.pulsar.handlers.mqtt.utils.NettyUtils.ATTR_KEY_CONNECTION; import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; import io.netty.channel.Channel; @@ -246,6 +248,8 @@ public void sendConnAck() { .topicAliasMaximum(clientRestrictions.getTopicAliasMaximum()) .cleanSession(clientRestrictions.isCleanSession()) .maximumQos(MqttQoS.AT_LEAST_ONCE.value()) + .authMethod(getAuthMethod(connectMessage)) + .authData(getAuthData(connectMessage)) .maximumPacketSize(getServerRestrictions().getMaximumPacketSize()); MqttProperties.StringProperty resInformation = (MqttProperties.StringProperty) connectMessage.variableHeader() .properties().getProperty(MqttProperties.MqttPropertyType.RESPONSE_INFORMATION.value()); @@ -259,7 +263,6 @@ public void sendConnAck() { } } - public enum ConnectionState { DISCONNECTED, CONNECT_ACK, diff --git a/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/MQTTAuthenticationService.java b/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/MQTTAuthenticationService.java index e542f7c33..52f91be6d 100644 --- a/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/MQTTAuthenticationService.java +++ b/mqtt-impl/src/main/java/io/streamnative/pulsar/handlers/mqtt/MQTTAuthenticationService.java @@ -15,7 +15,9 @@ import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_BASIC; import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_TOKEN; +import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttConnectPayload; +import io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -62,6 +64,19 @@ private Map getAuthenticationProviders(List ((MqttProperties.IntegerProperty) prop).value()).findFirst().orElse(0); } + + public static String getAuthMethod(MqttConnectMessage connectMessage) { + MqttProperties properties = connectMessage.variableHeader().properties(); + MqttProperties.StringProperty authMethodProperty = (MqttProperties.StringProperty) properties + .getProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_METHOD.value()); + return authMethodProperty != null ? authMethodProperty.value() : null; + } + + public static byte[] getAuthData(MqttConnectMessage connectMessage) { + MqttProperties properties = connectMessage.variableHeader().properties(); + MqttProperties.BinaryProperty authDataProperty = (MqttProperties.BinaryProperty) properties + .getProperty(MqttProperties.MqttPropertyType.AUTHENTICATION_DATA.value()); + return authDataProperty != null ? authDataProperty.value() : null; + } } diff --git a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/BasicAuthenticationTest.java b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/BasicAuthenticationTest.java index 17c7b83e7..307ae1031 100644 --- a/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/BasicAuthenticationTest.java +++ b/tests/src/test/java/io/streamnative/pulsar/handlers/mqtt/mqtt3/fusesource/base/BasicAuthenticationTest.java @@ -13,7 +13,24 @@ */ package io.streamnative.pulsar.handlers.mqtt.mqtt3.fusesource.base; +import com.hivemq.client.internal.shaded.org.jetbrains.annotations.NotNull; +import com.hivemq.client.mqtt.MqttGlobalPublishFilter; +import com.hivemq.client.mqtt.datatypes.MqttQos; +import com.hivemq.client.mqtt.datatypes.MqttUtf8String; +import com.hivemq.client.mqtt.mqtt5.Mqtt5BlockingClient; +import com.hivemq.client.mqtt.mqtt5.Mqtt5Client; +import com.hivemq.client.mqtt.mqtt5.Mqtt5ClientConfig; +import com.hivemq.client.mqtt.mqtt5.auth.Mqtt5EnhancedAuthMechanism; +import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5Auth; +import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5AuthBuilder; +import com.hivemq.client.mqtt.mqtt5.message.auth.Mqtt5EnhancedAuthBuilder; +import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5Connect; +import com.hivemq.client.mqtt.mqtt5.message.connect.connack.Mqtt5ConnAck; +import com.hivemq.client.mqtt.mqtt5.message.disconnect.Mqtt5Disconnect; +import com.hivemq.client.mqtt.mqtt5.message.publish.Mqtt5Publish; import io.streamnative.pulsar.handlers.mqtt.base.BasicAuthenticationConfig; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.CompletableFuture; import lombok.extern.slf4j.Slf4j; import org.fusesource.mqtt.client.BlockingConnection; import org.fusesource.mqtt.client.MQTT; @@ -47,6 +64,92 @@ public void testAuthenticate() throws Exception { connection.disconnect(); } + @Test(timeOut = TIMEOUT) + public void testAuthenticateWithAuthMethod() throws Exception { + String topic = "persistent://public/default/testAuthenticateWithAuthMethod"; + Mqtt5BlockingClient client1 = Mqtt5Client.builder() + .identifier("abc") + .serverHost("127.0.0.1") + .serverPort(getMqttBrokerPortList().get(0)) + .buildBlocking(); + client1.connectWith().enhancedAuth(new Mqtt5EnhancedAuthMechanism() { + @Override + public @NotNull MqttUtf8String getMethod() { + return MqttUtf8String.of("basic"); + } + + @Override + public int getTimeout() { + return 10; + } + + @Override + public @NotNull CompletableFuture onAuth(@NotNull Mqtt5ClientConfig clientConfig, + @NotNull Mqtt5Connect connect, + @NotNull Mqtt5EnhancedAuthBuilder authBuilder) { + authBuilder.data("superUser:supepass".getBytes(StandardCharsets.UTF_8)); + return CompletableFuture.completedFuture(null); + } + + @Override + public @NotNull CompletableFuture onReAuth(@NotNull Mqtt5ClientConfig clientConfig, + @NotNull Mqtt5AuthBuilder authBuilder) { + return CompletableFuture.completedFuture(null); + } + + @Override + public @NotNull CompletableFuture onContinue(@NotNull Mqtt5ClientConfig clientConfig, + @NotNull Mqtt5Auth auth, + @NotNull Mqtt5AuthBuilder authBuilder) { + return CompletableFuture.completedFuture(false); + } + + @Override + public @NotNull CompletableFuture onAuthSuccess(@NotNull Mqtt5ClientConfig clientConfig, + @NotNull Mqtt5ConnAck connAck) { + return CompletableFuture.completedFuture(true); + } + + @Override + public @NotNull CompletableFuture onReAuthSuccess(@NotNull Mqtt5ClientConfig clientConfig, + @NotNull Mqtt5Auth auth) { + return null; + } + + @Override + public void onAuthRejected(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Mqtt5ConnAck connAck) { + //NOP + } + + @Override + public void onReAuthRejected(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Mqtt5Disconnect disconnect) { + //NOP + } + + @Override + public void onAuthError(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Throwable cause) { + //NOP + } + + @Override + public void onReAuthError(@NotNull Mqtt5ClientConfig clientConfig, @NotNull Throwable cause) { + //NOP + } + }).send(); + Mqtt5Publish publishMessage = Mqtt5Publish.builder().topic(topic) + .qos(MqttQos.AT_LEAST_ONCE).build(); + client1.subscribeWith() + .topicFilter(topic) + .qos(MqttQos.AT_LEAST_ONCE) + .send(); + Mqtt5BlockingClient.Mqtt5Publishes publishes = client1.publishes(MqttGlobalPublishFilter.ALL); + client1.publish(publishMessage); + Mqtt5Publish message = publishes.receive(); + Assert.assertNotNull(message); + publishes.close(); + client1.disconnect(); + } + @Test(expectedExceptions = {MQTTException.class}, timeOut = TIMEOUT) public void testNoAuthenticated() throws Exception { MQTT mqtt = createMQTTClient();