diff --git a/src/AsyncEventSource.cpp b/src/AsyncEventSource.cpp index f2914df54..51a84f031 100644 --- a/src/AsyncEventSource.cpp +++ b/src/AsyncEventSource.cpp @@ -263,6 +263,10 @@ void AsyncEventSource::onConnect(ArEventHandlerFunction cb){ _connectcb = cb; } +void AsyncEventSource::onHandshake(ArHandshakeHandlerFunction cb){ + _handshakecb = cb; +} + void AsyncEventSource::_addClient(AsyncEventSourceClient * client){ /*char * temp = (char *)malloc(2054); if(temp != NULL){ @@ -333,13 +337,27 @@ bool AsyncEventSource::canHandle(AsyncWebServerRequest *request){ return false; } request->addInterestingHeader("Last-Event-ID"); + request->addInterestingHeader("Cookie"); return true; } void AsyncEventSource::handleRequest(AsyncWebServerRequest *request){ if((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())) return request->requestAuthentication(); - request->send(new AsyncEventSourceResponse(this)); + + // If Custom Handshake Handler is supplied + if(_handshakecb != nullptr){ + if(_handshakecb(request)){ + // Request Accepted + request->send(new AsyncEventSourceResponse(this)); + }else{ + // Request Rejected. Supply unauthorised http response. + request->send(401); + } + }else{ + // No Custom Handshake Handler Supplied. Accept as default action. + request->send(new AsyncEventSourceResponse(this)); + } } // Response diff --git a/src/AsyncEventSource.h b/src/AsyncEventSource.h index b097fa623..8d547aec5 100644 --- a/src/AsyncEventSource.h +++ b/src/AsyncEventSource.h @@ -49,6 +49,7 @@ class AsyncEventSource; class AsyncEventSourceResponse; class AsyncEventSourceClient; typedef std::function ArEventHandlerFunction; +typedef std::function ArHandshakeHandlerFunction; class AsyncEventSourceMessage { private: @@ -100,6 +101,7 @@ class AsyncEventSource: public AsyncWebHandler { String _url; LinkedList _clients; ArEventHandlerFunction _connectcb; + ArHandshakeHandlerFunction _handshakecb; public: AsyncEventSource(const String& url); ~AsyncEventSource(); @@ -107,6 +109,8 @@ class AsyncEventSource: public AsyncWebHandler { const char * url() const { return _url.c_str(); } void close(); void onConnect(ArEventHandlerFunction cb); + void onHandshake(ArHandshakeHandlerFunction cb); + void send(const char *message, const char *event=NULL, uint32_t id=0, uint32_t reconnect=0); size_t count() const; //number clinets connected size_t avgPacketsWaiting() const; diff --git a/src/AsyncWebSocket.cpp b/src/AsyncWebSocket.cpp index 52dcd75f0..ff30df1d4 100644 --- a/src/AsyncWebSocket.cpp +++ b/src/AsyncWebSocket.cpp @@ -1159,12 +1159,14 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *message, size_t len){ const char * WS_STR_CONNECTION = "Connection"; const char * WS_STR_UPGRADE = "Upgrade"; const char * WS_STR_ORIGIN = "Origin"; +const char * WS_STR_COOKIE = "Cookie"; const char * WS_STR_VERSION = "Sec-WebSocket-Version"; const char * WS_STR_KEY = "Sec-WebSocket-Key"; const char * WS_STR_PROTOCOL = "Sec-WebSocket-Protocol"; const char * WS_STR_ACCEPT = "Sec-WebSocket-Accept"; const char * WS_STR_UUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){ if(!_enabled) return false; @@ -1175,6 +1177,7 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){ request->addInterestingHeader(WS_STR_CONNECTION); request->addInterestingHeader(WS_STR_UPGRADE); request->addInterestingHeader(WS_STR_ORIGIN); + request->addInterestingHeader(WS_STR_COOKIE); request->addInterestingHeader(WS_STR_VERSION); request->addInterestingHeader(WS_STR_KEY); request->addInterestingHeader(WS_STR_PROTOCOL); @@ -1189,6 +1192,14 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){ if((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())){ return request->requestAuthentication(); } + + if(_handshakeHandler != nullptr){ + if(!_handshakeHandler(request)){ + request->send(401); + return; + } + } + AsyncWebHeader* version = request->getHeader(WS_STR_VERSION); if(version->value().toInt() != 13){ AsyncWebServerResponse *response = request->beginResponse(400); diff --git a/src/AsyncWebSocket.h b/src/AsyncWebSocket.h index 5b03aceb9..2f97830bc 100644 --- a/src/AsyncWebSocket.h +++ b/src/AsyncWebSocket.h @@ -237,6 +237,7 @@ class AsyncWebSocketClient { void _onData(void *pbuf, size_t plen); }; +typedef std::function AwsHandshakeHandler; typedef std::function AwsEventHandler; //WebServer Handler implementation that plays the role of a socket server @@ -248,6 +249,7 @@ class AsyncWebSocket: public AsyncWebHandler { AsyncWebSocketClientLinkedList _clients; uint32_t _cNextId; AwsEventHandler _eventHandler; + AwsHandshakeHandler _handshakeHandler; bool _enabled; AsyncWebLock _lock; @@ -315,6 +317,11 @@ class AsyncWebSocket: public AsyncWebHandler { void onEvent(AwsEventHandler handler){ _eventHandler = handler; } + + // Handshake Handler + void handleHandshake(AwsHandshakeHandler handler){ + _handshakeHandler = handler; + } //system callbacks (do not call) uint32_t _getNextId(){ return _cNextId++; }