diff --git a/vnext/.editorconfig b/vnext/.editorconfig index e0bfb830ed3..3309e13016a 100644 --- a/vnext/.editorconfig +++ b/vnext/.editorconfig @@ -21,3 +21,6 @@ insert_final_newline = false [*.ps1] indent_style = tab indent_size = 4 + +[package.json] +insert_final_newline = false diff --git a/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj b/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj index ab76178ff12..6960c5517ee 100644 --- a/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj +++ b/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj @@ -57,6 +57,7 @@ true BOOST_ASIO_HAS_IOCP;_WIN32_WINNT=_WIN32_WINNT_WIN7;WIN32;_WINDOWS;REACTNATIVEWIN32_EXPORTS;FOLLY_NO_CONFIG;WIN32_LEAN_AND_MEAN;NOMINMAX;GLOG_NO_ABBREVIATED_SEVERITIES;_HAS_AUTO_PTR_ETC;CHAKRACORE;RN_PLATFORM=windesktop;RN_EXPORT=;JSI_EXPORT=;NOJSC;%(PreprocessorDefinitions) + %(AdditionalOptions) /bigobj $(VCInstallDir)UnitTest\include;%(AdditionalIncludeDirectories) ProgramDatabase true @@ -102,6 +103,7 @@ + @@ -130,6 +132,7 @@ + diff --git a/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj.filters b/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj.filters index 25f5bfe4563..b1ff9efcace 100644 --- a/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj.filters +++ b/vnext/Desktop.IntegrationTests/React.Windows.Desktop.IntegrationTests.vcxproj.filters @@ -45,6 +45,9 @@ Integration Tests + + Source Files + Source Files @@ -56,6 +59,9 @@ Header Files + + Header Files + Header Files diff --git a/vnext/Desktop.IntegrationTests/WebSocketIntegrationTest.cpp b/vnext/Desktop.IntegrationTests/WebSocketIntegrationTest.cpp index cd404f88bcd..19ce87d4828 100644 --- a/vnext/Desktop.IntegrationTests/WebSocketIntegrationTest.cpp +++ b/vnext/Desktop.IntegrationTests/WebSocketIntegrationTest.cpp @@ -3,27 +3,31 @@ #include #include -#include "unicode.h" +#include #include #include #include -using namespace facebook::react; +using namespace Microsoft::React; using namespace Microsoft::VisualStudio::CppUnitTestFramework; using std::chrono::milliseconds; using std::condition_variable; +using std::make_shared; using std::unique_lock; using std::lock_guard; using std::promise; using std::string; +using CloseCode = IWebSocket::CloseCode; + TEST_CLASS(WebSocketIntegrationTest) { TEST_METHOD(ConnectClose) { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + auto server = make_shared(5556); + auto ws = IWebSocket::Make("ws://localhost:5556/"); Assert::IsFalse(nullptr == ws); bool connected = false; string message; @@ -32,8 +36,10 @@ TEST_CLASS(WebSocketIntegrationTest) connected = true; }); + server->Start(); ws->Connect(); - ws->Close(IWebSocket::CloseCode::Normal, "Closing"); + ws->Close(CloseCode::Normal, "Closing"); + server->Stop(); Assert::IsTrue(connected); } @@ -41,10 +47,12 @@ TEST_CLASS(WebSocketIntegrationTest) TEST_METHOD(ConnectNoClose) { bool connected = false; + auto server = make_shared(5556); + server->Start(); // IWebSocket scope. Ensures object is closed implicitly by destructor. { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + auto ws = IWebSocket::Make("ws://localhost:5556/"); ws->SetOnConnect([&connected]() { connected = true; @@ -53,12 +61,17 @@ TEST_CLASS(WebSocketIntegrationTest) ws->Connect(); } + server->Stop(); + Assert::IsTrue(connected); } TEST_METHOD(PingClose) { - auto ws = IWebSocket::Make("ws://localhost:5555"); + auto server = make_shared(5556); + server->Start(); + + auto ws = IWebSocket::Make("ws://localhost:5556"); promise pingPromise; ws->SetOnPing([&pingPromise]() { @@ -75,43 +88,14 @@ TEST_CLASS(WebSocketIntegrationTest) auto pingFuture = pingPromise.get_future(); pingFuture.wait(); bool pinged = pingFuture.get(); + ws->Close(CloseCode::Normal, "Closing after reading"); - ws->Close(IWebSocket::CloseCode::Normal, "Closing after reading"); + server->Stop(); Assert::IsTrue(pinged); Assert::AreEqual({}, errorString); } - TEST_METHOD(SendReceiveNoClose) - { - auto ws = IWebSocket::Make("ws://localhost:5555/"); - promise response; - ws->SetOnMessage([&response](size_t size, const string& message) - { - // Ignore greeting message. - if (message == "hello") - return; - - response.set_value(message); - }); - string errorMessage; - ws->SetOnError([&errorMessage](IWebSocket::Error err) - { - errorMessage = err.Message; - }); - - ws->Connect(); - ws->Send("suffixme"); - - // Block until respone is received. Fail in case of a remote endpoint failure. - auto future = response.get_future(); - future.wait(); - string result = future.get(); - - Assert::AreEqual({}, errorMessage); - Assert::AreEqual(string("suffixme_response"), result); - } - // Emulate promise/future functionality. // Fails when connecting to stock package bundler. BEGIN_TEST_METHOD_ATTRIBUTE(WaitForBundlerResponseNoClose) @@ -157,7 +141,12 @@ TEST_CLASS(WebSocketIntegrationTest) TEST_METHOD(SendReceiveClose) { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + auto server = make_shared(5556); + server->SetMessageFactory([](string&& message) + { + return message + "_response"; + }); + auto ws = IWebSocket::Make("ws://localhost:5556/"); promise sentSizePromise; ws->SetOnSend([&sentSizePromise](size_t size) { @@ -166,10 +155,6 @@ TEST_CLASS(WebSocketIntegrationTest) promise receivedPromise; ws->SetOnMessage([&receivedPromise](size_t size, const string& message) { - // Ignore greeting message - if (message == "hello") - return; - receivedPromise.set_value(message); }); string errorMessage; @@ -178,7 +163,8 @@ TEST_CLASS(WebSocketIntegrationTest) errorMessage = err.Message; }); - string sent = "suffixme"; + server->Start(); + string sent = "prefix"; ws->Connect(); ws->Send(sent); @@ -191,31 +177,34 @@ TEST_CLASS(WebSocketIntegrationTest) string received = receivedFuture.get(); Assert::AreEqual({}, errorMessage); - ws->Close(IWebSocket::CloseCode::Normal, "Closing after reading"); + ws->Close(CloseCode::Normal, "Closing after reading"); + server->Stop(); Assert::AreEqual({}, errorMessage); Assert::AreEqual(sent.length(), sentSize); - Assert::AreEqual(string("suffixme_response"), received); + Assert::AreEqual({ "prefix_response" }, received); } TEST_METHOD(SendReceiveLargeMessage) { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + auto server = make_shared(5556); + server->SetMessageFactory([](string&& message) + { + return message + "_response"; + }); + auto ws = IWebSocket::Make("ws://localhost:5556/"); promise response; ws->SetOnMessage([&response](size_t size, const string& message) { - // Ignore greeting message - if (message == "hello") - return; - response.set_value(message); }); - ws->SetOnError([](IWebSocket::Error err) + string errorMessage; + ws->SetOnError([&errorMessage](IWebSocket::Error err) { - auto message = facebook::react::unicode::utf8ToUtf16(err.Message); - Assert::Fail(message.c_str()); + errorMessage = err.Message; }); + server->Start(); ws->Connect(); char digits[] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' }; @@ -234,8 +223,10 @@ TEST_CLASS(WebSocketIntegrationTest) future.wait(); string result = future.get(); - ws->Close(IWebSocket::CloseCode::Normal, "Closing after reading"); + ws->Close(CloseCode::Normal, "Closing after reading"); + server->Stop(); + Assert::AreEqual({}, errorMessage); Assert::AreEqual(static_cast(LEN + string("_response").length()), result.length()); } @@ -261,13 +252,24 @@ TEST_CLASS(WebSocketIntegrationTest) END_TEST_METHOD_ATTRIBUTE() TEST_METHOD(AdditionalHeaders) { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + string cookie; + auto server = make_shared(5556); + server->SetOnHandshake([server](boost::beast::websocket::response_type& response) + { + auto cookie = response[boost::beast::http::field::cookie].to_string(); + server->SetMessageFactory([cookie](string&&) + { + return cookie; + }); + }); + auto ws = IWebSocket::Make("ws://localhost:5556/"); promise response; ws->SetOnMessage([&response](size_t size, const string& message) { response.set_value(message); }); + server->Start(); ws->Connect({}, {{ L"Cookie", "JSESSIONID=AD9A320CC4034641997FF903F1D10906" }}); ws->Send(""); @@ -275,65 +277,42 @@ TEST_CLASS(WebSocketIntegrationTest) future.wait(); string result = future.get(); - Assert::AreEqual(string("JSESSIONID=AD9A320CC4034641997FF903F1D10906"), result); + Assert::AreEqual({ "JSESSIONID=AD9A320CC4034641997FF903F1D10906" }, result); - ws->Close(IWebSocket::CloseCode::Normal, "No reason"); + ws->Close(CloseCode::Normal, "No reason"); + server->Stop(); } - /// - // Run this test against a valid WebSocket server runing on SSL. - // See sample below. - /// - /* -const WebSocket = require('ws'); -const fs = require('fs'); -const https = require('https'); - -const httpsServer = https.createServer({ - key: fs.readFileSync('key.pem'), - cert: fs.readFileSync('cert.pem') -}); -const server = new WebSocket.Server({ - server:httpsServer -}); - -server.on('connection', (ws) => { - ws.on('message', (message) => { - console.log('Received message:', message); - if (message === 'exit') { - console.log('WebSocket integration test server exit'); - process.exit(0); - } - console.log('Cookie:', ws.upgradeReq.headers.cookie); - ws.send(message + '_response'); - }); - - ws.send('hello'); -}); - -httpsServer.listen(443); - */ - BEGIN_TEST_METHOD_ATTRIBUTE(SendAndReceiveSsl) - TEST_IGNORE() - END_TEST_METHOD_ATTRIBUTE() - TEST_METHOD(SendAndReceiveSsl) + TEST_METHOD(SendReceiveSsl) { - auto ws = IWebSocket::Make("wss://localhost/"); - string message; - ws->SetOnMessage([&message](size_t size, const string& messageIn) + auto server = make_shared(5556, /*isSecure*/ true); + server->SetMessageFactory([](string&& message) { - message = messageIn; + return message + "_response"; + }); + auto ws = IWebSocket::Make("wss://localhost:5556"); + promise response; + ws->SetOnMessage([&response](size_t size, const string& messageIn) + { + response.set_value(messageIn); }); + server->Start(); ws->Connect(); ws->Send("suffixme"); - ws->Close(IWebSocket::CloseCode::Normal, "Closing after reading"); - Assert::AreEqual(string("hello"), message); + auto result = response.get_future(); + result.wait(); + + ws->Close(CloseCode::Normal, "Closing after reading"); + server->Stop(); + + Assert::AreEqual({ "suffixme_response" }, result.get()); } + //TODO: Use Test::WebSocketServer!!! BEGIN_TEST_METHOD_ATTRIBUTE(SendBinary) - //TEST_IGNORE() + TEST_IGNORE() END_TEST_METHOD_ATTRIBUTE() TEST_METHOD(SendBinary) { @@ -386,23 +365,24 @@ httpsServer.listen(443); Assert::AreEqual(messages[i], response); } - ws->Close(IWebSocket::CloseCode::Normal, "Closing after reading"); + ws->Close(CloseCode::Normal, "Closing after reading"); Assert::AreEqual({}, errorMessage); } TEST_METHOD(SendConsecutive) { - auto ws = IWebSocket::Make("ws://localhost:5555/"); + auto server = make_shared(5556); + server->SetMessageFactory([](string&& message) + { + return message + "_response"; + }); + auto ws = IWebSocket::Make("ws://localhost:5556/"); promise response; const int writes = 10; int count = 0; ws->SetOnMessage([&response, &count, writes](size_t size, const string& message) { - // Ignore greeting message. - if (message == "hello") - return; - if (++count < writes) return; @@ -414,6 +394,7 @@ httpsServer.listen(443); errorMessage = err.Message; }); + server->Start(); ws->Connect(); // Consecutive immediate writes should be enqueued. @@ -426,8 +407,11 @@ httpsServer.listen(443); future.wait(); string result = future.get(); - ws->Close(IWebSocket::CloseCode::Normal, "Closing"); + ws->Close(CloseCode::Normal, "Closing"); + server->Stop(); + Assert::AreEqual({}, errorMessage); - Assert::AreEqual(string("suffixme_response"), result); + Assert::AreEqual(writes, count); + Assert::AreEqual({ "suffixme_response" }, result); } }; diff --git a/vnext/Desktop.IntegrationTests/WebSocketServer.cpp b/vnext/Desktop.IntegrationTests/WebSocketServer.cpp new file mode 100644 index 00000000000..db5efc24ae7 --- /dev/null +++ b/vnext/Desktop.IntegrationTests/WebSocketServer.cpp @@ -0,0 +1,390 @@ +#include "WebSocketServer.h" + +#include +#include + +using namespace boost::asio; + +using boost::system::error_code; +using std::function; +using std::placeholders::_1; +using std::placeholders::_2; +using std::string; + +namespace websocket = boost::beast::websocket; + +namespace Microsoft { +namespace React { +namespace Test { + +#pragma region BaseWebSocketSession + +template +BaseWebSocketSession::BaseWebSocketSession(WebSocketServiceCallbacks& callbacks) + : m_callbacks{ callbacks } + , m_state{ State::Stopped } +{ +} + +template +BaseWebSocketSession::~BaseWebSocketSession() +{ +} + +template +void BaseWebSocketSession::Start() +{ + Accept(); +} + +template +void BaseWebSocketSession::Accept() +{ + m_stream->async_accept_ex( + bind_executor(*m_strand, std::bind( + &BaseWebSocketSession::OnHandshake, + this->SharedFromThis(), + _1 // response + )), + bind_executor(*m_strand, std::bind( + &BaseWebSocketSession::OnAccept, + this->SharedFromThis(), + _1 // ec + )) + ); +} + +template +void BaseWebSocketSession::OnHandshake(websocket::response_type& response) +{ + if (m_callbacks.OnHandshake) + m_callbacks.OnHandshake(response); +} + +template +void BaseWebSocketSession::OnAccept(error_code ec) +{ + if (ec) + return;//TODO: fail + + m_state = State::Started; + + if (m_callbacks.OnConnection) + m_callbacks.OnConnection(); + + Read(); +} + +template +void BaseWebSocketSession::Read() +{ + if (State::Stopped == m_state) + return; + + m_stream->async_read(m_buffer, bind_executor(*m_strand, std::bind( + &BaseWebSocketSession::OnRead, + this->SharedFromThis(), + _1, // ec + _2 // transferred + ))); +} + +template +void BaseWebSocketSession::OnRead(error_code ec, size_t /*transferred*/) +{ + if (websocket::error::closed == ec) + return; + + if (ec) + return;//TODO: fail instead + + if (!m_callbacks.MessageFactory) + { + m_buffer.consume(m_buffer.size()); + return Read(); + } + + m_message = m_callbacks.MessageFactory(buffers_to_string(m_buffer.data())); + m_buffer.consume(m_buffer.size()); + + m_stream->text(m_stream->got_text()); + m_stream->async_write(buffer(m_message), bind_executor(*m_strand, std::bind( + &BaseWebSocketSession::OnWrite, + this->SharedFromThis(), + _1, // ec + _2 // transferred + ))); +} + +template +void BaseWebSocketSession::OnWrite(error_code ec, size_t /*transferred*/) +{ + if (ec) + return; //TODO: fail + + // Clear outgoing message contents. + m_message.clear(); + + Read(); +} + +#pragma endregion // BaseWebSocketSession + +#pragma region WebSocketSession + +WebSocketSession::WebSocketSession(ip::tcp::socket socket, WebSocketServiceCallbacks& callbacks) + : BaseWebSocketSession(callbacks) +{ + m_stream = std::make_shared>(std::move(socket)); + m_strand = std::make_shared>(m_stream->get_executor()); +} + +WebSocketSession::~WebSocketSession() {} + +#pragma region BaseWebSocketSession + +std::shared_ptr> WebSocketSession::SharedFromThis() /*override*/ +{ + return this->shared_from_this(); +} + +#pragma endregion // BaseWebSocketSession + +#pragma endregion // WebSocketSession + +#pragma region SecureWebSocketSession + +SecureWebSocketSession::SecureWebSocketSession(ip::tcp::socket socket, WebSocketServiceCallbacks& callbacks) + : BaseWebSocketSession(callbacks) +{ + // Initialize SSL context. + string const cert = + "-----BEGIN CERTIFICATE-----\n" + "MIIDhjCCAm6gAwIBAgIJAPh+egUebaStMA0GCSqGSIb3DQEBCwUAMFgxCzAJBgNV\n" + "BAYTAlVTMRMwEQYDVQQIDApXYXNoaW5ndG9uMRAwDgYDVQQHDAdSZWRtb25kMRIw\n" + "EAYDVQQKDAlNaWNyb3NvZnQxDjAMBgNVBAsMBVJlYWN0MB4XDTE5MDYwMTA4MDcx\n" + "M1oXDTI5MDUyOTA4MDcxM1owWDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCldhc2hp\n" + "bmd0b24xEDAOBgNVBAcMB1JlZG1vbmQxEjAQBgNVBAoMCU1pY3Jvc29mdDEOMAwG\n" + "A1UECwwFUmVhY3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCgHKC9\n" + "XC+1FjDg2Xdmbpf3ILiJQtGFiD3WFb+duNXThOA0LY6gytVBY6LitAzB7Jm7spvx\n" + "BbA46kw8Dsmv02hI0diVzFKCB5bTTs0N8bgAsem6qvDpo/mvp2TtDU2J8o4RhMQ3\n" + "BQvdZUGgtH4VR2W4vIHufNjVUvN9hTL2eOBz0EYElsMSogG8f97N+m/7L6JeyjPo\n" + "kFwXYTFMjv3ihJmev/cBNkxuchLUT7NAc7bMCtmv5lzsKMKe6g0lUxDSBYxXztqU\n" + "l3huo2g990VbvTWH/lhz3bgdnon/AUKWBmS2eRmK9hH/rGlm1NeMCjexMZrYC3m8\n" + "vvfIR25plGmNjyQJAgMBAAGjUzBRMB0GA1UdDgQWBBTNLE3Nl0s3O40wDEXf9t/7\n" + "r6Y1QjAfBgNVHSMEGDAWgBTNLE3Nl0s3O40wDEXf9t/7r6Y1QjAPBgNVHRMBAf8E\n" + "BTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAbRL+69uZLW3Q6JIQ9qg3CPjm6574o\n" + "cOiyJ9DX49i23AyYNsCYulvLApHgrmltMJHrC4U7EasQCTtwpAMyLJDLIDdujSSs\n" + "ynSe5PcNeElLmEkH4PxyAFsb/2oWI2PSJh0vseuugUpeKTHJv5MPkLUm7WMLHYj5\n" + "uOQzYDwJ+IuveVzX5TRXtkY8jF9ErL4iF8eYTyp0ANeY11vJOAbd2FcAy5baYjB3\n" + "JVczHy7eegwfOQJFM9mIZE7+Ac0SaknC0Jun9486cJ0mvbdrMSd+vgF85OrpWwYU\n" + "ISfux0NDVN1gjvSgdkEB+CWnV1rNKsVKlg4psDkpq33AJEwnx3qNxtKG\n" + "-----END CERTIFICATE-----\n"; + + string const key = + "-----BEGIN RSA PRIVATE KEY-----\n" + "MIIEowIBAAKCAQEAoBygvVwvtRYw4Nl3Zm6X9yC4iULRhYg91hW/nbjV04TgNC2O\n" + "oMrVQWOi4rQMweyZu7Kb8QWwOOpMPA7Jr9NoSNHYlcxSggeW007NDfG4ALHpuqrw\n" + "6aP5r6dk7Q1NifKOEYTENwUL3WVBoLR+FUdluLyB7nzY1VLzfYUy9njgc9BGBJbD\n" + "EqIBvH/ezfpv+y+iXsoz6JBcF2ExTI794oSZnr/3ATZMbnIS1E+zQHO2zArZr+Zc\n" + "7CjCnuoNJVMQ0gWMV87alJd4bqNoPfdFW701h/5Yc924HZ6J/wFClgZktnkZivYR\n" + "/6xpZtTXjAo3sTGa2At5vL73yEduaZRpjY8kCQIDAQABAoIBAA/bpgP7THJYF1E9\n" + "2LiZfY2pfP2DU7MxEkbQ8qCRfQQtJfOlC3pbfJG0Z56ijJzsbTGM+CsAEDsi4ZgV\n" + "Mt6qRqrntdboXMeqLsMRC/g0l6/h7y9g3OmXJxTBtJpR9fsSvgV4K+LzKgCslbpw\n" + "BgjfgHCyov/W97bxN1KYTbrhsAFoWFwyTglDIkTYo/92suwsyYt14pxnK54QyNrM\n" + "tWKS0K9rZmSMjaVYW+nnSLelFVAOAzW4SOt8CE1V0usjIkmD5smdadjT14exWnxV\n" + "zVMhsbrfUFi3oBfW8X+TuWQjBnVfX0akZALY6vmnmOEWLz4pXJBLmazSGpXyM5o2\n" + "JpxSUS0CgYEAywE9MhWqKhZXpOwGWbNe1Nibh4l8vt+pTDs2TpFdHXNT8UfuUk03\n" + "ycCxGKrDPUAUPdVGygvmqdKHpXLbWalPGdFR6Xcn5YjVxM7L37jGl28oQOdNLI5u\n" + "Lw7hJ5L84M0LZqMI5589jA63WrgkLNQ9eKnuFn9N/3n9r/uZqtGVjLsCgYEAyejc\n" + "I/St33V0CNDtfEZ8dBDztXhx1WjDzv+JgNiy7pLUB+8yW0/iU2Y8ptpmfQ0nRDDv\n" + "sGK5myuBv85PRgWQnPskL3V5+L+DK40hyYnrL4bKhxT8az2CQdWI23sS4Nq5b59A\n" + "ylUyGIUXv3P62nUgMq3kM7L6mMgz/cCxeVeQyAsCgYAlZAIIgpMIE3trJgn5ZZ9W\n" + "5tqmuT0fzwRYxSM4j2+uJ/rTGyObRxu6bmJwH6u8UVwpE2ppdo2yw9M2NxSNzDCE\n" + "mdhTfx37Ghv9lvVYLKlvZQruAWxmg4lp43y3FEy9fybVbbwLJXppnKBK9lW7aBA2\n" + "dF4lCKeuIaMHUfk4zEeWVwKBgF81HXEa9E4VfUSW+BUMy6yTPcgJZmwCParDFlya\n" + "Ui1rMO4Y3X7vOUKoR9tJyuAWrrhZ9vwOYYUIy+Lc7saO4zUSu2phk8U20SxdHVyC\n" + "W1MK1T9DJw+ObniKr0EHVMyQdrZqusttxvSG9b7Cerw+VJNxKdUzBTW72cBC96zH\n" + "HK8nAoGBAMGKiSm4a1O5SpSiiZ6kNZHw9wB98Jtic7ozoUrlVEPSeRrXq/BzsyyH\n" + "md/sN/1v/Qq4SPlUtnzRakcPa2sntDb4SEt/Lrr97ouX1C/qfWljh69jRDLPBBez\n" + "cAlHdEzualsWQsACr7I71UebXvha+v0XXAKiIRqAKRrFLmPPCFrP\n" + "-----END RSA PRIVATE KEY-----\n"; + + string dh = + "-----BEGIN DH PARAMETERS-----\n" + "MIIBCAKCAQEA5VbTCtf4s2qPpqTtk2pXsYcqo7cLF0LVQaXMhOZNmif0TKDyclSV\n" + "NQANJcl0K9C5cGfh/1oEZs30A+Ww1zCtjkwJFvQdUAhCy/1U/qhRO2swXtz+CGZL\n" + "7PL0yu0Xht3EqGRS4z98LPCALVYvuqbNKTnFHUZl8oYJT0Xx0lzzZ+r5uFYYghQU\n" + "nCohXf/O0VLCPJMnd/oLY70CcPEL9V1KDb80oTzlYzrVPAHidcOXkiZpmOHgdiA/\n" + "LLG0h495hZhL5OqqDrLM7IWxHNmzgwhQ04PdGa6zPP4fnt7L4Ia5/lYOolvdmNkx\n" + "XgdewtScX7P5ltOMhhcWS4Og+qZn18a3kwIBAg==\n" + "-----END DH PARAMETERS-----\n"; + + auto context = ssl::context(ssl::context::sslv23); + + //TODO: Remove if not used. + context.set_password_callback([](size_t, ssl::context_base::password_purpose) + { + return "test"; + }); + + context.set_options(ssl::context::default_workarounds | ssl::context::no_sslv2 | ssl::context::single_dh_use); + context.use_certificate_chain(buffer(cert.data(), cert.size())); + context.use_private_key(buffer(key.data(), key.size()), ssl::context::file_format::pem); + context.use_tmp_dh(buffer(dh.data(), dh.size())); + + m_stream = std::make_shared>>(std::move(socket), context); + m_strand = std::make_shared>(m_stream->get_executor()); +} + +SecureWebSocketSession::~SecureWebSocketSession() {} + +#pragma region BaseWebSocketSession + +std::shared_ptr>> SecureWebSocketSession::SharedFromThis() /*override*/ +{ + return this->shared_from_this(); +} + +#pragma endregion // BaseWebSocketSession + +#pragma region IWebSocketSession + +void SecureWebSocketSession::Start() /*override*/ +{ + m_stream->next_layer().async_handshake(ssl::stream_base::server, bind_executor(*m_strand, std::bind( + &SecureWebSocketSession::OnSslHandshake, + this->shared_from_this(), + _1 // ec + ))); +} + +void SecureWebSocketSession::OnSslHandshake(error_code ec) +{ + if (ec) + return; + + Accept(); +} + +#pragma endregion // IWebSocketSession + +#pragma endregion // SecureWebSocketSession + +#pragma region WebSocketServer + +WebSocketServer::WebSocketServer(uint16_t port, bool isSecure) + : m_acceptor{ m_context } + , m_socket{ m_context } + , m_sessions{} + , m_isSecure{ isSecure } +{ + ip::tcp::endpoint ep{ip::make_address("0.0.0.0"), port }; + error_code ec; + + m_acceptor.open(ep.protocol(), ec); + if (ec) + { + return; //TODO: handle + } + + m_acceptor.set_option(socket_base::reuse_address(true), ec); + if (ec) + { + return; //TODO: handle + } + + m_acceptor.bind(ep, ec); + if (ec) + { + return; //TODO: handle + } + + m_acceptor.listen(socket_base::max_listen_connections, ec); + if (ec) + { + return; //TODO: handle + } +} + +WebSocketServer::WebSocketServer(int port, bool isSecure) + : WebSocketServer(static_cast(port), isSecure) +{ +} + +void WebSocketServer::Start() +{ + if (!m_acceptor.is_open()) + return; + + Accept(); + + m_contextThread = std::thread([self = shared_from_this()]() + { + self->m_context.run(); + }); +} + +void WebSocketServer::Accept() +{ + m_acceptor.async_accept(m_socket, std::bind(&WebSocketServer::OnAccept, shared_from_this(), /*ec*/ _1)); +} + +void WebSocketServer::Stop() +{ + if (m_acceptor.is_open()) + m_acceptor.close(); + + m_contextThread.join(); +} + +void WebSocketServer::OnAccept(error_code ec) +{ + if (ec) + { + //TODO: fail + } + else + { + std::shared_ptr session; + if (m_isSecure) + session = std::shared_ptr(new SecureWebSocketSession(std::move(m_socket), m_callbacks)); + else + session = std::shared_ptr(new WebSocketSession(std::move(m_socket), m_callbacks)); + + m_sessions.push_back(session); + session->Start(); + } + + //TODO: Accept again. + //Accept(); +} + +void WebSocketServer::SetOnConnection(function&& func) +{ + m_callbacks.OnConnection = std::move(func); +} + +void WebSocketServer::SetOnHandshake(function&& func) +{ + m_callbacks.OnHandshake = std::move(func); +} + +void WebSocketServer::SetOnMessage(function&& func) +{ + m_callbacks.OnMessage = std::move(func); +} + +void WebSocketServer::SetMessageFactory(function&& func) +{ + m_callbacks.MessageFactory = std::move(func); +} + +void WebSocketServer::SetOnError(function&& func) +{ + m_callbacks.OnError = std::move(func); +} + +#pragma endregion // WebSocketServer + +} } } // Microsoft::React::Test diff --git a/vnext/Desktop.IntegrationTests/WebSocketServer.h b/vnext/Desktop.IntegrationTests/WebSocketServer.h new file mode 100644 index 00000000000..ce7ff5dc650 --- /dev/null +++ b/vnext/Desktop.IntegrationTests/WebSocketServer.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace Microsoft { +namespace React { +namespace Test { + +struct WebSocketServiceCallbacks +{ + std::function OnConnection; + std::function OnHandshake; + std::function OnMessage; + std::function MessageFactory; + std::function OnError; +}; + +struct IWebSocketSession +{ + virtual ~IWebSocketSession() {} + + virtual void Start() = 0; +}; + +template +class BaseWebSocketSession : public IWebSocketSession +{ + enum class State : std::size_t + { + Started, + Stopped + }; + + boost::beast::multi_buffer m_buffer; + std::string m_message; + WebSocketServiceCallbacks& m_callbacks; + State m_state; + + std::function m_errorHandler; + + void Read(); + + void OnAccept(boost::system::error_code ec); + void OnHandshake(boost::beast::websocket::response_type& response); + void OnRead(boost::system::error_code ec, std::size_t transferred); + void OnWrite(boost::system::error_code ec, std::size_t transferred); + +protected: + std::shared_ptr> m_stream; + std::shared_ptr> m_strand; + + void Accept(); + + virtual std::shared_ptr> SharedFromThis() = 0; + +public: + BaseWebSocketSession(WebSocketServiceCallbacks& callbacks); + ~BaseWebSocketSession(); + + virtual void Start() override; +}; + +class WebSocketSession + : public std::enable_shared_from_this + , public BaseWebSocketSession +{ + std::shared_ptr> SharedFromThis() override; + +public: + WebSocketSession(boost::asio::ip::tcp::socket socket, WebSocketServiceCallbacks& callbacks); + ~WebSocketSession(); +}; + +class SecureWebSocketSession + : public std::enable_shared_from_this + , public BaseWebSocketSession> +{ + std::shared_ptr>> SharedFromThis() override; + +public: + SecureWebSocketSession(boost::asio::ip::tcp::socket socket, WebSocketServiceCallbacks& callbacks); + ~SecureWebSocketSession(); + + void OnSslHandshake(boost::system::error_code ec); + + #pragma region IWebSocketSession + + void Start() override; + + #pragma endregion //IWebSocketSession +}; + +class WebSocketServer : public std::enable_shared_from_this +{ + std::thread m_contextThread; + boost::asio::io_context m_context; + boost::asio::ip::tcp::acceptor m_acceptor; + boost::asio::ip::tcp::socket m_socket; + WebSocketServiceCallbacks m_callbacks; + std::vector> m_sessions; + bool m_isSecure; + + void Accept(); + + void OnAccept(boost::system::error_code ec); + +public: + WebSocketServer(std::uint16_t port, bool isSecure); + WebSocketServer(int port, bool isSecure = false); + + void Start(); + void Stop(); + + void SetOnConnection(std::function&& func); + void SetOnHandshake(std::function&& func); + void SetOnMessage(std::function&& func); + void SetMessageFactory(std::function&& func); + void SetOnError(std::function&& func); +}; + +} } } // Microsoft::React::Test