diff --git a/.clangd b/.clangd deleted file mode 100644 index 9342164e..00000000 --- a/.clangd +++ /dev/null @@ -1,18 +0,0 @@ -If: - PathMatch: "(ssh\/.*)|(ids\/.*)|(backend\/.*)|(dependencies\/Nui\/.*)" - PathExclude: "(dependencies\/Nui\/nui\/include\/nui\/frontend.*)" -CompileFlags: - CompilationDatabase: "build/clang_debug" - Add: - - "-IE:/DevelopmentFast/scp/build/clang_debug/_deps/emscripten-src/upstream/emscripten/system/include" - - "-ID:/msys2/clang64/include" - - "-D__cplusplus=202302L" ---- -If: - PathMatch: "(frontend\/.*)|(nui-file-explorer\/.*)|(dependencies\/Nui\/nui\/include\/nui\/frontend.*)" -CompileFlags: - CompilationDatabase: "build/clang_debug/module_nui-scp" - Add: - - "-D__EMSCRIPTEN__" - - "-IE:/DevelopmentFast/scp/build/clang_debug/_deps/emscripten-src/upstream/emscripten/system/include" - - "-ID:/msys2/clang64/include" diff --git a/.gitignore b/.gitignore index e3d5bf0d..82c65bbe 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ dependencies/* node_modules/* themes/node_modules/* themes/node_modules -todo.txt \ No newline at end of file +todo.txt +.clangd \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 6775509b..18bd658c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ set(NUI_FETCH_ROAR OFF CACHE BOOL "Dont fetch roar, use local copy" FORCE) set(NUI_FETCH_NLOHMANN_JSON OFF CACHE BOOL "Dont fetch nlohmann, use local copy" FORCE) set(ROAR_EXTERNAL_NLOHMANN_JSON ON CACHE BOOL "Use an external nlohmann_json library (provide it manually)" FORCE) set(SPDLOG_BUILD_SHARED OFF CACHE BOOL "Dont build shared spdlog" FORCE) +set(NUI_FILE_EXPLORER_MEMORY64 ON CACHE BOOL "Enable MEMORY64 for nui-file-explorer" FORCE) option(BUILD_SHARED_LIBS "Do not build shared libraries" OFF) option(BUILD_STATIC_LIBS "Build static libraries" ON) @@ -69,5 +70,6 @@ else() include(./_cmake/dependencies/gtest.cmake) enable_testing() add_subdirectory(ssh/test/ssh) + add_subdirectory(backend/test/backend) endif() endif() \ No newline at end of file diff --git a/_cmake/common_options.cmake b/_cmake/common_options.cmake index a036205e..5ab709ea 100644 --- a/_cmake/common_options.cmake +++ b/_cmake/common_options.cmake @@ -5,6 +5,12 @@ if (${MSVC}) else() target_compile_options(core-target INTERFACE -Wall -Wextra -Wpedantic) endif() -target_compile_options(core-target INTERFACE -Wbad-function-cast -Wcast-function-type -fexceptions -pedantic $<$:-g;-Werror=return-type> $<$:-O3>) -target_link_options(core-target INTERFACE $<$:-s;-static-libgcc;-static-libstdc++>) + +set(MEM64 "") +if (EMSCRIPTEN) + set(MEM64 "-sMEMORY64=1") +endif() + +target_compile_options(core-target INTERFACE -Wbad-function-cast -Wcast-function-type -fexceptions -pedantic $<$:-g;-Werror=return-type> $<$:-O3> ${MEM64}) +target_link_options(core-target INTERFACE $<$:-s;-static-libgcc;-static-libstdc++> ${MEM64}) target_compile_features(core-target INTERFACE cxx_std_23) \ No newline at end of file diff --git a/_cmake/dependencies/spdlog.cmake b/_cmake/dependencies/spdlog.cmake index f64ef8e8..70055da2 100755 --- a/_cmake/dependencies/spdlog.cmake +++ b/_cmake/dependencies/spdlog.cmake @@ -4,7 +4,7 @@ include(FetchContent) FetchContent_Declare( spdlog GIT_REPOSITORY https://github.com/gabime/spdlog.git - GIT_TAG ae1de0dc8cf480f54eaa425c4a9dc4fef29b28ba + GIT_TAG v1.16.0 ) FetchContent_MakeAvailable(spdlog) diff --git a/backend/include/backend/main.hpp b/backend/include/backend/main.hpp index 1e0a0a1d..e8a10b84 100644 --- a/backend/include/backend/main.hpp +++ b/backend/include/backend/main.hpp @@ -1,9 +1,10 @@ #pragma once #include -#include +#include #include #include +#include #include #include @@ -37,7 +38,7 @@ class Main Nui::RpcHub hub_; ProcessStore processes_; PasswordPrompter prompter_; - SshSessionManager sshSessionManager_; + std::shared_ptr sshSessionManager_; std::atomic_bool shuttingDown_; boost::asio::steady_timer childSignalTimer_; }; \ No newline at end of file diff --git a/backend/include/backend/rpc_helper.hpp b/backend/include/backend/rpc_helper.hpp new file mode 100644 index 00000000..c3693de4 --- /dev/null +++ b/backend/include/backend/rpc_helper.hpp @@ -0,0 +1,317 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// TODO: split and document this file! +namespace RpcHelper +{ + namespace Detail + { + template + struct CallOperatorApplyTypesOnly + {}; + + template + struct CallOperatorApplyTypesOnly> + { + template + static void apply(F&& func) + { + return std::forward(func).template operator()(); + } + }; + + template + struct DebugPrintType; + } + + class RpcInCorrectThread + { + public: + RpcInCorrectThread(Nui::Window& wnd, Nui::RpcHub& hub, std::string const& responseId) + : wnd_{&wnd} + , hub_{&hub} + , responseId_{responseId} + {} + RpcInCorrectThread(RpcInCorrectThread&&) = default; + RpcInCorrectThread& operator=(RpcInCorrectThread&&) = default; + RpcInCorrectThread(RpcInCorrectThread const&) = default; + RpcInCorrectThread& operator=(RpcInCorrectThread const&) = default; + ~RpcInCorrectThread() = default; + + void operator()(nlohmann::json&& json) const + { + wnd_->runInJavascriptThread([this, json = std::move(json)]() { + try + { + hub_->callRemote(responseId_, json); + } + catch (const std::exception& e) + { + Log::error("Failed to call rpc respond '{}': {}", responseId_, e.what()); + } + }); + } + + protected: + Nui::Window* wnd_; + Nui::RpcHub* hub_; + std::string responseId_; + }; + + class RpcOnce + { + public: + RpcOnce(Nui::Window& wnd, Nui::RpcHub& hub, std::string const& responseId) + : wnd_{&wnd} + , hub_{&hub} + , responseId_{responseId} + {} + RpcOnce(RpcOnce&&) = default; + RpcOnce& operator=(RpcOnce&&) = default; + RpcOnce(RpcOnce const&) = delete; + RpcOnce& operator=(RpcOnce const&) = delete; + ~RpcOnce() = default; + + void operator()(nlohmann::json&& json) const + { + if (called_) + { + Log::warn("RPC response with id '{}' already sent. Not sending: {}", responseId_, json.dump(4)); + return; + } + called_ = true; + + wnd_->runInJavascriptThread([hub = hub_, responseId = std::move(responseId_), json = std::move(json)]() { + try + { + hub->callRemote(responseId, json); + } + catch (const std::exception& e) + { + Log::error("Failed to call rpc respond '{}': {}", responseId, e.what()); + } + }); + } + + private: + Nui::Window* wnd_; + Nui::RpcHub* hub_; + mutable std::string responseId_; + mutable bool called_{false}; + }; + + class ParameterVerifyView + { + private: + RpcOnce& rpcOnce_; + nlohmann::json const& json_; + std::string const functionName_; + + public: + ParameterVerifyView(RpcOnce& rpcOnce, std::string const& functionName, nlohmann::json const& json) + : rpcOnce_{rpcOnce} + , json_{json} + , functionName_{functionName} + {} + ~ParameterVerifyView() = default; + ParameterVerifyView(ParameterVerifyView const&) = delete; + ParameterVerifyView& operator=(ParameterVerifyView const&) = delete; + + // Lets not have this view moveable somewhere, to avoid it being used after the original json is gone + ParameterVerifyView(ParameterVerifyView&&) = delete; + ParameterVerifyView& operator=(ParameterVerifyView&&) = delete; + + template + requires(std::convertible_to && (std::convertible_to && ...)) + bool hasValueDeep(KeyT&& key, KeysT&&... keys) + { + std::vector keyChain{}; + + auto onFail = [this, &keyChain]() { + keyChain.pop_back(); + rpcOnce_({ + { + "error", + fmt::format( + "Missing parameter to function '{}': {}", + functionName_, + keyChain | std::views::join | std::ranges::to()), + }, + }); + return false; + }; + + auto iter = json_.find(key); + auto currentEnd = end(json_); + + auto checkOnce = [&iter, &keyChain, &onFail, ¤tEnd](std::string_view key) { + keyChain.push_back(std::string{key}); + keyChain.push_back("."); + if (iter == currentEnd) + return onFail(); + currentEnd = end(*iter); + return true; + }; + + auto findNext = [&iter, &checkOnce](std::string_view key) { + iter = iter->find(key); + return checkOnce(key); + }; + + return checkOnce(key) && (findNext(keys) && ...); + } + }; + + template + auto rpcSafe(RpcOnce&& rpcOnce, FunctionT&& func) + { + return [rpcOnce = std::make_shared(std::move(rpcOnce)), + func = std::forward(func)](auto&&... args) mutable { + try + { + if constexpr (std::is_invocable_v) + func(std::move(*rpcOnce), std::forward(args)...); + else if constexpr (std::is_invocable_v) + func(std::forward(args)...); + else + static_assert(false, "FunctionT is not callable with the given arguments"); + } + catch (const std::exception& e) + { + (*rpcOnce)({{"error", e.what()}}); + } + }; + } + + class StrandRpc + { + public: + StrandRpc(boost::asio::any_io_executor executor, Nui::Window& wnd, Nui::RpcHub& hub) + : executor_{executor} + , strand_{std::make_shared>(executor_)} + , wnd_{&wnd} + , hub_{&hub} + , timer_{executor_} + {} + + StrandRpc( + boost::asio::any_io_executor executor, + std::shared_ptr> strand, + Nui::Window& wnd, + Nui::RpcHub& hub) + : executor_{executor} + , strand_{std::move(strand)} + , wnd_{&wnd} + , hub_{&hub} + , timer_{executor_} + {} + + virtual ~StrandRpc() = default; + StrandRpc(StrandRpc const&) = delete; + StrandRpc& operator=(StrandRpc const&) = delete; + StrandRpc(StrandRpc&&) = default; + StrandRpc& operator=(StrandRpc&&) = default; + + template + void registerOnStrand(std::string_view functionName, FunctionT&& func) + { + // (RpcReply&&, Param1, Param2, ...) => (Param1, Param2, ...) + using ArgsTuple = mplex::pop_front_t>::ArgsTuple>; + + // Call lambda with template arguments FunctionT, Param1, Param2, ... and no parameters. + Detail::CallOperatorApplyTypesOnly::apply([this, + functionName = std::string{functionName}, + func = std::forward( + func)]() mutable { + // Register function that is wrapped in a strand execute call: + registeredFunctions_.emplace_back(hub_->autoRegisterFunction( + functionName, + [this, func = std::move(func), functionName](std::string responseId, ParameterTs&&... parameters) { + strand_->execute([responseId = std::move(responseId), + ... parameters = std::forward(parameters), + func, + wnd = wnd_, + hub = hub_]() mutable { + // Threadsafe do: + RpcHelper::rpcSafe( + RpcHelper::RpcOnce{*wnd, *hub, std::move(responseId)}, + [¶meters..., &func](RpcOnce&& reply) mutable { + // Call actual function + func(std::move(reply), std::forward(parameters)...); + })(); + }); + })); + }); + } + + class Proxy + { + private: + StrandRpc* parent_; + std::string functionName; + + public: + Proxy(StrandRpc* parent, std::string functionName) + : parent_(parent) + , functionName(std::move(functionName)) + {} + + template + void perform(Args&&... args) + { + parent_->registerOnStrand(functionName, std::forward(args)...); + } + }; + Proxy on(std::string_view functionName) + { + return Proxy{this, std::string{functionName}}; + } + + void within_strand_do(auto&& func) const + { + if (!strand_->running_in_this_thread()) + return strand_->execute(std::forward(func)); + func(); + } + + void within_strand_do_no_recurse(auto&& func) const + { + return strand_->execute(std::forward(func)); + } + + void within_strand_do_delayed(auto&& func, std::chrono::steady_clock::duration delay) + { + timer_.expires_after(delay); + timer_.async_wait([func = std::forward(func)](auto const& ec) { + if (ec) + Log::info("Timer canceled in within_strand_do_delayed: {}", ec.message()); + + // Run func anyway, if its because of a shutdown, it will stop early. + func(); + }); + } + + protected: + boost::asio::any_io_executor executor_{}; + std::shared_ptr> strand_; + Nui::Window* wnd_; + Nui::RpcHub* hub_; + std::vector registeredFunctions_{}; + boost::asio::steady_timer timer_; + }; +} \ No newline at end of file diff --git a/backend/include/backend/session.hpp b/backend/include/backend/session.hpp new file mode 100644 index 00000000..562e337c --- /dev/null +++ b/backend/include/backend/session.hpp @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/** + * @brief This session is the implementation equivalent of one tab in the UI. + * A tab in the UI represents everything related to one server address and port. + * + * These sessions are managed by SessionManager. Each session has an id, as well as ssh channels, sftp channels and + */ +class Session + : public RpcHelper::StrandRpc + , public std::enable_shared_from_this +{ + public: + constexpr static auto futureTimeout = std::chrono::seconds{10}; + constexpr static auto queueStartThrottle = std::chrono::milliseconds{5}; + constexpr static auto queueMaxThrottle = std::chrono::seconds{3}; + + Session( + Ids::SessionId id, + std::unique_ptr session, + boost::asio::any_io_executor executor, + std::shared_ptr> strand, + Nui::Window& wnd, + Nui::RpcHub& hub, + Persistence::SftpOptions const& sftpOptions); + + Session(const Session&) = delete; + Session& operator=(const Session&) = delete; + Session(Session&&) = delete; + Session& operator=(Session&&) = delete; + ~Session() = default; + + void start(); + void stop(); + + private: + /** + * Handles calls from the frontend to create a new channel with the following payload: + * { + * engine: { + * sshSessionOptions: Persistence::SshSessionOptions, + * environment: std::unordered_map + * }, + * fileMode: boolean (optional, default false) + * } + * + */ + void registerRpcCreateChannel(); + + /** + * Handles calls from the frontend to start reading from a channel with the following payload: + * { + * channelId: string + * } + */ + void registerRpcStartChannelRead(); + + /** + * Handles calls from the frontend to close a channel with the following payload: + * { + * channelId: string + * } + */ + void registerRpcChannelClose(); + + /** + * Handles calls from the frontend to write data to a channel with the following payload: + * { + * channelId: string, + * data: string (base64 encoded) + * } + */ + void registerRpcChannelWrite(); + + /** + * Handles calls from the frontend to resize the pty of a channel with the following payload: + * { + * channelId: string, + * cols: int, + * rows: int + * } + */ + void registerRpcChannelPtyResize(); + + /** + * Handles calls from the frontend to list a directory over sftp with the following payload: + * { + * sftpChannelId: string, + * path: string + * } + */ + void registerRpcSftpListDirectory(); + + /** + * Handles calls from the frontend to create a directory over sftp with the following payload: + * { + * sftpChannelId: string, + * path: string + * } + */ + void registerRpcSftpCreateDirectory(); + + /** + * Handles calls from the frontend to create a file over sftp with the following payload: + * { + * sftpChannelId: string, + * path: string + * } + */ + void registerRpcSftpCreateFile(); + void registerRpcSftpAddDownloadOperation(); + void registerOperationQueuePauseUnpause(); + + void removeChannel(Ids::ChannelId channelId); + + void removeSftpChannel(Ids::ChannelId channelId); + + template + void withChannelDo(Ids::ChannelId channelId, FunctionT&& func, ReplyCallable&& reply) + { + within_strand_do([this, + channelId = std::move(channelId), + func = std::forward(func), + reply = std::forward(reply)]() mutable { + if (auto iter = channels_.find(channelId); iter != channels_.end()) + { + if (auto channel = iter->second.lock(); channel) + { + func(std::move(reply), std::move(channel)); + } + else + { + Log::error("Failed to lock channel with id: {}", channelId.value()); + removeChannel(channelId); + return reply({{"error", "Failed to lock channel"}}); + } + } + else + { + Log::error("No channel found with id: {}", channelId.value()); + return reply({{"error", "No channel found with id"}}); + } + }); + } + + template + void withSftpChannelDo(Ids::ChannelId channelId, FunctionT&& func, ReplyCallable&& reply) + { + within_strand_do([this, + channelId = std::move(channelId), + func = std::forward(func), + reply = std::forward(reply)]() mutable { + if (auto iter = sftpChannels_.find(channelId); iter != sftpChannels_.end()) + { + if (auto channel = iter->second.lock(); channel) + { + func(std::move(reply), std::move(channel)); + } + else + { + Log::error("Failed to lock channel with id: {}", channelId.value()); + removeSftpChannel(channelId); + return reply({{"error", "Failed to lock channel"}}); + } + } + else + { + Log::error("No channel found with id: {}", channelId.value()); + return reply({{"error", "No channel found with id"}}); + } + }); + } + + void doOperationQueueWork(); + + private: + void resetQueueThrottle(); + + private: + Ids::SessionId id_; + /// Has nothing to do with pause/unpause - this is used for shutdown of the session. + std::atomic_bool running_; + std::chrono::milliseconds operationThrottle_{5}; + bool queueThrottleTimerIsRunning_{false}; + int unthrottledLimitCounter_{0}; + std::unique_ptr session_{}; + std::unordered_map, Ids::IdHash> channels_{}; + std::unordered_map, Ids::IdHash> sftpChannels_{}; + std::shared_ptr operationQueue_; +}; \ No newline at end of file diff --git a/backend/include/backend/session_manager.hpp b/backend/include/backend/session_manager.hpp new file mode 100644 index 00000000..e85f1051 --- /dev/null +++ b/backend/include/backend/session_manager.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +class SessionManager + : public std::enable_shared_from_this + , public RpcHelper::StrandRpc +{ + public: + SessionManager( + boost::asio::any_io_executor executor, + Persistence::StateHolder& stateHolder, + Nui::Window& wnd, + Nui::RpcHub& hub); + ~SessionManager() = default; + SessionManager(SessionManager const&) = delete; + SessionManager& operator=(SessionManager const&) = delete; + SessionManager(SessionManager&&) = delete; + SessionManager& operator=(SessionManager&&) = delete; + + void registerRpc(); + + void addPasswordProvider(int priority, PasswordProvider* provider); + void joinSessionAdder(); + void addSession( + Persistence::SshTerminalEngine const& engine, + std::function const&)> onComplete); + + friend int askPassDefault(char const* prompt, char* buf, std::size_t length, int echo, int verify, void* userdata); + + // void startUpdateDispatching(); + // void stopUpdateDispatching(); + + private: + // void dispatchUpdates(); + + /** + * Handles calls from the frontend to connect to a new ssh server with the following payload: + * { + * engine: Persistence::SshTerminalEngine + * } + */ + void registerRpcSessionConnect(); + + /** + * Handles calls from the frontend to disconnect from an ssh server with the following payload: + * { + * sessionId: string + * } + */ + void registerRpcSessionDisconnect(); + + /// Removes a session and closes all its channels. Safe to call from any thread. + void removeSession(Ids::SessionId sessionId); + + private: + Persistence::StateHolder* stateHolder_{}; + std::unordered_map, Ids::IdHash> sessions_{}; + + std::map passwordProviders_{}; + std::unique_ptr addSessionThread_{}; + std::vector pwCache_{}; + std::atomic_bool updateDispatchRunning_{false}; +}; + +int askPassDefault(char const* prompt, char* buf, std::size_t length, int echo, int verify, void* userdata); \ No newline at end of file diff --git a/backend/include/backend/sftp/all_operations.hpp b/backend/include/backend/sftp/all_operations.hpp new file mode 100644 index 00000000..c0a48093 --- /dev/null +++ b/backend/include/backend/sftp/all_operations.hpp @@ -0,0 +1,30 @@ +#include +#include +#include +#include + +template +auto Operation::visit(FunctionT&& func) const +{ + using enum SharedData::OperationType; + switch (type()) + { + case Download: + { + return func(static_cast(*this)); + } + case Scan: + { + return func(static_cast(*this)); + } + case BulkDownload: + { + return func(static_cast(*this)); + } + default: + { + Log::error("Operation: Unknown operation type: {}", static_cast(type())); + return func(std::nullopt); + } + } +} \ No newline at end of file diff --git a/backend/include/backend/sftp/bulk_download_operation.hpp b/backend/include/backend/sftp/bulk_download_operation.hpp new file mode 100644 index 00000000..a4e76af6 --- /dev/null +++ b/backend/include/backend/sftp/bulk_download_operation.hpp @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +// TODO: concurrent downloads! +class BulkDownloadOperation : public Operation +{ + public: + struct BulkDownloadOperationOptions + { + std::function + overallProgressCallback = [](auto const&, auto, auto, auto, auto) {}; + + // TODO: as tar, do compress tar? + DownloadOperation::DownloadOperationOptions individualOptions = {}; + }; + + BulkDownloadOperation(SecureShell::SftpSession& sftp, BulkDownloadOperationOptions options); + ~BulkDownloadOperation() override; + BulkDownloadOperation(BulkDownloadOperation const&) = delete; + BulkDownloadOperation(BulkDownloadOperation&&) = delete; + BulkDownloadOperation& operator=(BulkDownloadOperation const&) = delete; + BulkDownloadOperation& operator=(BulkDownloadOperation&&) = delete; + + std::expected work() override; + SharedData::OperationType type() const override; + std::expected cancel(bool adoptCancelState) override; + + void setScanResult(std::vector&& entries); + + bool isBarrier() const noexcept override + { + return false; + } + + // TODO: can do more than 1. + int parallelWorkDoable(int) const noexcept override + { + return 1; + } + + SecureShell::ProcessingStrand* strand() const override; + + private: + SecureShell::SftpSession* sftp_; + BulkDownloadOperationOptions options_; + std::unique_ptr currentDownload_; + std::vector entries_; + std::uint64_t totalBytes_{0}; + std::uint64_t currentIndex_{0}; + std::uint64_t currentBytes_{0}; + std::chrono::seconds futureTimeout_{5}; +}; \ No newline at end of file diff --git a/backend/include/backend/sftp/download_operation.hpp b/backend/include/backend/sftp/download_operation.hpp new file mode 100644 index 00000000..7459274a --- /dev/null +++ b/backend/include/backend/sftp/download_operation.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +class DownloadOperation : public Operation +{ + public: + struct DownloadOperationOptions + { + std::function progressCallback = + [](auto, auto, auto) {}; + std::filesystem::path remotePath{}; + std::filesystem::path localPath{}; + std::string tempFileSuffix{".filepart"}; + bool mayOverwrite{false}; + bool reserveSpace{false}; + bool tryContinue{false}; + bool inheritPermissions{false}; + bool doCleanup{true}; + std::optional permissions{std::nullopt}; + std::chrono::seconds futureTimeout{5}; + }; + + SecureShell::ProcessingStrand* strand() const override + { + if (auto stream = fileStream_.lock(); stream) + return stream->strand(); + return nullptr; + } + + DownloadOperation(std::weak_ptr fileStream, DownloadOperationOptions options); + ~DownloadOperation() override; + DownloadOperation(DownloadOperation const&) = delete; + DownloadOperation(DownloadOperation&&) = delete; + DownloadOperation& operator=(DownloadOperation const&) = delete; + DownloadOperation& operator=(DownloadOperation&&) = delete; + + std::expected work() override; + + bool isBarrier() const noexcept override + { + return false; + } + + /** + * @brief How much parallel work does this operation do. + * + * @param parallel Maximum parallelism allowed. + * @return The amount of parallel work that can be done maxed by parallel parameter. + */ + int parallelWorkDoable(int) const noexcept override + { + return 1; + } + + SharedData::OperationType type() const override + { + return SharedData::OperationType::Download; + } + + std::filesystem::path remotePath() const + { + return remotePath_; + } + + std::filesystem::path localPath() const + { + return localPath_; + } + + std::expected cancel(bool adoptCancelState) override; + + std::expected prepare(); + std::expected finalize(); + + private: + /// Returns true if there is more data to read, false if the operation is complete. + std::expected readOnce(); + + std::expected openOrAdoptFile(SecureShell::IFileStream& stream); + + void cleanup(); + + private: + std::weak_ptr fileStream_; + std::filesystem::path remotePath_; + std::filesystem::path localPath_; + std::string tempFileSuffix_; + std::function progressCallback_; + bool mayOverwrite_; + bool reserveSpace_; + bool tryContinue_; + bool inheritPermissions_; + bool doCleanup_; + std::optional permissions_; + std::ofstream localFile_; + std::uint64_t fileSize_; + std::chrono::seconds futureTimeout_; + std::array buffer_; +}; \ No newline at end of file diff --git a/backend/include/backend/sftp/operation.hpp b/backend/include/backend/sftp/operation.hpp new file mode 100644 index 00000000..751eb805 --- /dev/null +++ b/backend/include/backend/sftp/operation.hpp @@ -0,0 +1,126 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +class Operation +{ + public: + Operation() + : id_{Ids::generateOperationId()} + {} + Operation(Operation const&) = delete; + Operation& operator=(Operation const&) = delete; + Operation(Operation&&) = delete; + Operation& operator=(Operation&&) = delete; + + virtual ~Operation() = default; + + using ErrorType = SharedData::OperationErrorType; + using Error = SharedData::OperationError; + + virtual SharedData::OperationType type() const = 0; + + template + auto visit(FunctionT&& func) const; + + virtual SecureShell::ProcessingStrand* strand() const = 0; + + template + bool perform(FunctionT&& func) + { + if (auto* theStrand = strand(); theStrand) + return theStrand->pushTask(std::forward(func)); + else + { + Log::error("Operation: Cannot perform task on strand, no processing strand available."); + return false; + } + } + + Ids::OperationId id() const + { + return id_; + } + + using OperationState = SharedData::OperationState; + + OperationState state() const + { + return state_; + } + + /** + * @brief Can parallel actions go beyond this operation? + * + * @return true Cannot progress beyond this operation. + * @return false Can progress beyond this operation. + */ + virtual bool isBarrier() const noexcept = 0; + + /** + * @brief How much parallel work does this operation do. + * + * @param parallel Maximum parallelism allowed. + * @return The amount of parallel work that can be done maxed by parallel parameter. + */ + virtual int parallelWorkDoable(int parallel) const noexcept = 0; + + enum class WorkStatus + { + MoreWork, + Complete + }; + + /** + * @brief Performs work for the operation depending on the operation type. + * + * @return std::expected, true if it wants to be retriggered without delay. + */ + virtual std::expected work() = 0; + + /** + * @brief Cancels the operation. + * + * @return std::expected + */ + virtual std::expected cancel(bool adoptCancelState) = 0; + + template + std::expected enterErrorState(Error error) + { + state_ = OperationState::Failed; + error_ = std::move(error); + const auto cancelResult = cancel(false); + if (!cancelResult.has_value()) + { + Log::error("Operation: Failed to cancel operation: {}", cancelResult.error().toString()); + // If cancel fails, we still want to call the completion callback. + } + return std::unexpected(error_.value()); + } + + protected: + void enterState(OperationState newState) + { + state_ = newState; + } + + protected: + OperationState state_{OperationState::NotStarted}; + std::optional error_{std::nullopt}; + + private: + Ids::OperationId id_; +}; \ No newline at end of file diff --git a/backend/include/backend/sftp/operation_queue.hpp b/backend/include/backend/sftp/operation_queue.hpp new file mode 100644 index 00000000..fa873c6e --- /dev/null +++ b/backend/include/backend/sftp/operation_queue.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +class OperationQueue + : public RpcHelper::StrandRpc + , public std::enable_shared_from_this +{ + public: + using Error = SharedData::OperationErrorType; + using OperationCompleted = SharedData::OperationCompleted; + using CompletionReason = SharedData::OperationCompletionReason; + + public: + OperationQueue( + boost::asio::any_io_executor executor, + std::shared_ptr> strand, + Nui::Window& wnd, + Nui::RpcHub& hub, + Persistence::SftpOptions sftpOpts, + Ids::SessionId sessionId, + int parallelism = 1); + + void cancelAll(); + void cancel(Ids::OperationId id); + + /** + * @brief Returns true if it should be called without delay again. + * + * @return true + * @return false + */ + bool work(); + + std::expected addDownloadOperation( + SecureShell::SftpSession& sftp, + Ids::OperationId operationId, + std::filesystem::path const& localPath, + std::filesystem::path const& remotePath); + + void registerRpc(); + + bool paused() const; + void paused(bool pause); + + private: + void completeOperation(OperationCompleted&& operationCompleted); + + private: + Persistence::SftpOptions sftpOpts_{}; + Ids::SessionId sessionId_{}; + std::deque>> operations_{}; + std::atomic_bool paused_{true}; + int parallelism_{1}; +}; diff --git a/backend/include/backend/sftp/scan_operation.hpp b/backend/include/backend/sftp/scan_operation.hpp new file mode 100644 index 00000000..72ed7552 --- /dev/null +++ b/backend/include/backend/sftp/scan_operation.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +class ScanOperation : public Operation +{ + public: + struct ScanOperationOptions + { + std::function + progressCallback = [](auto, auto, auto) {}; + std::filesystem::path remotePath{}; + std::chrono::seconds futureTimeout{5}; + }; + + SecureShell::ProcessingStrand* strand() const override; + + ScanOperation(SecureShell::SftpSession& sftp, ScanOperationOptions options); + ~ScanOperation() override; + ScanOperation(ScanOperation const&) = delete; + ScanOperation(ScanOperation&&) = delete; + ScanOperation& operator=(ScanOperation const&) = delete; + ScanOperation& operator=(ScanOperation&&) = delete; + + std::expected work() override; + + SharedData::OperationType type() const override + { + return SharedData::OperationType::Scan; + } + + bool isBarrier() const noexcept override + { + return true; + } + + int parallelWorkDoable(int) const noexcept override + { + return 1; + } + + std::filesystem::path remotePath() const + { + return remotePath_; + } + + std::expected cancel(bool adoptCancelState) override; + + private: + std::expected scanOnce(std::filesystem::path const& path); + + private: + SecureShell::SftpSession* sftp_; + std::filesystem::path remotePath_; + std::function + progressCallback_; + std::chrono::seconds futureTimeout_; + std::vector entries_{}; + std::uint64_t currentIndex_{0}; + std::uint64_t totalBytes_{0}; +}; \ No newline at end of file diff --git a/backend/include/backend/ssh_session_manager.hpp b/backend/include/backend/ssh_session_manager.hpp deleted file mode 100644 index 05288df7..00000000 --- a/backend/include/backend/ssh_session_manager.hpp +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include -#include -// #include -// #include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include - -class SshSessionManager -{ - public: - SshSessionManager(); - ~SshSessionManager(); - SshSessionManager(SshSessionManager const&) = delete; - SshSessionManager& operator=(SshSessionManager const&) = delete; - SshSessionManager(SshSessionManager&&) = delete; - SshSessionManager& operator=(SshSessionManager&&) = delete; - - void registerRpc(Nui::Window& wnd, Nui::RpcHub& rpcHub); - - void addPasswordProvider(int priority, PasswordProvider* provider); - void joinSessionAdder(); - void addSession( - Persistence::SshTerminalEngine const& engine, - std::function const&)> onComplete); - - friend int askPassDefault(char const* prompt, char* buf, std::size_t length, int echo, int verify, void* userdata); - - private: - void registerRpcConnect(Nui::Window& wnd, Nui::RpcHub& hub); - void registerRpcCreateChannel(Nui::Window& wnd, Nui::RpcHub& hub); - void registerRpcStartChannelRead(Nui::Window& wnd, Nui::RpcHub& hub); - void registerRpcChannelClose(Nui::Window& wnd, Nui::RpcHub& hub); - void registerRpcEndSession(Nui::Window&, Nui::RpcHub& hub); - void registerRpcChannelWrite(Nui::Window&, Nui::RpcHub& hub); - void registerRpcChannelPtyResize(Nui::Window&, Nui::RpcHub& hub); - - private: - std::mutex passwordProvidersMutex_{}; - std::mutex addSessionMutex_{}; - std::unordered_map, Ids::IdHash> sessions_{}; - std::unordered_map, Ids::IdHash> channels_{}; - std::map passwordProviders_{}; - std::unique_ptr addSessionThread_{}; - std::vector pwCache_{}; -}; - -int askPassDefault(char const* prompt, char* buf, std::size_t length, int echo, int verify, void* userdata); \ No newline at end of file diff --git a/backend/source/backend/CMakeLists.txt b/backend/source/backend/CMakeLists.txt index b81aa3f8..c08aeec6 100644 --- a/backend/source/backend/CMakeLists.txt +++ b/backend/source/backend/CMakeLists.txt @@ -1,47 +1,59 @@ -target_sources( - ${PROJECT_NAME} - PRIVATE +add_library( + backend + STATIC password/password_prompter.cpp process/process.cpp process/process_store.cpp process/environment.cpp - ssh_session_manager.cpp - main.cpp + session_manager.cpp + session.cpp + sftp/operation_queue.cpp + sftp/download_operation.cpp + sftp/scan_operation.cpp + sftp/bulk_download_operation.cpp ) if (WIN32) target_sources( - ${PROJECT_NAME} + backend PRIVATE pty/windows/conpty.cpp pty/windows/pipe.cpp ) else() target_sources( - ${PROJECT_NAME} + backend PRIVATE pty/linux/pty.cpp ) endif() +target_sources( + ${PROJECT_NAME} + PRIVATE + main.cpp +) + find_package(Boost CONFIG 1.86.0 REQUIRED COMPONENTS system filesystem asio) -target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_SOURCE_DIR}/backend/include") +target_include_directories(backend PUBLIC "${CMAKE_SOURCE_DIR}/backend/include") -target_compile_definitions(${PROJECT_NAME} PRIVATE -DSSH_NO_CPP_EXCEPTIONS) +target_compile_definitions(backend PUBLIC -DSSH_NO_CPP_EXCEPTIONS) include("${CMAKE_CURRENT_LIST_DIR}/../../../_cmake/dependencies/libssh.cmake") # Link backend of nui outside of emscripten target_link_libraries( - ${PROJECT_NAME} - PRIVATE + backend + PUBLIC nui-backend + mplex efsw-static Boost::filesystem Boost::asio Boost::system Boost::process + roar-include-only shared_data utility log @@ -51,8 +63,16 @@ target_link_libraries( events ) +target_link_libraries( + ${PROJECT_NAME} + PRIVATE + backend +) + nui_set_target_output_directories(${PROJECT_NAME}) +add_dependencies(backend scp-resource-copy) + # Creates a target that is compiled through emscripten. This target becomes the frontend part. nui_add_emscripten_target( TARGET @@ -65,6 +85,4 @@ nui_add_emscripten_target( CMAKE_OPTIONS # I recommend to work with a release build by default because debug builds get big fast. -DCMAKE_BUILD_TYPE=Release -) - -add_dependencies(${PROJECT_NAME} scp-resource-copy) \ No newline at end of file +) \ No newline at end of file diff --git a/backend/source/backend/main.cpp b/backend/source/backend/main.cpp index 152e87dc..f254c23f 100644 --- a/backend/source/backend/main.cpp +++ b/backend/source/backend/main.cpp @@ -158,11 +158,11 @@ Main::Main(int const, char const* const* argv) , hub_{window_} , processes_{window_.getExecutor(), window_, hub_} , prompter_{hub_} - , sshSessionManager_{} + , sshSessionManager_{std::make_shared(window_.getExecutor(), stateHolder_, window_, hub_)} , shuttingDown_{false} , childSignalTimer_{window_.getExecutor()} { - sshSessionManager_.addPasswordProvider(-99, &prompter_); + sshSessionManager_->addPasswordProvider(-99, &prompter_); stateHolder_.load([](bool success, Persistence::StateHolder& holder) { if (!success) @@ -174,6 +174,7 @@ Main::Main(int const, char const* const* argv) Main::~Main() { shuttingDown_ = true; + // sshSessionManager_->stopUpdateDispatching(); childSignalTimer_.cancel(); } @@ -189,7 +190,7 @@ void Main::registerRpc() Log::setupBackendRpcHub(&hub_); stateHolder_.registerRpc(hub_); processes_.registerRpc(window_, hub_); - sshSessionManager_.registerRpc(window_, hub_); + sshSessionManager_->registerRpc(); } void Main::show() @@ -261,10 +262,12 @@ int main(int const argc, char const* const* argv) ssh_init(); - Main m{argc, argv}; - m.registerRpc(); - m.startChildSignalTimer(); - m.show(); + { + Main m{argc, argv}; + m.registerRpc(); + m.startChildSignalTimer(); + m.show(); + } ssh_finalize(); } \ No newline at end of file diff --git a/backend/source/backend/pty/windows/conpty.cpp b/backend/source/backend/pty/windows/conpty.cpp index 14b9efff..8c546b77 100644 --- a/backend/source/backend/pty/windows/conpty.cpp +++ b/backend/source/backend/pty/windows/conpty.cpp @@ -47,6 +47,7 @@ namespace ConPTY #pragma clang diagnostic push // This is winapi for you. #pragma clang diagnostic ignored "-Wcast-function-type-strict" +#pragma clang diagnostic ignored "-Wcast-function-type-mismatch" , createPseudoConsole{reinterpret_cast( GetProcAddress(kernel32, "CreatePseudoConsole"))} , resizePseudoConsole{reinterpret_cast( diff --git a/backend/source/backend/session.cpp b/backend/source/backend/session.cpp new file mode 100644 index 00000000..ab5107f1 --- /dev/null +++ b/backend/source/backend/session.cpp @@ -0,0 +1,505 @@ +#include + +#include +#include + +using namespace std::chrono_literals; + +Session::Session( + Ids::SessionId id, + std::unique_ptr session, + boost::asio::any_io_executor executor, + std::shared_ptr> strand, + Nui::Window& wnd, + Nui::RpcHub& hub, + Persistence::SftpOptions const& sftpOptions) + : RpcHelper::StrandRpc{executor, std::move(strand), wnd, hub} + , id_{std::move(id)} + , session_{std::move(session)} + , operationQueue_{std::make_shared< + OperationQueue>(executor_, strand_, wnd, hub, sftpOptions, id_, sftpOptions.concurrency.value_or(1))} +{} + +void Session::start() +{ + within_strand_do([weak = weak_from_this()]() { + auto self = weak.lock(); + if (!self) + return; + + if (!self->session_) + { + Log::error("Session is not valid, cannot start"); + return; + } + + self->session_->start(); + self->registerRpcCreateChannel(); + self->registerRpcStartChannelRead(); + self->registerRpcChannelClose(); + self->registerRpcChannelWrite(); + self->registerRpcChannelPtyResize(); + self->registerRpcSftpListDirectory(); + self->registerRpcSftpCreateDirectory(); + self->registerRpcSftpCreateFile(); + self->registerRpcSftpAddDownloadOperation(); + self->registerOperationQueuePauseUnpause(); + self->operationQueue_->registerRpc(); + + Log::info("Session '{}' connected", self->id_.value()); + + self->running_ = true; + self->doOperationQueueWork(); + }); +} + +void Session::stop() +{ + running_ = false; + timer_.cancel(); +} + +void Session::doOperationQueueWork() +{ + within_strand_do_no_recurse([weak = weak_from_this()]() { + auto self = weak.lock(); + if (!self) + return; + + if (!self->running_) + return; + + if (self->operationQueue_ && !self->operationQueue_->paused() && self->operationQueue_->work()) + { + ++self->unthrottledLimitCounter_; + if (self->unthrottledLimitCounter_ < 10) + { + self->doOperationQueueWork(); + self->operationThrottle_ = queueStartThrottle; + return; + } + self->unthrottledLimitCounter_ = 0; + } + + if (self->operationThrottle_ < queueMaxThrottle) + self->operationThrottle_ *= 2; + else + self->operationThrottle_ = queueMaxThrottle; + + self->queueThrottleTimerIsRunning_ = true; + self->within_strand_do_delayed( + [weak = self->weak_from_this()]() { + if (auto self = weak.lock(); self) + { + self->queueThrottleTimerIsRunning_ = false; + self->doOperationQueueWork(); + } + }, + self->operationThrottle_); + }); +} + +void Session::resetQueueThrottle() +{ + within_strand_do([weak = weak_from_this()]() { + auto self = weak.lock(); + if (!self) + return; + + self->operationThrottle_ = queueStartThrottle; + if (self->queueThrottleTimerIsRunning_) + self->timer_.cancel(); + }); +} + +void Session::removeChannel(Ids::ChannelId channelId) +{ + within_strand_do([weak = weak_from_this(), channelId = std::move(channelId)]() { + auto self = weak.lock(); + if (!self) + return; + + auto iter = self->channels_.find(channelId); + if (iter != self->channels_.end()) + { + if (auto locked = iter->second.lock(); locked) + locked->close(); + + self->channels_.erase(iter); + } + else + { + Log::warn("Cannot remove channel, no channel found with id: {}", channelId.value()); + } + }); +} + +void Session::removeSftpChannel(Ids::ChannelId channelId) +{ + within_strand_do([weak = weak_from_this(), channelId = std::move(channelId)]() { + auto self = weak.lock(); + if (!self) + return; + + auto iter = self->sftpChannels_.find(channelId); + if (iter != self->sftpChannels_.end()) + { + if (auto locked = iter->second.lock(); locked) + locked->close(); + + self->sftpChannels_.erase(iter); + } + else + { + Log::warn("Cannot remove sftp channel, no channel found with id: {}", channelId.value()); + } + }); +} + +void Session::registerRpcCreateChannel() +{ + on(fmt::format("Session::{}::Channel::create", id_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply, nlohmann::json const& parameters) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + RpcHelper::ParameterVerifyView verify{ + reply, fmt::format("Session::{}::Channel::create", self->id_.value()), parameters}; + + if (!verify.hasValueDeep("engine")) + return; + + const bool fileMode = parameters.contains("fileMode") && parameters["fileMode"].is_boolean() && + parameters["fileMode"].get(); + + if (!fileMode) + { + Log::info("Creating pty channel for session '{}'", self->id_.value()); + + const auto sessionOptions = + parameters["engine"]["sshSessionOptions"].get(); + + const auto weakChannel = + self->session_->createPtyChannel({.environment = sessionOptions.environment}).get(); + if (!weakChannel.has_value()) + { + Log::error("Failed to create pty channel: {}", weakChannel.error()); + return reply({{"error", "Failed to create pty channel"}}); + } + + const auto channelId = Ids::generateChannelId(); + self->channels_.emplace(channelId, std::move(weakChannel).value()); + + Log::info( + "Created pty channel with id '{}', channel total is now '{}'.", + channelId.value(), + self->channels_.size()); + + return reply({{"id", channelId.value()}}); + } + else + { + Log::info("Creating sftp channel for session '{}'", self->id_.value()); + + const auto weakChannel = self->session_->createSftpSession().get(); + if (!weakChannel.has_value()) + { + Log::error("Failed to create sftp channel: {}", weakChannel.error().toString()); + return reply({{"error", "Failed to create sftp channel"}}); + } + + const auto channelId = Ids::generateChannelId(); + self->sftpChannels_.emplace(channelId, std::move(weakChannel).value()); + + Log::info( + "Created sftp channel with id '{}', sftp channel total is now '{}'.", + channelId.value(), + self->sftpChannels_.size()); + + return reply({{"id", channelId.value()}}); + } + }); +} + +void Session::registerRpcStartChannelRead() +{ + on(fmt::format("Session::{}::Channel::startReading", id_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply, std::string const& channelIdString) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + const auto channelId = Ids::makeChannelId(channelIdString); + if (self->channels_.find(channelId) == self->channels_.end()) + { + Log::error("No channel found with id: {}", channelId.value()); + return reply({{"error", "No channel found with id"}}); + } + + auto locked = self->channels_[channelId].lock(); + if (!locked) + { + Log::error("Failed to lock channel with id: {}", channelId.value()); + self->removeChannel(channelId); + return reply({{"error", "Failed to lock channel"}}); + } + + // Do not use this in these functions, because they are called from a different thread, + // unless within_strand_do is used! + locked->startReading( + // Stdout + [sessionId = self->id_, + channelId, + stdOut = + RpcHelper::RpcInCorrectThread{ + *self->wnd_, + *self->hub_, + fmt::format("sshTerminalStdout_{}", channelId.value()), + }](std::string const& msg) { + stdOut( + nlohmann::json{ + {"sessionId", sessionId.value()}, + {"channelId", channelId.value()}, + {"data", Roar::base64Encode(msg)}, + }); + }, + // Stderr + [sessionId = self->id_, + channelId, + stdErr = + RpcHelper::RpcInCorrectThread{ + *self->wnd_, + *self->hub_, + fmt::format("sshTerminalStderr_{}", channelId.value()), + }](std::string const& data) { + stdErr( + nlohmann::json{ + {"sessionId", sessionId.value()}, + {"channelId", channelId.value()}, + {"data", Roar::base64Encode(data)}, + }); + }, + // On channel exit: + [removeChannel = + [weak = self->weak_from_this(), channelId]() { + if (auto self = weak.lock(); self) + self->removeChannel(channelId); + }, + sessionId = self->id_, + channelId, + onExit = RpcHelper::RpcInCorrectThread{ + *self->wnd_, + *self->hub_, + fmt::format("sshTerminalOnExit_{}", channelId.value()), + }]() { + Log::info("Channel for session '{}' lost with id: {}", sessionId.value(), channelId.value()); + removeChannel(); + onExit({{"sessionId", sessionId.value()}, {"channelId", channelId.value()}}); + }); + + return reply({{"success", true}}); + }); +} + +void Session::registerRpcChannelClose() +{ + on(fmt::format("Session::{}::Channel::close", id_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply, std::string const& channelIdString) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->removeChannel(Ids::makeChannelId(channelIdString)); + return reply({{"success", true}}); + }); +} + +void Session::registerRpcChannelWrite() +{ + on(fmt::format("Session::{}::Channel::write", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, std::string const& channelIdString, std::string&& data) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withChannelDo( + Ids::makeChannelId(channelIdString), + [data = std::move(data)](RpcHelper::RpcOnce&& reply, auto&& channel) { + channel->write(Roar::base64Decode(data)); + reply({{"success", true}}); + }, + std::move(reply)); + }); +} + +void Session::registerRpcChannelPtyResize() +{ + on(fmt::format("Session::{}::Channel::ptyResize", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, std::string const& channelIdString, int cols, int rows) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withChannelDo( + Ids::makeChannelId(channelIdString), + [cols, rows](RpcHelper::RpcOnce&& reply, auto&& channel) { + channel->resizePty(cols, rows); + reply({{"success", true}}); + }, + std::move(reply)); + }); +} + +void Session::registerRpcSftpListDirectory() +{ + on(fmt::format("Session::{}::sftp::listDirectory", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, std::string const& channelIdString, std::string const& path) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withSftpChannelDo( + Ids::makeChannelId(channelIdString), + [path](RpcHelper::RpcOnce&& reply, auto&& channel) { + auto fut = channel->listDirectory(path); + if (fut.wait_for(futureTimeout) != std::future_status::ready) + return reply({{"error", "Failed to list directory: timeout"}}); + + const auto result = fut.get(); + if (!result.has_value()) + return reply({{"error", result.error().message}}); + + Log::info("Listed directory '{}', got {} entries", path, result->size()); + reply({{"entries", *result}}); + }, + std::move(reply)); + }); +} + +void Session::registerRpcSftpCreateDirectory() +{ + on(fmt::format("Session::{}::sftp::createDirectory", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, std::string const& channelIdString, std::string const& path) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withSftpChannelDo( + Ids::makeChannelId(channelIdString), + [path](RpcHelper::RpcOnce&& reply, auto&& channel) { + auto fut = channel->createDirectory(path); + if (fut.wait_for(futureTimeout) != std::future_status::ready) + return reply({{"error", "Failed to create directory: timeout"}}); + + const auto result = fut.get(); + if (!result.has_value()) + return reply({{"error", result.error().message}}); + + Log::info("Created directory '{}'", path); + reply({{"success", true}}); + }, + std::move(reply)); + }); +} + +void Session::registerRpcSftpCreateFile() +{ + on(fmt::format("Session::{}::sftp::createFile", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, std::string const& channelIdString, std::string const& path) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withSftpChannelDo( + Ids::makeChannelId(channelIdString), + [path](RpcHelper::RpcOnce&& reply, auto&& channel) { + auto fut = channel->createFile(path); + if (fut.wait_for(futureTimeout) != std::future_status::ready) + return reply({{"error", "Failed to create file: timeout"}}); + + const auto result = fut.get(); + if (!result.has_value()) + return reply({{"error", result.error().message}}); + + Log::info("Created file '{}'", path); + reply({{"success", true}}); + }, + std::move(reply)); + }); +} + +void Session::registerRpcSftpAddDownloadOperation() +{ + on(fmt::format("Session::{}::sftp::addDownload", id_.value())) + .perform([weak = weak_from_this()]( + RpcHelper::RpcOnce&& reply, + std::string const& channelIdString, + std::string const& newOperationIdString, + std::string const& remotePath, + std::string const& localPath) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + self->withSftpChannelDo( + Ids::makeChannelId(channelIdString), + [weak = self->weak_from_this(), newOperationIdString, localPath, remotePath]( + RpcHelper::RpcOnce&& reply, auto&& channel) { + auto self = weak.lock(); + if (!self) + return reply({{"error", "Session no longer exists"}}); + + const auto result = self->operationQueue_->addDownloadOperation( + *channel, Ids::makeOperationId(newOperationIdString), localPath, remotePath); + + if (!result.has_value()) + { + Log::error( + "Failed to add download operation for file '{}' to '{}': {}", + remotePath, + localPath, + result.error().toString()); + return reply({{"error", result.error().toString()}}); + } + + Log::info( + "Added download operation with id '{}' for file '{}' to '{}'", + newOperationIdString, + remotePath, + localPath); + + self->resetQueueThrottle(); + reply({{"success", true}}); + }, + std::move(reply)); + }); +} + +void Session::registerOperationQueuePauseUnpause() +{ + on(fmt::format("OperationQueue::{}::pauseUnpause", id_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply, bool pause) { + auto self = weak.lock(); + if (!self) + { + Log::error("Session no longer exists, cannot pause/unpause operation queue"); + return reply(SharedData::error("Session no longer exists")); + } + + if (!self->operationQueue_) + { + Log::error("No operation queue available to pause/unpause"); + return reply(SharedData::error("No operation queue available")); + } + + self->operationQueue_->paused(pause); + self->resetQueueThrottle(); + return reply(SharedData::success()); + }); +} \ No newline at end of file diff --git a/backend/source/backend/session_manager.cpp b/backend/source/backend/session_manager.cpp new file mode 100644 index 00000000..31427827 --- /dev/null +++ b/backend/source/backend/session_manager.cpp @@ -0,0 +1,189 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +using namespace Detail; + +int askPassDefault(char const* prompt, char* buf, std::size_t length, int, int, void* userdata) +{ + std::pair const* data = + static_cast*>(userdata); + + auto* manager = data->first; + std::string whatFor = data->second; + + if (!manager->strand_->running_in_this_thread()) + { + throw std::runtime_error( + "askPassDefault called outside of strand - This critical bug will crash the application."); + } + + std::promise> pwPromise{}; + std::functionpasswordProviders_)::const_iterator)> askNextProvider; + + askNextProvider = [end = manager->passwordProviders_.end(), &askNextProvider, prompt, &pwPromise, &whatFor]( + decltype(manager->passwordProviders_)::const_iterator iter) { + if (iter == end) + { + pwPromise.set_value(std::nullopt); + return; + } + + iter->second->getPassword( + whatFor, prompt, [iter, &pwPromise, &askNextProvider](std::optional pw) mutable { + if (pw.has_value()) + pwPromise.set_value(pw); + else + { + ++iter; + askNextProvider(iter); + } + }); + }; + askNextProvider(manager->passwordProviders_.begin()); + + const auto pw = pwPromise.get_future().get(); + + if (pw.has_value()) + { + std::memset(buf, 0, length); + std::strncpy(buf, pw.value().c_str(), length - 1); + return 0; + } + return -1; +} + +SessionManager::SessionManager( + boost::asio::any_io_executor executor, + Persistence::StateHolder& stateHolder, + Nui::Window& wnd, + Nui::RpcHub& hub) + : RpcHelper::StrandRpc{executor, wnd, hub} + , stateHolder_{&stateHolder} +{} + +void SessionManager::addPasswordProvider(int priority, PasswordProvider* provider) +{ + within_strand_do([this, priority, provider]() { + passwordProviders_.emplace(priority, provider); + }); +} + +void SessionManager::addSession( + Persistence::SshTerminalEngine const& engine, + std::function const&)> onComplete) +{ + within_strand_do([this, engine, onComplete = std::move(onComplete)]() { + std::pair askPassUserDataKeyPhrase{this, "Key phrase"}; + std::pair askPassUserDataPassword{this, "Password"}; + auto maybeSshSession = + makeSession(engine, askPassDefault, &askPassUserDataKeyPhrase, &askPassUserDataPassword, &pwCache_); + + if (maybeSshSession) + { + const auto sessionId = Ids::SessionId{Ids::generateId()}; + const auto session = std::make_shared( + sessionId, + std::move(maybeSshSession).value(), + executor_, + strand_, + *wnd_, + *hub_, + engine.sshSessionOptions->sftpOptions.value()); + const auto emplaced = sessions_.emplace(sessionId, session); + if (!emplaced.second) + { + Log::error("Session id collision - This should never happen."); + return onComplete(std::nullopt); + } + + Log::info("Created session with id '{}', total is now '{}'.", sessionId.value(), sessions_.size()); + session->start(); + onComplete(sessionId); + } + else + { + Log::error("Failed to create session: {}", maybeSshSession.error()); + onComplete(std::nullopt); + } + }); +} + +void SessionManager::removeSession(Ids::SessionId sessionId) +{ + within_strand_do([this, sessionId]() { + if (auto iter = sessions_.find(sessionId); iter != sessions_.end()) + { + Log::info("Removing session with id: {}", sessionId.value()); + sessions_.erase(iter); + } + else + { + Log::warn("Cannot remove session, no session found with id: {}", sessionId.value()); + } + }); +} + +void SessionManager::registerRpc() +{ + registerRpcSessionConnect(); + registerRpcSessionDisconnect(); +} + +void SessionManager::registerRpcSessionConnect() +{ + on("SessionManager::connect").perform([this](RpcHelper::RpcOnce&& reply, nlohmann::json const& parameters) { + Log::info("Connecting to ssh server with parameters: {}", parameters.dump(4)); + + if (!RpcHelper::ParameterVerifyView{reply, "SessionManager::connect", parameters}.hasValueDeep( + "engine", "sshSessionOptions")) + { + return; + } + + auto onComplete = rpcSafe(std::move(reply), [](auto const& reply, auto const& maybeId) { + if (!maybeId) + { + Log::error("Failed to connect to ssh server"); + return reply({{"error", "Failed to connect to ssh server"}}); + } + + Log::info("Connected to ssh server with id: {}", maybeId->value()); + return reply({{"id", maybeId->value()}}); + }); + + addSession(parameters["engine"].get(), std::move(onComplete)); + }); +} + +void SessionManager::registerRpcSessionDisconnect() +{ + on("SessionManager::disconnect").perform([this](RpcHelper::RpcOnce&& reply, nlohmann::json const& parameters) { + if (!RpcHelper::ParameterVerifyView{reply, "SessionManager::disconnect", parameters}.hasValueDeep("sessionId")) + return; + + try + { + removeSession(Ids::makeSessionId(parameters["sessionId"].get())); + return reply({{"success", true}}); + } + catch (std::exception const& e) + { + Log::error("Error disconnecting to ssh server: {}", e.what()); + return reply({{"error", e.what()}}); + } + }); +} \ No newline at end of file diff --git a/backend/source/backend/sftp/bulk_download_operation.cpp b/backend/source/backend/sftp/bulk_download_operation.cpp new file mode 100644 index 00000000..675f62e5 --- /dev/null +++ b/backend/source/backend/sftp/bulk_download_operation.cpp @@ -0,0 +1,49 @@ +#include +#include +#include + +BulkDownloadOperation::BulkDownloadOperation(SecureShell::SftpSession& sftp, BulkDownloadOperationOptions options) + : Operation{} + , sftp_{&sftp} + , options_{std::move(options)} + , currentDownload_{nullptr} + , entries_{} + , totalBytes_{0} + , currentIndex_{0} + , currentBytes_{0} + , futureTimeout_{options_.individualOptions.futureTimeout} +{} + +BulkDownloadOperation::~BulkDownloadOperation() = default; + +std::expected BulkDownloadOperation::work() +{ + using enum OperationState; + + // TODO: Not implemented yet + + return enterErrorState({.type = Operation::ErrorType::CannotWorkFailedOperation}); +} + +SharedData::OperationType BulkDownloadOperation::type() const +{ + return SharedData::OperationType::BulkDownload; +} + +void BulkDownloadOperation::setScanResult(std::vector&& entries) +{ + entries_ = std::move(entries); + totalBytes_ = 0; +} + +std::expected BulkDownloadOperation::cancel(bool adoptCancelState) +{ + if (adoptCancelState) + enterState(OperationState::Canceled); + return {}; +} + +SecureShell::ProcessingStrand* BulkDownloadOperation::strand() const +{ + return sftp_->strand(); +} \ No newline at end of file diff --git a/backend/source/backend/sftp/download_operation.cpp b/backend/source/backend/sftp/download_operation.cpp new file mode 100644 index 00000000..b01cfe4c --- /dev/null +++ b/backend/source/backend/sftp/download_operation.cpp @@ -0,0 +1,405 @@ +#include + +#include +#include + +DownloadOperation::DownloadOperation( + std::weak_ptr fileStream, + DownloadOperationOptions options) + : Operation{} + , fileStream_{std::move(fileStream)} + , remotePath_{std::move(options.remotePath)} + , localPath_{std::move(options.localPath)} + , tempFileSuffix_{std::move(options.tempFileSuffix)} + , progressCallback_{std::move(options.progressCallback)} + , mayOverwrite_{options.mayOverwrite} + , reserveSpace_{options.reserveSpace} + , tryContinue_{options.tryContinue} + , inheritPermissions_{options.inheritPermissions} + , doCleanup_{options.doCleanup} + , permissions_{options.permissions} + , localFile_{} + , fileSize_{0} + , futureTimeout_{options.futureTimeout} +{ + if (tempFileSuffix_.empty()) + tempFileSuffix_ = ".filepart"; + if (tempFileSuffix_.find('/') != 0) + tempFileSuffix_ = ".filepart"; +} + +DownloadOperation::~DownloadOperation() +{ + std::ignore = cancel(false); + + if (auto stream = fileStream_.lock(); stream) + { + // wait for all tasks of the operation to finish + stream->strand()->pushPromiseTask([]() {}).get(); + } +} + +std::expected DownloadOperation::work() +{ + using enum OperationState; + + switch (state_) + { + case (NotStarted): + { + state_ = Preparing; + [[fallthrough]]; + } + case (Preparing): + { + const auto prepareResult = prepare(); + if (!prepareResult.has_value()) + { + Log::error("DownloadOperation: Failed to prepare operation: {}", prepareResult.error().toString()); + return enterErrorState(prepareResult.error()); + } + state_ = Prepared; + [[fallthrough]]; + } + case (Prepared): + { + state_ = Running; + [[fallthrough]]; + } + case (Running): + { + const auto result = readOnce(); + if (!result.has_value()) + { + Log::error("DownloadOperation: Failed to read file: {}", result.error().toString()); + return enterErrorState(result.error()); + } + if (result.value()) + { + return WorkStatus::MoreWork; + } + // No More to read? + else + { + Log::info("DownloadOperation: Data reading completed."); + state_ = Finalizing; + [[fallthrough]]; + } + } + case (Finalizing): + { + const auto finalizeResult = finalize(); + if (!finalizeResult.has_value()) + { + Log::error("DownloadOperation: Failed to finalize operation: {}", finalizeResult.error().toString()); + return enterErrorState(finalizeResult.error()); + } + state_ = Completed; + Log::info("DownloadOperation: Operation completed successfully."); + return WorkStatus::Complete; + } + case (Completed): + { + Log::warn("DownloadOperation: Operation already completed."); + // Dont enter error state here, it would overwrite the success state. + return std::unexpected(Error{.type = ErrorType::CannotWorkCompletedOperation}); + } + case (Failed): + { + Log::warn("DownloadOperation: Operation already failed."); + // Do not enter error state here, it would overwrite the error state. + return std::unexpected(Error{.type = ErrorType::CannotWorkFailedOperation}); + } + default: + { + } + } + Log::error("DownloadOperation: Unknown operation state: {}", static_cast(state_)); + return enterErrorState({.type = ErrorType::UnknownWorkState}); +} + +std::expected DownloadOperation::readOnce() +{ + if (state_ < OperationState::Prepared) + { + Log::error("DownloadOperation: Operation not prepared."); + return enterErrorState({.type = ErrorType::OperationNotPrepared}); + } + + if (!localFile_.is_open()) + { + Log::error("DownloadOperation: File is not open."); + return enterErrorState({.type = ErrorType::OpenFailure}); + } + + if (fileSize_ == 0) + { + Log::info("DownloadOperation: Remote file is empty, nothing to do."); + return false; + } + + auto stream = fileStream_.lock(); + if (!stream) + { + Log::error("DownloadOperation: File stream expired."); + return enterErrorState({.type = ErrorType::FileStreamExpired}); + } + + auto future = stream->readSome(buffer_.data(), buffer_.size()); + + const auto futureStatus = future.wait_for(futureTimeout_); + + if (futureStatus != std::future_status::ready) + { + Log::error("DownloadOperation: Future timed out while reading."); + return enterErrorState({.type = ErrorType::FutureTimeout}); + } + + const auto result = future.get(); + + if (!result.has_value()) + { + Log::error("DownloadOperation: Failed to read from remote file: {}", result.error().message); + return enterErrorState({.type = ErrorType::SftpError, .sftpError = result.error()}); + } + + const auto readAmount = result.value(); + + if (readAmount == 0) + { + Log::info("DownloadOperation: Remote file read complete or error."); + return false; + } + + std::uint64_t tellp = 0; + std::uint64_t fileSize = 0; + bool good = true; + { + localFile_.write(buffer_.data(), static_cast(readAmount)); + tellp = static_cast(localFile_.tellp()); + fileSize = fileSize_; + good = localFile_.good(); + progressCallback_(0ull, fileSize, tellp); + } + if (!good) + { + Log::error("DownloadOperation read cycle stopped: localFile_.good() == false"); + std::ignore = enterErrorState({ + .type = SharedData::OperationErrorType::TargetFileNotGood, + }); + return false; + } + return good && tellp < fileSize; +} + +std::expected DownloadOperation::openOrAdoptFile(SecureShell::IFileStream& stream) +{ + const auto tempPath = localPath_.generic_string() + tempFileSuffix_; + + if (tryContinue_ && std::filesystem::exists(tempPath)) + { + localFile_.open(tempPath, std::ios::binary | std::ios::app); + if (!localFile_.is_open()) + { + Log::error("DownloadOperation: Failed to open file for appending: {}", tempPath); + return enterErrorState({.type = ErrorType::OpenFailure}); + } + + // File complete but not renamed? just rename it in the finalize() step + if (static_cast(localFile_.tellp()) == fileSize_) + { + Log::info("DownloadOperation: File '{}' already complete, will be renamed in finalize() step.", tempPath); + localFile_.close(); + return {}; + } + // File is larger than expected? discard it and start over. + else if (static_cast(localFile_.tellp()) > fileSize_) + { + Log::info("DownloadOperation: File '{}' is larger than expected, discarding and starting over.", tempPath); + localFile_.close(); + // Reset the file + localFile_.open(tempPath, std::ios::binary | std::ios::trunc); + } + else + { + Log::info("DownloadOperation: File '{}' is incomplete, continuing download.", tempPath); + // Seek stream to position: + auto seekResult = stream.seek(localFile_.tellp()).get(); + if (!seekResult.has_value()) + return enterErrorState({.type = ErrorType::FileStatFailed}); + } + } + else + { + Log::info("DownloadOperation: Starting new download to '{}'.", tempPath); + localFile_.open(tempPath, std::ios::binary | std::ios::trunc); + } + + if (!localFile_.is_open()) + { + Log::error("DownloadOperation: Failed to open file: {}", tempPath); + return enterErrorState({.type = ErrorType::OpenFailure}); + } + + return {}; +} + +std::expected DownloadOperation::prepare() +{ + if (localPath_.empty()) + { + Log::error("DownloadOperation: Invalid local path."); + return enterErrorState({.type = ErrorType::InvalidPath}); + } + + // Initial check. Check again later before rename + if (std::filesystem::exists(localPath_)) + { + if (!mayOverwrite_) + { + Log::error( + "DownloadOperation: File '{}' already exists and may not be overwritten.", localPath_.generic_string()); + return enterErrorState({.type = ErrorType::FileExists}); + } + } + + auto stream = fileStream_.lock(); + if (!stream) + { + Log::error("DownloadOperation: File stream expired."); + return enterErrorState({.type = ErrorType::FileStreamExpired}); + } + + const auto fileInfo = stream->stat().get(); + if (!fileInfo.has_value()) + { + Log::error("DownloadOperation: Failed to stat file."); + return enterErrorState({.type = ErrorType::FileStatFailed, .sftpError = fileInfo.error()}); + } + + fileSize_ = fileInfo->size; + + auto openResult = openOrAdoptFile(*stream); + if (!openResult.has_value()) + { + Log::error("DownloadOperation: Failed to open file."); + return enterErrorState(std::move(openResult).error()); + } + + if (reserveSpace_ && fileSize_ != 0) + { + // Reserve space + Log::info("DownloadOperation: Reserving space for file."); + const auto pos = localFile_.tellp(); + localFile_.seekp(fileSize_ - 1); + localFile_.put('\0'); + if (localFile_.fail()) + { + Log::error("DownloadOperation: Failed to open file."); + return enterErrorState({.type = ErrorType::OpenFailure}); + } + localFile_.seekp(pos); + } + + Log::info( + "DownloadOperation: Prepared download of '{}' to '{}'.", + remotePath_.generic_string(), + localPath_.generic_string()); + + return {}; +} + +std::expected DownloadOperation::cancel(bool adoptCancelState) +{ + if (adoptCancelState) + { + Log::info( + "DownloadOperation: Download of '{}' to '{}' canceled.", + remotePath_.generic_string(), + localPath_.generic_string()); + state_ = OperationState::Canceled; + } + + cleanup(); + return {}; +} + +void DownloadOperation::cleanup() +{ + localFile_.close(); + + if (doCleanup_ && std::filesystem::exists(localPath_.generic_string() + tempFileSuffix_)) + std::filesystem::remove(localPath_.generic_string() + tempFileSuffix_); + + if (auto stream = fileStream_.lock(); stream) + stream->close(false); +} + +std::expected DownloadOperation::finalize() +{ + if (state_ == OperationState::Running) + { + Log::error("DownloadOperation: Cannot finalize while reading."); + return std::unexpected(Error{.type = ErrorType::CannotFinalizeDuringRead}); + } + + localFile_.close(); + + if (std::filesystem::exists(localPath_) && !mayOverwrite_) + { + Log::error( + "DownloadOperation: File '{}' already exists and may not be overwritten.", localPath_.generic_string()); + return std::unexpected(Error{.type = ErrorType::FileExists}); + } + + std::error_code ec{}; + std::filesystem::rename(localPath_.generic_string() + tempFileSuffix_, localPath_, ec); + if (ec) + { + Log::error("DownloadOperation: Failed to rename file: {}", ec.message()); + return std::unexpected(Error{.type = ErrorType::RenameFailure}); + } + + if (inheritPermissions_) + { + Log::info("DownloadOperation: Inheriting permissions from remote file."); + auto stream = fileStream_.lock(); + if (!stream) + { + Log::error("DownloadOperation: File stream expired."); + return std::unexpected(Error{.type = ErrorType::FileStreamExpired}); + } + + const auto fileInfo = stream->stat().get(); + if (!fileInfo.has_value()) + { + Log::error("DownloadOperation: Failed to stat file."); + return std::unexpected(Error{.type = ErrorType::FileStatFailed, .sftpError = fileInfo.error()}); + } + + std::error_code permissionsError{}; + std::filesystem::permissions(localPath_, fileInfo->permissions, permissionsError); + if (permissionsError) + { + Log::error("DownloadOperation: Failed to set permissions: {}", permissionsError.message()); + return std::unexpected(Error{.type = ErrorType::CannotSetFilePermissions}); + } + } + else if (permissions_) + { + Log::info("DownloadOperation: Setting permissions."); + std::error_code permissionsError{}; + std::filesystem::permissions(localPath_, *permissions_, permissionsError); + if (permissionsError) + { + Log::error("DownloadOperation: Failed to set permissions: {}", permissionsError.message()); + return std::unexpected(Error{.type = ErrorType::CannotSetFilePermissions}); + } + } + + Log::info( + "DownloadOperation: Finalized download of '{}' to '{}'.", + remotePath_.generic_string(), + localPath_.generic_string()); + return {}; +} \ No newline at end of file diff --git a/backend/source/backend/sftp/operation_queue.cpp b/backend/source/backend/sftp/operation_queue.cpp new file mode 100644 index 00000000..5147181c --- /dev/null +++ b/backend/source/backend/sftp/operation_queue.cpp @@ -0,0 +1,361 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace +{ + OperationQueue::OperationCompleted makeCompletedOperation( + OperationQueue::CompletionReason reason, + Ids::OperationId operationId, + Operation const& operation, + std::optional error = std::nullopt) + { + return operation.visit( + Utility::overloaded( + [reason, operationId, error](DownloadOperation const& op) { + return OperationQueue::OperationCompleted{ + .reason = reason, + .operationId = operationId, + .completionTime = std::chrono::system_clock::now(), + .localPath = op.localPath(), + .remotePath = op.remotePath(), + .error = error, + }; + }, + [reason, operationId, error](ScanOperation const& op) { + return OperationQueue::OperationCompleted{ + .reason = reason, + .operationId = operationId, + .completionTime = std::chrono::system_clock::now(), + .remotePath = op.remotePath(), + .error = error, + }; + }, + [reason, operationId, error](BulkDownloadOperation const&) { + return OperationQueue::OperationCompleted{ + .reason = reason, + .operationId = operationId, + .completionTime = std::chrono::system_clock::now(), + .error = error, + }; + }, + [reason, operationId](std::nullopt_t) { + return OperationQueue::OperationCompleted{ + .reason = reason, + .operationId = operationId, + .completionTime = std::chrono::system_clock::now(), + }; + })); + } +} + +OperationQueue::OperationQueue( + boost::asio::any_io_executor executor, + std::shared_ptr> strand, + Nui::Window& wnd, + Nui::RpcHub& hub, + Persistence::SftpOptions sftpOpts, + Ids::SessionId sessionId, + int parallelism) + : RpcHelper::StrandRpc{executor, strand, wnd, hub} + , sftpOpts_{std::move(sftpOpts)} + , sessionId_{std::move(sessionId)} + , parallelism_{parallelism} +{} + +void OperationQueue::cancelAll() +{ + within_strand_do([weak = weak_from_this()]() { + auto self = weak.lock(); + if (!self) + return; + + self->operations_.clear(); + Log::info("All operations in the queue have been canceled."); + }); +} +void OperationQueue::cancel(Ids::OperationId id) +{ + within_strand_do([weak = weak_from_this(), id = std::move(id)]() { + auto self = weak.lock(); + if (!self) + return; + + std::erase_if(self->operations_, [id](auto& op) { + bool isMatch = op.first == id; + if (isMatch) + op.second->cancel(true); + return isMatch; + }); + }); +} + +void OperationQueue::completeOperation(SharedData::OperationCompleted&& operationCompleted) +{ + within_strand_do([weak = weak_from_this(), operationCompleted = std::move(operationCompleted)]() { + auto self = weak.lock(); + if (!self) + return; + + if (operationCompleted.error) + Log::error("Operation failed: {}", operationCompleted.error->toString()); + + Log::info( + "Operation completed: id={}, reason={}, localPath='{}', remotePath='{}'", + operationCompleted.operationId.value(), + static_cast(operationCompleted.reason), + operationCompleted.localPath ? operationCompleted.localPath->generic_string() : "", + operationCompleted.remotePath ? operationCompleted.remotePath->generic_string() : ""); + + self->hub_->callRemote( + fmt::format("OperationQueue::{}::onOperationCompleted", self->sessionId_.value()), + SharedData::OperationCompleted{ + .reason = operationCompleted.reason, + .operationId = operationCompleted.operationId, + .completionTime = operationCompleted.completionTime, + .localPath = operationCompleted.localPath, + .remotePath = operationCompleted.remotePath, + .error = operationCompleted.error, + }); + }); +} + +bool OperationQueue::work() +{ + // Assumed in strand + + if (paused_) + return false; + + const auto updateCount = std::min(operations_.size(), static_cast(parallelism_)); + + bool moreWork = false; + if (updateCount == 0) + return false; + + for (std::size_t i = 0; i < updateCount; ++i) + { + auto& [id, operation] = operations_[i]; + const auto workResult = operation->work(); + if (!workResult.has_value()) + { + completeOperation( + makeCompletedOperation(OperationQueue::CompletionReason::Failed, id, *operation, workResult.error())); + operations_.erase(operations_.begin() + static_cast(i)); + // Exit loop and avoid any offset math. Just do another update cycle. + return true; + } + + const auto workStatus = workResult.value(); + if (workStatus == Operation::WorkStatus::Complete) + { + Log::info("Operation completed successfully: {}", id.value()); + completeOperation(makeCompletedOperation(OperationQueue::CompletionReason::Completed, id, *operation)); + operations_.pop_front(); + // Exit loop and avoid any offset math. Just do another update cycle. + return true; + } + else if (workStatus == Operation::WorkStatus::MoreWork) + { + moreWork = true; + continue; + } + } + return moreWork; +} + +bool OperationQueue::paused() const +{ + return paused_; +} +void OperationQueue::paused(bool pause) +{ + within_strand_do([weak = weak_from_this(), pause]() { + auto self = weak.lock(); + if (!self) + return; + + self->paused_ = pause; + }); +} + +std::expected OperationQueue::addDownloadOperation( + SecureShell::SftpSession& sftp, + Ids::OperationId operationId, + std::filesystem::path const& localPath, + std::filesystem::path const& remotePath) +{ + // Assumed in strand + + auto fut = sftp.stat(remotePath); + if (fut.wait_for(sftpOpts_.operationTimeout) != std::future_status::ready) + { + Log::error("Failed to stat remote sftp file: timeout"); + return std::unexpected(Operation::Error{.type = Operation::ErrorType::FutureTimeout}); + } + + const auto result = fut.get(); + if (!result.has_value()) + { + Log::error("Failed to stat remote sftp file: {}", result.error().message); + return std::unexpected(Operation::Error{.type = Operation::ErrorType::SftpError, .sftpError = result.error()}); + } + + if (result->isRegularFile()) + { + const auto fileSize = result->size; + + auto fut = sftp.openFile(remotePath, SecureShell::SftpSession::OpenType::Read, std::filesystem::perms::unknown); + if (fut.wait_for(std::chrono::seconds{5}) != std::future_status::ready) + { + Log::error("Failed to open remote sftp file: timeout"); + return std::unexpected(Operation::Error{.type = Operation::ErrorType::OpenFailure}); + } + + const auto openResult = fut.get(); + if (!openResult.has_value()) + { + Log::error("Failed to open remote sftp file: {}", openResult.error().message); + return std::unexpected( + Operation::Error{.type = Operation::ErrorType::SftpError, .sftpError = openResult.error()}); + } + + const auto transferOptions = sftpOpts_.downloadOptions.value_or(Persistence::TransferOptions{}); + const auto defaultOptions = DownloadOperation::DownloadOperationOptions{}; + + auto operation = std::make_unique( + std::move(openResult).value(), + DownloadOperation::DownloadOperationOptions{ + .progressCallback = + [weak = weak_from_this(), operationId](auto min, auto max, auto current) { + auto self = weak.lock(); + if (!self) + return; + + self->hub_->callRemote( + fmt::format("OperationQueue::{}::onDownloadProgress", self->sessionId_.value()), + SharedData::DownloadProgress{ + .operationId = operationId, + .min = min, + .max = max, + .current = current, + }); + + Log::debug( + "Downloaded {} / {} bytes ({}%)", + current - min, + max - min, + (current - min) * 100 / (max - min)); + }, + .remotePath = remotePath, + .localPath = localPath, + .tempFileSuffix = transferOptions.tempFileSuffix.value_or(defaultOptions.tempFileSuffix), + .mayOverwrite = transferOptions.mayOverwrite.value_or(defaultOptions.mayOverwrite), + .reserveSpace = transferOptions.reserveSpace.value_or(defaultOptions.reserveSpace), + .tryContinue = transferOptions.tryContinue.value_or(defaultOptions.tryContinue), + .inheritPermissions = transferOptions.inheritPermissions.value_or(defaultOptions.inheritPermissions), + .doCleanup = transferOptions.doCleanup.value_or(defaultOptions.doCleanup), + .permissions = + transferOptions.customPermissions ? transferOptions.customPermissions : defaultOptions.permissions, + }); + + operations_.emplace_back(operationId, std::move(operation)); + + Log::info("Calling OperationQueue::{}::onOperationAdded", sessionId_.value()); + hub_->callRemote( + fmt::format("OperationQueue::{}::onOperationAdded", sessionId_.value()), + SharedData::OperationAdded{ + .operationId = operationId, + .type = SharedData::OperationType::Download, + .totalBytes = fileSize, + .localPath = localPath, + .remotePath = remotePath}); + + return {}; + } + else if (result->isDirectory()) + { + auto scan = std::make_unique( + sftp, + ScanOperation::ScanOperationOptions{ + .progressCallback = + [](auto, auto, auto) { + // TODO: + }, + .remotePath = remotePath, + .futureTimeout = std::chrono::seconds{5}, + }); + + auto bulk = std::make_unique( + sftp, + BulkDownloadOperation::BulkDownloadOperationOptions{ + .overallProgressCallback = + [](auto const&, auto, auto, auto, auto) { + // TODO: + }, + .individualOptions = + DownloadOperation::DownloadOperationOptions{ + .progressCallback = + [](auto, auto, auto) { + // TODO: + }, + .remotePath = remotePath, + .localPath = localPath, + }, + }); + + operations_.emplace_back(operationId, std::move(scan)); + operations_.emplace_back(operationId, std::move(bulk)); + + hub_->callRemote( + fmt::format("OperationQueue::{}::{}", sessionId_.value(), "onOperationAdded"), + SharedData::OperationAdded{ + .operationId = operationId, + .type = SharedData::OperationType::Scan, + }); + hub_->callRemote( + fmt::format("OperationQueue::{}::{}", sessionId_.value(), "onOperationAdded"), + SharedData::OperationAdded{ + .operationId = operationId, + .type = SharedData::OperationType::BulkDownload, + }); + + return {}; + } + else + { + Log::error("Remote path is neither a file nor a directory: {}.", result->type); + return std::unexpected(Operation::Error{.type = Operation::ErrorType::OperationNotPossibleOnFileType}); + } +} + +void OperationQueue::registerRpc() +{ + on(fmt::format("OperationQueue::{}::isPaused", sessionId_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply) { + auto self = weak.lock(); + if (!self) + return reply(SharedData::error("OperationQueue no longer exists")); + + return reply( + SharedData::ErrorOrSuccess{SharedData::IsPaused{ + .paused = self->paused(), + }}); + }); + + on(fmt::format("OperationQueue::{}::cancel", sessionId_.value())) + .perform([weak = weak_from_this()](RpcHelper::RpcOnce&& reply, Ids::OperationId operationId) { + auto self = weak.lock(); + if (!self) + return reply(SharedData::error("OperationQueue no longer exists")); + + self->cancel(std::move(operationId)); + return reply(SharedData::success()); + }); +} \ No newline at end of file diff --git a/backend/source/backend/sftp/scan_operation.cpp b/backend/source/backend/sftp/scan_operation.cpp new file mode 100644 index 00000000..0c0bf6c1 --- /dev/null +++ b/backend/source/backend/sftp/scan_operation.cpp @@ -0,0 +1,125 @@ +#include +#include + +#include + +ScanOperation::ScanOperation(SecureShell::SftpSession& sftp, ScanOperationOptions options) + : sftp_(&sftp) + , remotePath_{std::move(options.remotePath)} + , progressCallback_{std::move(options.progressCallback)} + , futureTimeout_{options.futureTimeout} +{} + +ScanOperation::~ScanOperation() = default; + +std::expected ScanOperation::scanOnce(std::filesystem::path const& path) +{ + auto fut = sftp_->listDirectory(path); + fut.wait_for(futureTimeout_); + if (fut.wait_for(futureTimeout_) != std::future_status::ready) + return enterErrorState({.type = ErrorType::FutureTimeout}); + + const auto result = fut.get(); + if (!result.has_value()) + return enterErrorState({.type = ErrorType::SftpError, .sftpError = result.error()}); + + entries_.insert(entries_.end(), result->begin(), result->end()); + return {}; +} + +std::expected ScanOperation::work() +{ + using enum OperationState; + + switch (state_) + { + case (NotStarted): + { + state_ = Running; + totalBytes_ = 0; + currentIndex_ = 0; + Log::info("ScanOperation: Starting scan of '{}'.", remotePath_.generic_string()); + const auto result = scanOnce(remotePath_); + if (!result.has_value()) + { + Log::error("ScanOperation: Failed to scan directory: {}", result.error().toString()); + return enterErrorState(result.error()); + } + progressCallback_(totalBytes_, currentIndex_, static_cast(entries_.size())); + break; + } + case (Running): + { + if (currentIndex_ >= static_cast(entries_.size())) + { + Log::info("ScanOperation: Scan of '{}' completed.", remotePath_.generic_string()); + state_ = Completed; + return WorkStatus::Complete; + } + + if (entries_[currentIndex_].isDirectory()) + { + const auto result = scanOnce(entries_[currentIndex_].path); + if (!result.has_value()) + { + Log::error( + "ScanOperation: Failed to scan directory '{}': {}", + entries_[currentIndex_].path.generic_string(), + result.error().toString()); + return enterErrorState(result.error()); + } + ++currentIndex_; + } + + for (; currentIndex_ < static_cast(entries_.size()); ++currentIndex_) + { + auto& entry = entries_[currentIndex_]; + if (entry.isRegularFile()) + totalBytes_ += entry.size; + else if (entry.isDirectory()) + return WorkStatus::MoreWork; + } + + progressCallback_(totalBytes_, currentIndex_, static_cast(entries_.size())); + return WorkStatus::MoreWork; + } + case (Prepared): + case (Preparing): + case (Finalizing): + Log::error("ScanOperation: Invalid state: {}", static_cast(state_)); + return enterErrorState({.type = ErrorType::InvalidOperationState}); + case (Completed): + { + Log::warn("ScanOperation: Operation already completed."); + // Dont enter error state here, it would overwrite the success state. + return std::unexpected(Error{.type = ErrorType::CannotWorkCompletedOperation}); + } + case (Failed): + { + Log::warn("ScanOperation: Operation already failed."); + // Do not enter error state here, it would overwrite the error state. + return std::unexpected(Error{.type = ErrorType::CannotWorkFailedOperation}); + } + case (Canceled): + { + Log::warn("ScanOperation: Cannot work on canceled operation."); + return std::unexpected(Error{.type = ErrorType::CannotWorkCanceledOperation}); + } + } + return enterErrorState({.type = ErrorType::UnknownWorkState}); +} + +std::expected ScanOperation::cancel(bool adoptCancelState) +{ + if (adoptCancelState) + { + Log::info("ScanOperation: Scan of '{}' canceled.", remotePath_.generic_string()); + enterState(OperationState::Canceled); + } + return {}; +} + +SecureShell::ProcessingStrand* ScanOperation::strand() const +{ + return sftp_->strand(); +} \ No newline at end of file diff --git a/backend/source/backend/ssh_session_manager.cpp b/backend/source/backend/ssh_session_manager.cpp deleted file mode 100644 index 67f91769..00000000 --- a/backend/source/backend/ssh_session_manager.cpp +++ /dev/null @@ -1,557 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -using namespace Detail; - -int askPassDefault(char const* prompt, char* buf, std::size_t length, int, int, void* userdata) -{ - std::pair const* data = - static_cast*>(userdata); - auto* manager = data->first; - std::string whatFor = data->second; - - std::promise> pwPromise{}; - std::lock_guard lock{manager->passwordProvidersMutex_}; - std::functionpasswordProviders_)::const_iterator)> askNextProvider; - - askNextProvider = [end = manager->passwordProviders_.end(), &askNextProvider, prompt, &pwPromise, &whatFor]( - decltype(manager->passwordProviders_)::const_iterator iter) { - if (iter == end) - { - pwPromise.set_value(std::nullopt); - return; - } - - iter->second->getPassword( - whatFor, prompt, [iter, &pwPromise, &askNextProvider](std::optional pw) mutable { - if (pw.has_value()) - pwPromise.set_value(pw); - else - { - ++iter; - askNextProvider(iter); - } - }); - }; - askNextProvider(manager->passwordProviders_.begin()); - - const auto pw = pwPromise.get_future().get(); - - if (pw.has_value()) - { - std::memset(buf, 0, length); - std::strncpy(buf, pw.value().c_str(), length - 1); - return 0; - } - return -1; -} - -SshSessionManager::SshSessionManager() -{} - -SshSessionManager::~SshSessionManager() -{ - joinSessionAdder(); -} - -void SshSessionManager::registerRpcConnect(Nui::Window&, Nui::RpcHub& hub) -{ - /* - { - engine: { - sshSessionOptions: Persistence::SshSessionOptions, - environment: std::unordered_map - }, - } - */ - hub.registerFunction( - "SshSessionManager::connect", - [this, hub = &hub](std::string const& responseId, nlohmann::json const& parameters) { - try - { - Log::info("Connecting to ssh server with parameters: {}", parameters.dump(4)); - if (!parameters.contains("engine")) - { - Log::error("No engine specified for ssh connection"); - hub->callRemote(responseId, nlohmann::json{{"error", "No engine specified for ssh connection"}}); - return; - } - - const auto sessionOptions = - parameters["engine"]["sshSessionOptions"].get(); - - const auto engine = parameters["engine"].get(); - joinSessionAdder(); - addSession(engine, [responseId, hub](auto const& maybeId) { - if (!maybeId) - { - Log::error("Failed to connect to ssh server"); - hub->callRemote(responseId, nlohmann::json{{"error", "Failed to connect to ssh server"}}); - return; - } - - Log::info("Connected to ssh server with id: {}", maybeId->value()); - hub->callRemote(responseId, nlohmann::json{{"id", maybeId->value()}}); - }); - } - catch (std::exception const& e) - { - Log::error("Error connecting to ssh server: {}", e.what()); - hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - return; - } - }); -} - -void SshSessionManager::registerRpcCreateChannel(Nui::Window&, Nui::RpcHub& hub) -{ - /* - { - sessionId: string, - engine: { - sshSessionOptions: Persistence::SshSessionOptions, - environment: std::unordered_map - } - } - */ - hub.registerFunction( - "SshSessionManager::Session::createChannel", - [this, hub = &hub](std::string const& responseId, nlohmann::json const& parameters) { - if (!parameters.contains("sessionId")) - { - Log::error("No session id specified for channel creation"); - hub->callRemote(responseId, nlohmann::json{{"error", "No session id specified for channel creation"}}); - return; - } - - const auto sessionId = Ids::makeSessionId(parameters.at("sessionId").get()); - - if (sessions_.find(sessionId) == sessions_.end()) - { - Log::error("No session found with id: {}", sessionId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No session found with id"}}); - return; - } - - if (!parameters.contains("engine")) - { - Log::error("No engine specified for ssh channel"); - hub->callRemote(responseId, nlohmann::json{{"error", "No engine specified for ssh connection"}}); - return; - } - - const auto sessionOptions = parameters["engine"]["sshSessionOptions"].get(); - auto env = sessionOptions.environment; - - const auto weakChannel = sessions_[sessionId]->createPtyChannel({.environment = env}).get(); - if (!weakChannel.has_value()) - { - Log::error("Failed to create pty channel: {}", weakChannel.error()); - hub->callRemote(responseId, nlohmann::json{{"error", "Failed to create pty channel"}}); - return; - } - - const auto channelId = Ids::generateChannelId(); - channels_.emplace(channelId, std::move(weakChannel).value()); - - Log::info( - "Created pty channel with id '{}', channel total is now '{}'.", channelId.value(), channels_.size()); - - hub->callRemote(responseId, nlohmann::json{{"id", channelId.value()}}); - }); -} - -void SshSessionManager::registerRpcStartChannelRead(Nui::Window& wnd, Nui::RpcHub& hub) -{ - hub.registerFunction( - "SshSessionManager::Channel::startReading", - [this, hub = &hub, wnd = &wnd]( - std::string const& responseId, std::string const& sessionIdString, std::string const& channelIdString) { - const auto sessionId = Ids::makeSessionId(sessionIdString); - const auto channelId = Ids::makeChannelId(channelIdString); - - if (sessions_.find(sessionId) == sessions_.end()) - { - Log::error("No session found with id: {}", sessionId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No session found with id"}}); - return; - } - - if (channels_.find(channelId) == channels_.end()) - { - Log::error("No channel found with id: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No channel found with id"}}); - return; - } - - auto& channel = channels_[channelId]; - - auto locked = channel.lock(); - if (!locked) - { - Log::error("Failed to lock channel with id: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "Failed to lock channel"}}); - return; - } - - const std::string stdoutReceptacle{"sshTerminalStdout_" + channelId.value()}; - const std::string stderrReceptacle{"sshTerminalStderr_" + channelId.value()}; - const std::string exitReceptacle{"sshTerminalOnExit_" + channelId.value()}; - - locked->startReading( - [wnd, hub, sessionId, channelId, stdoutReceptacle](std::string const& msg) { - wnd->runInJavascriptThread([hub, stdoutReceptacle, sessionId, channelId, msg]() { - hub->callRemote( - stdoutReceptacle, - nlohmann::json{ - {"sessionId", sessionId.value()}, - {"channelId", channelId.value()}, - {"data", Roar::base64Encode(msg)}}); - }); - }, - [wnd, hub, sessionId, channelId, stderrReceptacle](std::string const& data) { - wnd->runInJavascriptThread([hub, stderrReceptacle, sessionId, channelId, data]() { - hub->callRemote( - stderrReceptacle, - nlohmann::json{ - {"sessionId", sessionId.value()}, - {"channelId", channelId.value()}, - {"data", Roar::base64Encode(data)}}); - }); - }, - [this, wnd, hub, sessionId, channelId, exitReceptacle]() { - Log::info("Channel for session '{}' lost with id: {}", sessionId.value(), channelId.value()); - - auto iter = channels_.find(channelId); - if (iter != channels_.end()) - { - if (auto locked = iter->second.lock(); locked) - { - locked->close(); - } - channels_.erase(iter); - } - - wnd->runInJavascriptThread([hub, sessionId, channelId, exitReceptacle]() { - hub->callRemote( - exitReceptacle, - nlohmann::json{{"sessionId", sessionId.value()}, {"channelId", channelId.value()}}); - }); - }); - - hub->callRemote(responseId, nlohmann::json{{"success", true}}); - }); -} - -void SshSessionManager::registerRpcChannelClose(Nui::Window&, Nui::RpcHub& hub) -{ - hub.registerFunction( - "SshSessionManager::Session::closeChannel", - [this, hub = &hub](std::string const& responseId, std::string const& channelIdString) { - try - { - const auto channelId = Ids::makeChannelId(channelIdString); - - if (auto iter = channels_.find(channelId); iter != channels_.end()) - { - if (auto channel = iter->second.lock(); channel) - { - channel->close(); - } - channels_.erase(iter); - Log::info( - "Closed channel with id '{}', now remaining channels total '{}'.", - channelId.value(), - channels_.size()); - } - else - { - Log::error("No channel found with id: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No channel found with id"}}); - return; - } - - hub->callRemote(responseId, nlohmann::json{{"success", true}}); - } - catch (std::exception const& e) - { - Log::error("Error closing channel: {}", e.what()); - hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - return; - } - }); -} - -void SshSessionManager::registerRpcEndSession(Nui::Window&, Nui::RpcHub& hub) -{ - hub.registerFunction( - "SshSessionManager::disconnect", - [this, hub = &hub](std::string const& responseId, std::string const& sessionIdString) { - try - { - const auto sessionId = Ids::makeSessionId(sessionIdString); - - if (sessions_.find(sessionId) == sessions_.end()) - { - // Do not log this, because multi delete is not an error - hub->callRemote(responseId, nlohmann::json{{"error", "No session found with id"}}); - return; - } - Log::info("Disconnecting from ssh server with id: {}", sessionId.value()); - sessions_.erase(sessionId); - hub->callRemote(responseId, nlohmann::json{{"success", true}}); - } - catch (std::exception const& e) - { - Log::error("Error disconnecting to ssh server: {}", e.what()); - hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - return; - } - }); -} - -void SshSessionManager::registerRpcChannelWrite(Nui::Window&, Nui::RpcHub& hub) -{ - hub.registerFunction( - "SshSessionManager::Channel::write", - [this, hub = &hub](std::string const& responseId, std::string const& channelIdString, std::string const& data) { - try - { - const auto channelId = Ids::makeChannelId(channelIdString); - - if (auto iter = channels_.find(channelId); iter != channels_.end()) - { - if (auto channel = iter->second.lock(); channel) - { - channel->write(Roar::base64Decode(data)); - } - else - { - Log::error("Failed to get channel: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "Failed to get channel"}}); - return; - } - } - else - { - Log::error("No channel found with id: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No channel found with id"}}); - return; - } - } - catch (std::exception const& e) - { - Log::error("Error writing to pty: {}", e.what()); - hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - return; - } - }); -} - -void SshSessionManager::registerRpcChannelPtyResize(Nui::Window&, Nui::RpcHub& hub) -{ - hub.registerFunction( - "SshSessionManager::Channel::ptyResize", - [this, hub = &hub](std::string const& responseId, std::string const& channelIdString, int cols, int rows) { - try - { - const auto channelId = Ids::makeChannelId(channelIdString); - - if (auto iter = channels_.find(channelId); iter != channels_.end()) - { - if (auto channel = iter->second.lock(); channel) - { - channel->resizePty(cols, rows); - // FIXME: Why does this always fail? - // if (fut.wait_for(std::chrono::milliseconds{500}) != std::future_status::ready) - // { - // Log::error("Failed to resize pty: timeout"); - // hub->callRemote(responseId, nlohmann::json{{"error", "Failed to resize pty: timeout"}}); - // return; - // } - // const auto result = fut.get(); - // if (result != 0) - // { - // Log::error("Failed to resize pty: {}", result); - // hub->callRemote(responseId, nlohmann::json{{"error", "Failed to resize pty"}}); - // return; - // } - hub->callRemote(responseId, nlohmann::json{{"success", true}}); - } - else - { - Log::error("Failed to get channel: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "Failed to get channel"}}); - return; - } - } - else - { - Log::error("No channel found with id: {}", channelId.value()); - hub->callRemote(responseId, nlohmann::json{{"error", "No channel found with id"}}); - return; - } - } - catch (std::exception const& e) - { - Log::error("Error resizing pty: {}", e.what()); - hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - return; - } - }); -} - -void SshSessionManager::registerRpc(Nui::Window& wnd, Nui::RpcHub& hub) -{ - registerRpcConnect(wnd, hub); - registerRpcCreateChannel(wnd, hub); - registerRpcStartChannelRead(wnd, hub); - registerRpcChannelClose(wnd, hub); - registerRpcEndSession(wnd, hub); - registerRpcChannelWrite(wnd, hub); - registerRpcChannelPtyResize(wnd, hub); - - // hub.registerFunction( - // "SshSessionManager::sftp::listDirectory", - // [this, hub = &hub](std::string const& responseId, std::string const& sessionIdString, std::string const& - // path) { - // try - // { - // const auto sessionId = Ids::makeSessionId(sessionIdString); - - // if (sessions_.find(sessionId) == sessions_.end()) - // { - // Log::error("No session found with id: {}", sessionId.value()); - // hub->callRemote(responseId, nlohmann::json{{"error", "No session found with id"}}); - // return; - // } - - // auto& session = sessions_[sessionId]; - // auto sftpSession = session->getSftpSession(); - // if (!sftpSession) - // { - // Log::error("Failed to create sftp session"); - // hub->callRemote(responseId, nlohmann::json{{"error", "Failed to create sftp session"}}); - // return; - // } - - // auto result = sftpSession->listDirectory(path); - // if (!result.has_value()) - // { - // Log::error("Failed to list directory: {}", result.error().message); - // hub->callRemote(responseId, nlohmann::json{{"error", result.error().message}}); - // return; - // } - - // hub->callRemote(responseId, nlohmann::json{{"entries", *result}}); - // } - // catch (std::exception const& e) - // { - // Log::error("Error listing directory: {}", e.what()); - // hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - // return; - // } - // }); - - // hub.registerFunction( - // "SshSessionManager::sftp::createDirectory", - // [this, hub = &hub](std::string const& responseId, std::string const& sessionIdString, std::string const& - // path) { - // try - // { - // const auto sessionId = Ids::makeSessionId(sessionIdString); - - // if (sessions_.find(sessionId) == sessions_.end()) - // { - // Log::error("No session found with id: {}", sessionId.value()); - // hub->callRemote(responseId, nlohmann::json{{"error", "No session found with id"}}); - // return; - // } - - // auto& session = sessions_[sessionId]; - // auto sftpSession = session->getSftpSession(); - // if (!sftpSession) - // { - // Log::error("Failed to create sftp session"); - // hub->callRemote(responseId, nlohmann::json{{"error", "Failed to create sftp session"}}); - // return; - // } - - // auto result = sftpSession->createDirectory(path); - // if (result) - // { - // Log::error("Failed to create directory: {}", result->message); - // hub->callRemote(responseId, nlohmann::json{{"error", result->message}}); - // return; - // } - - // hub->callRemote(responseId, nlohmann::json{{"success", true}}); - // } - // catch (std::exception const& e) - // { - // Log::error("Error creating directory: {}", e.what()); - // hub->callRemote(responseId, nlohmann::json{{"error", e.what()}}); - // return; - // } - // }); -} - -void SshSessionManager::addPasswordProvider(int priority, PasswordProvider* provider) -{ - std::lock_guard lock{passwordProvidersMutex_}; - passwordProviders_.emplace(priority, provider); -} - -void SshSessionManager::joinSessionAdder() -{ - std::scoped_lock guard{addSessionMutex_}; - if (addSessionThread_ && addSessionThread_->joinable()) - { - addSessionThread_->join(); - addSessionThread_.reset(); - } -} - -void SshSessionManager::addSession( - Persistence::SshTerminalEngine const& engine, - std::function const&)> onComplete) -{ - std::scoped_lock guard{addSessionMutex_}; - if (addSessionThread_ && addSessionThread_->joinable()) - { - Log::error("Session creation already in progress"); - onComplete(std::nullopt); - } - - addSessionThread_ = std::make_unique([this, engine, onComplete = std::move(onComplete)]() { - std::pair askPassUserDataKeyPhrase{this, "Key phrase"}; - std::pair askPassUserDataPassword{this, "Password"}; - auto maybeSession = - makeSession(engine, askPassDefault, &askPassUserDataKeyPhrase, &askPassUserDataPassword, &pwCache_); - if (maybeSession) - { - const auto sessionId = Ids::SessionId{Ids::generateId()}; - sessions_.emplace(sessionId, std::move(maybeSession).value()); - sessions_[sessionId]->start(); - onComplete(sessionId); - } - else - { - Log::error("Failed to create session: {}", maybeSession.error()); - onComplete(std::nullopt); - } - }); -} \ No newline at end of file diff --git a/backend/test/backend/CMakeLists.txt b/backend/test/backend/CMakeLists.txt new file mode 100644 index 00000000..d320906a --- /dev/null +++ b/backend/test/backend/CMakeLists.txt @@ -0,0 +1,41 @@ +include(GoogleTest) + +add_executable( + nui-scp-backend-test + main.cpp +) + +find_package(GTest CONFIG REQUIRED) + +find_package(Boost CONFIG REQUIRED COMPONENTS filesystem asio system process) + +target_link_libraries( + nui-scp-backend-test + PUBLIC + secure-shell + backend + GTest::GTest + GTest::gmock_main + GTest::Main + Boost::filesystem + Boost::asio + Boost::system + Boost::process +) + +if (WIN32) + target_link_libraries(nui-scp-backend-test PUBLIC ws2_32) +endif() + +set_target_properties(nui-scp-backend-test PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/test") + +gtest_discover_tests(nui-scp-backend-test) + +# If msys2, copy dynamic libraries to executable directory, visual studio does this automatically. +# And there is no need on linux. +if (DEFINED ENV{MSYSTEM}) + add_custom_command(TARGET nui-scp-backend-test POST_BUILD + COMMAND bash -c "ldd $" | "grep" "clang" | awk "NF == 4 { system(\"${CMAKE_COMMAND} -E copy \" \$3 \" $\") }" + VERBATIM + ) +endif() \ No newline at end of file diff --git a/backend/test/backend/main.cpp b/backend/test/backend/main.cpp new file mode 100644 index 00000000..cfe644c1 --- /dev/null +++ b/backend/test/backend/main.cpp @@ -0,0 +1,19 @@ +#include "test_download_operation.hpp" + +#include + +#include + +#include + +std::filesystem::path programDirectory; + +int main(int argc, char** argv) +{ + Log::setLevel(Log::Level::Off); + + programDirectory = std::filesystem::path{argv[0]}.parent_path(); + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/backend/test/backend/test_download_operation.hpp b/backend/test/backend/test_download_operation.hpp new file mode 100644 index 00000000..abf7a01f --- /dev/null +++ b/backend/test/backend/test_download_operation.hpp @@ -0,0 +1,619 @@ +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include + +using namespace std::chrono_literals; +using namespace std::string_literals; + +extern std::filesystem::path programDirectory; + +namespace Test +{ + class DownloadOperationTests : public ::testing::Test + { + protected: + void SetUp() override + { + for (int i = 0; i != 10; ++i) + fakeFileContent_ += "This is a test file content.\n"; + + processingThread_.start(5ms); + } + + void TearDown() override + { + processingThread_.stop(); + } + + std::shared_ptr<::testing::NiceMock> makeFileStreamMock() + { + auto mock = std::make_shared<::testing::NiceMock>(); + + ON_CALL(*mock, strand()).WillByDefault([this]() -> SecureShell::ProcessingStrand* { + return strand_.get(); + }); + + return mock; + } + + void giveMockDefaultStat( + std::shared_ptr<::testing::NiceMock> const& mock, + std::optional size = std::nullopt) + { + if (!size) + size = fakeFileContent_.size(); + + ON_CALL(*mock, stat()) + .WillByDefault( + [size = size.value()]() + -> std::future> { + std::promise> promise; + promise.set_value( + SecureShell::FileInformation{{ + .size = size, + .permissions = std::filesystem::perms::owner_all, + }}); + return promise.get_future(); + }); + } + + void giveMockExpectedRead(std::shared_ptr<::testing::NiceMock> const& mock) + { + EXPECT_CALL(*mock, readAll(testing::_)) + .WillRepeatedly( + [this](std::function cb) + -> std::future> { + onRead_ = std::move(cb); + readPromise_ = {}; + if (!readCycleQueue_.empty()) + { + readCycleQueue_.front()(); + readCycleQueue_.pop(); + } + return readPromise_.get_future(); + }); + } + + void enqueueFakeReadCycle(std::optional chunkSizeOpt = std::nullopt) + { + readCycleQueue_.push([this, chunkSizeOpt]() { + auto chunkSize = fakeFileContent_.size(); + if (chunkSizeOpt) + chunkSize = chunkSizeOpt.value(); + + if (readOffset_ + chunkSize > fakeFileContent_.size()) + chunkSize = fakeFileContent_.size() - readOffset_; + + // EOF: + if (chunkSize == 0) + { + readPromise_.set_value(readOffset_); + onReadResult_ = onRead_({}); + return; + } + + if (onRead_) + { + onReadResult_ = onRead_(std::string_view{fakeFileContent_}.substr(readOffset_, chunkSize)); + readOffset_ += chunkSize; + readPromise_.set_value(readOffset_); + return; + } + else + { + throw std::runtime_error("No read callback set."); + } + }); + } + + std::string readFile(std::filesystem::path const& path) + { + std::ifstream file{path}; + return std::string{std::istreambuf_iterator{file}, std::istreambuf_iterator{}}; + } + + protected: + std::string fakeFileContent_{}; + Utility::TemporaryDirectory isolateDirectory_{programDirectory / "temp", true}; + SecureShell::ProcessingThread processingThread_{}; + std::unique_ptr strand_{processingThread_.createStrand()}; + std::promise> readPromise_{}; + std::function onRead_{}; + std::size_t readOffset_{0}; + std::queue> readCycleQueue_{}; + bool onReadResult_{true}; + }; + + TEST_F(DownloadOperationTests, CanCreateDownloadOperation) + { + auto fileStream = makeFileStreamMock(); + auto options = DownloadOperation::DownloadOperationOptions{}; + DownloadOperation operation{fileStream, options}; + } + + TEST_F(DownloadOperationTests, DownloadPreparationFailsWithLocalPathEmpty) + { + auto fileStream = makeFileStreamMock(); + auto options = DownloadOperation::DownloadOperationOptions{}; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_FALSE(result.has_value()); + EXPECT_EQ(result.error().type, DownloadOperation::ErrorType::InvalidPath); + } + + TEST_F(DownloadOperationTests, DownloadPreparationFailsWhenFileExistsAndMayNotBeOverwritten) + { + auto fileStream = makeFileStreamMock(); + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .mayOverwrite = false, + }; + std::ofstream file{options.localPath}; + file.close(); + + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_FALSE(result.has_value()); + EXPECT_EQ(result.error().type, DownloadOperation::ErrorType::FileExists); + } + + TEST_F(DownloadOperationTests, DownloadPreparationFailsWhenStreamIsExpired) + { + std::weak_ptr fileWeak; + { + auto fileStream = makeFileStreamMock(); + fileWeak = fileStream; + } + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileWeak, options}; + + auto result = operation.prepare(); + EXPECT_FALSE(result.has_value()); + EXPECT_EQ(result.error().type, DownloadOperation::ErrorType::FileStreamExpired); + } + + TEST_F(DownloadOperationTests, DownloadPreparationFailsWhenStattingTheFileFails) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()).WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + std::unexpected( + SftpError{ + .message = "Stat failed", + })); + return promise.get_future(); + }); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_FALSE(result.has_value()); + EXPECT_EQ(result.error().type, DownloadOperation::ErrorType::FileStatFailed); + } + + // TODO: Test continuation download + + TEST_F(DownloadOperationTests, DownloadPreparationSucceedsWithEmptyRemoteFile) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 0, + }}); + return promise.get_future(); + }); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + } + + TEST_F(DownloadOperationTests, DownloadPreparationSucceedsWithNonEmptyRemoteFile) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 42, + }}); + return promise.get_future(); + }); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + } + + TEST_F(DownloadOperationTests, DownloadPreparationCreatesPartFile) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 42, + }}); + return promise.get_future(); + }); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + TEST_F(DownloadOperationTests, DestructorCleansUpIfOptionIsTrue) + { + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = true, + }; + + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 42, + }}); + return promise.get_future(); + }); + + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + EXPECT_FALSE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + TEST_F(DownloadOperationTests, DestructorDoesNotCleanUpIfOptionIsFalse) + { + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = false, + }; + + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 42, + }}); + return promise.get_future(); + }); + + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + TEST_F(DownloadOperationTests, CancelDoesNotRemoveTheFileIfCleanIsFalse) + { + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = false, + }; + + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream); + + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + + EXPECT_TRUE(operation.cancel(true).has_value()); + + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + TEST_F(DownloadOperationTests, CancelDoesRemoveTheFileIfCleanIsTrue) + { + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = true, + }; + + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + + EXPECT_CALL(*fileStream, stat()) + .WillOnce([]() -> std::future> { + std::promise> promise; + promise.set_value( + FileInformation{{ + .size = 42, + }}); + return promise.get_future(); + }); + + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + EXPECT_TRUE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + + EXPECT_TRUE(operation.cancel(true).has_value()); + + EXPECT_FALSE(std::filesystem::exists(options.localPath.generic_string() + ".filepart")); + } + + TEST_F(DownloadOperationTests, WorkWithEmptyRemoteFileSucceeds) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, 0); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + const auto result = operation.work(); + EXPECT_TRUE(result.has_value()); + } + + TEST_F(DownloadOperationTests, WorkFailsWithExpiredFileStream) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()) << boost::describe::enum_to_string(result.error().type, "INVALID_ENUM_VALUE"); + + fileStream.reset(); + + const auto workResult = operation.work(); + EXPECT_FALSE(workResult.has_value()); + EXPECT_EQ(workResult.error().type, DownloadOperation::ErrorType::FileStreamExpired); + } + + TEST_F(DownloadOperationTests, WorkCallsReadOnFileStream) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, 42); + giveMockExpectedRead(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + }; + DownloadOperation operation{fileStream, options}; + + enqueueFakeReadCycle(); + + const auto result = operation.work(); + EXPECT_TRUE(result.has_value()); + } + + TEST_F(DownloadOperationTests, PrepareReservesSpaceForFile) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, fakeFileContent_.size()); + giveMockExpectedRead(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .reserveSpace = true, + .doCleanup = false, + }; + DownloadOperation operation{fileStream, options}; + + auto result = operation.prepare(); + EXPECT_TRUE(result.has_value()); + + EXPECT_TRUE(operation.cancel(true).has_value()); + + EXPECT_EQ( + std::filesystem::file_size(options.localPath.generic_string() + ".filepart"), fakeFileContent_.size()); + } + + TEST_F(DownloadOperationTests, ReadCycleWritesDataToFile) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, fakeFileContent_.size()); + giveMockExpectedRead(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = false, + }; + DownloadOperation operation{fileStream, options}; + + enqueueFakeReadCycle(5); + + const auto result = operation.work(); + EXPECT_TRUE(result.has_value()); + + EXPECT_TRUE(operation.cancel(false).has_value()); + + EXPECT_EQ(std::filesystem::file_size(options.localPath.generic_string() + ".filepart"), 5); + EXPECT_EQ(readFile(options.localPath.generic_string() + ".filepart"), fakeFileContent_.substr(0, 5)); + } + + TEST_F(DownloadOperationTests, FinalizeFailsIfFileExistsButOverwriteIsForbidden) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, fakeFileContent_.size()); + giveMockExpectedRead(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .mayOverwrite = false, + }; + DownloadOperation operation{fileStream, options}; + + enqueueFakeReadCycle(); + + const auto result = operation.work(); + EXPECT_TRUE(result.has_value()); + + { + std::ofstream file{options.localPath}; + } + + auto fin = operation.finalize(); + EXPECT_FALSE(fin.has_value()); + EXPECT_EQ(fin.error().type, DownloadOperation::ErrorType::FileExists); + } + + TEST_F(DownloadOperationTests, FinalizeSucceedsIfFileExistsButOverwriteIsAllowed) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, fakeFileContent_.size()); + giveMockExpectedRead(fileStream); + + auto options = DownloadOperation::DownloadOperationOptions{ + .localPath = isolateDirectory_.path() / "file.txt", + .mayOverwrite = true, + }; + DownloadOperation operation{fileStream, options}; + + enqueueFakeReadCycle(); + enqueueFakeReadCycle(0); + + { + std::ofstream file{options.localPath}; + } + + auto result = operation.work(); + EXPECT_TRUE(result.has_value()); + } + + TEST_F(DownloadOperationTests, ProgressCallbackIsCalledDuringRead) + { + using namespace SecureShell; + + auto fileStream = makeFileStreamMock(); + giveMockDefaultStat(fileStream, fakeFileContent_.size()); + giveMockExpectedRead(fileStream); + + std::vector> progressCalls; + + auto options = DownloadOperation::DownloadOperationOptions{ + .progressCallback = + [&progressCalls](std::int64_t min, std::int64_t max, std::int64_t current) { + progressCalls.emplace_back(min, max, current); + }, + .localPath = isolateDirectory_.path() / "file.txt", + .doCleanup = false, + }; + DownloadOperation operation{fileStream, options}; + + for (std::size_t i = 0; i < fakeFileContent_.size(); ++i) + { + enqueueFakeReadCycle(1); + } + enqueueFakeReadCycle(0); + + decltype(operation.work()) result; + do + { + result = operation.work(); + ASSERT_TRUE(result.has_value()); + } while (result.value() == DownloadOperation::WorkStatus::MoreWork); + + EXPECT_TRUE(operation.cancel(true).has_value()); + + ASSERT_EQ(progressCalls.size(), fakeFileContent_.size()); + + for (std::size_t i = 0; i < fakeFileContent_.size(); ++i) + { + EXPECT_EQ(std::get<0>(progressCalls[i]), 0); + EXPECT_EQ(std::get<1>(progressCalls[i]), fakeFileContent_.size()); + EXPECT_EQ(std::get<2>(progressCalls[i]), i + 1); + } + + EXPECT_EQ(std::get<0>(progressCalls.back()), 0); + EXPECT_EQ(std::get<1>(progressCalls.back()), fakeFileContent_.size()); + EXPECT_EQ(std::get<2>(progressCalls.back()), fakeFileContent_.size()); + } +} \ No newline at end of file diff --git a/diagrams/bulk_download.puml b/diagrams/bulk_download.puml new file mode 100644 index 00000000..14753994 --- /dev/null +++ b/diagrams/bulk_download.puml @@ -0,0 +1,38 @@ +@startuml +participant "UI Thread" +participant "Session Strand" +participant "SSH Thread" + +"UI Thread" -> "Session Strand" : start download on directory +activate "Session Strand" +"Session Strand" -> "SSH Thread" : stat remote path +activate "Session Strand" #orange +"Session Strand" -> "Session Strand" : blocking wait +"SSH Thread" -> "Session Strand" : stat result +deactivate "Session Strand" +"Session Strand" -> "Session Strand" : process stat result (here: directory) +"Session Strand" -> "Session Strand" : create scan operation +group Scan Operation + loop work loop on strand + "Session Strand" -> "SSH Thread": list directory (current) + activate "Session Strand" #orange + "Session Strand" -> "Session Strand": blocking wait + "SSH Thread" -> "Session Strand": list directory result + deactivate "Session Strand" + "Session Strand" -> "Session Strand": process list result + "Session Strand" -> "UI Thread": progress update + end +end +group Bulk Download Operation + loop work loop + "Session Strand" -> "SSH Thread": download some bytes + activate "Session Strand" #orange + "Session Strand" -> "Session Strand": blocking wait + "SSH Thread" -> "Session Strand": download result + deactivate "Session Strand" + "Session Strand" -> "Session Strand": process download result + "Session Strand" -> "UI Thread": progress update + end +end +deactivate "Session Strand" +@enduml \ No newline at end of file diff --git a/events/source/events/CMakeLists.txt b/events/source/events/CMakeLists.txt index 839eb72a..543ac5cd 100644 --- a/events/source/events/CMakeLists.txt +++ b/events/source/events/CMakeLists.txt @@ -1,6 +1,15 @@ add_library(events STATIC app_event_context.cpp) -target_include_directories(events +target_include_directories(events PUBLIC "${CMAKE_CURRENT_LIST_DIR}/../../include") -target_link_libraries(events PUBLIC nui-events PRIVATE core-target) \ No newline at end of file +target_link_libraries(events PUBLIC nui-events PRIVATE core-target) + +if (EMSCRIPTEN) + set_target_properties( + nui-events + PROPERTIES + LINK_FLAGS "-sMEMORY64=1" + COMPILE_FLAGS "-sMEMORY64=1" + ) +endif() \ No newline at end of file diff --git a/frontend/include/frontend/components/progress_bar.hpp b/frontend/include/frontend/components/progress_bar.hpp new file mode 100644 index 00000000..ea5a8335 --- /dev/null +++ b/frontend/include/frontend/components/progress_bar.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +class ProgressBar +{ + public: + struct Settings + { + std::string height{"30px"}; + long long min{0}; + long long max{100}; + bool showMinMax{false}; + bool byteMode{false}; + }; + ProgressBar(Settings settings); + + ROAR_PIMPL_SPECIAL_FUNCTIONS(ProgressBar); + + Nui::ElementRenderer operator()() const; + + /** + * Set the progress of the progress bar. + * Cannot be set if the progress bar is not mounted. + */ + void setProgress(long long current); + + /** + * @brief Get the maximum value of the progress bar. + * + * @return long long + */ + long long max() const; + + private: + void updateText(); + + private: + struct Implementation; + std::unique_ptr impl_; +}; \ No newline at end of file diff --git a/frontend/include/frontend/components/ui5-list.hpp b/frontend/include/frontend/components/ui5-list.hpp new file mode 100644 index 00000000..efdbe1fd --- /dev/null +++ b/frontend/include/frontend/components/ui5-list.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +// clang-format off + +#ifdef NUI_INLINE +// @inline(js, ui5-list) +js_import "@ui5/webcomponents/dist/List.js"; +js_import "@ui5/webcomponents/dist/ListItemStandard.js"; +js_import "@ui5/webcomponents/dist/ListItemCustom.js"; +js_import "@ui5/webcomponents/dist/ListItemGroup.js"; +// @endinline +#endif + +// clang-format on + +namespace ui5 +{ + NUI_MAKE_HTML_ELEMENT_RENAME(list, "ui5-list") + NUI_MAKE_HTML_ELEMENT_RENAME(li, "ui5-li") + NUI_MAKE_HTML_ELEMENT_RENAME(li_custom, "ui5-li-custom") + NUI_MAKE_HTML_ELEMENT_RENAME(li_group, "ui5-li-group") +} diff --git a/frontend/include/frontend/dialog/confirm_dialog.hpp b/frontend/include/frontend/dialog/confirm_dialog.hpp index fb1e1291..a72a3ae2 100644 --- a/frontend/include/frontend/dialog/confirm_dialog.hpp +++ b/frontend/include/frontend/dialog/confirm_dialog.hpp @@ -20,6 +20,11 @@ class ConfirmDialog No = 0b0000'1000, }; + friend auto operator|(Button lhs, Button rhs) + { + return static_cast