Skip to content
This repository was archived by the owner on Jan 20, 2025. It is now read-only.
20 changes: 19 additions & 1 deletion src/AsyncEventSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ void AsyncEventSource::onConnect(ArEventHandlerFunction cb){
_connectcb = cb;
}

void AsyncEventSource::onHandshake(ArHandshakeHandlerFunction cb){
_handshakecb = cb;
}

void AsyncEventSource::_addClient(AsyncEventSourceClient * client){
/*char * temp = (char *)malloc(2054);
if(temp != NULL){
Expand Down Expand Up @@ -333,13 +337,27 @@ bool AsyncEventSource::canHandle(AsyncWebServerRequest *request){
return false;
}
request->addInterestingHeader("Last-Event-ID");
request->addInterestingHeader("Cookie");
return true;
}

void AsyncEventSource::handleRequest(AsyncWebServerRequest *request){
if((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str()))
return request->requestAuthentication();
request->send(new AsyncEventSourceResponse(this));

// If Custom Handshake Handler is supplied
if(_handshakecb != nullptr){
if(_handshakecb(request)){
// Request Accepted
request->send(new AsyncEventSourceResponse(this));
}else{
// Request Rejected. Supply unauthorised http response.
request->send(401);
}
}else{
// No Custom Handshake Handler Supplied. Accept as default action.
request->send(new AsyncEventSourceResponse(this));
}
}

// Response
Expand Down
4 changes: 4 additions & 0 deletions src/AsyncEventSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class AsyncEventSource;
class AsyncEventSourceResponse;
class AsyncEventSourceClient;
typedef std::function<void(AsyncEventSourceClient *client)> ArEventHandlerFunction;
typedef std::function<bool(AsyncWebServerRequest *request)> ArHandshakeHandlerFunction;

class AsyncEventSourceMessage {
private:
Expand Down Expand Up @@ -100,13 +101,16 @@ class AsyncEventSource: public AsyncWebHandler {
String _url;
LinkedList<AsyncEventSourceClient *> _clients;
ArEventHandlerFunction _connectcb;
ArHandshakeHandlerFunction _handshakecb;
public:
AsyncEventSource(const String& url);
~AsyncEventSource();

const char * url() const { return _url.c_str(); }
void close();
void onConnect(ArEventHandlerFunction cb);
void onHandshake(ArHandshakeHandlerFunction cb);

void send(const char *message, const char *event=NULL, uint32_t id=0, uint32_t reconnect=0);
size_t count() const; //number clinets connected
size_t avgPacketsWaiting() const;
Expand Down
11 changes: 11 additions & 0 deletions src/AsyncWebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,12 +1159,14 @@ void AsyncWebSocket::binaryAll(const __FlashStringHelper *message, size_t len){
const char * WS_STR_CONNECTION = "Connection";
const char * WS_STR_UPGRADE = "Upgrade";
const char * WS_STR_ORIGIN = "Origin";
const char * WS_STR_COOKIE = "Cookie";
const char * WS_STR_VERSION = "Sec-WebSocket-Version";
const char * WS_STR_KEY = "Sec-WebSocket-Key";
const char * WS_STR_PROTOCOL = "Sec-WebSocket-Protocol";
const char * WS_STR_ACCEPT = "Sec-WebSocket-Accept";
const char * WS_STR_UUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";


bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){
if(!_enabled)
return false;
Expand All @@ -1175,6 +1177,7 @@ bool AsyncWebSocket::canHandle(AsyncWebServerRequest *request){
request->addInterestingHeader(WS_STR_CONNECTION);
request->addInterestingHeader(WS_STR_UPGRADE);
request->addInterestingHeader(WS_STR_ORIGIN);
request->addInterestingHeader(WS_STR_COOKIE);
request->addInterestingHeader(WS_STR_VERSION);
request->addInterestingHeader(WS_STR_KEY);
request->addInterestingHeader(WS_STR_PROTOCOL);
Expand All @@ -1189,6 +1192,14 @@ void AsyncWebSocket::handleRequest(AsyncWebServerRequest *request){
if((_username != "" && _password != "") && !request->authenticate(_username.c_str(), _password.c_str())){
return request->requestAuthentication();
}

if(_handshakeHandler != nullptr){
if(!_handshakeHandler(request)){
request->send(401);
return;
}
}

AsyncWebHeader* version = request->getHeader(WS_STR_VERSION);
if(version->value().toInt() != 13){
AsyncWebServerResponse *response = request->beginResponse(400);
Expand Down
7 changes: 7 additions & 0 deletions src/AsyncWebSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class AsyncWebSocketClient {
void _onData(void *pbuf, size_t plen);
};

typedef std::function<bool(AsyncWebServerRequest *request)> AwsHandshakeHandler;
typedef std::function<void(AsyncWebSocket * server, AsyncWebSocketClient * client, AwsEventType type, void * arg, uint8_t *data, size_t len)> AwsEventHandler;

//WebServer Handler implementation that plays the role of a socket server
Expand All @@ -248,6 +249,7 @@ class AsyncWebSocket: public AsyncWebHandler {
AsyncWebSocketClientLinkedList _clients;
uint32_t _cNextId;
AwsEventHandler _eventHandler;
AwsHandshakeHandler _handshakeHandler;
bool _enabled;
AsyncWebLock _lock;

Expand Down Expand Up @@ -315,6 +317,11 @@ class AsyncWebSocket: public AsyncWebHandler {
void onEvent(AwsEventHandler handler){
_eventHandler = handler;
}

// Handshake Handler
void handleHandshake(AwsHandshakeHandler handler){
_handshakeHandler = handler;
}

//system callbacks (do not call)
uint32_t _getNextId(){ return _cNextId++; }
Expand Down