diff --git a/client-lite/src/download/download.cpp b/client-lite/src/download/download.cpp index bdc23e44..f7efc60f 100644 --- a/client-lite/src/download/download.cpp +++ b/client-lite/src/download/download.cpp @@ -253,18 +253,7 @@ void Download::_Start() THROW_HR_IF(DO_E_DOWNLOAD_NO_URI, _url.empty()); THROW_HR_IF(DO_E_FILE_DOWNLOADSINK_UNSPECIFIED, _destFilePath.empty()); - // TODO(shishirb) expect file to not exist - _fileStream = std::make_unique(); - _fileStream->exceptions(std::fstream::badbit | std::fstream::failbit); - try - { - _fileStream->open(_destFilePath, (std::fstream::out | std::fstream::binary | std::fstream::trunc)); - } - catch (const std::system_error& e) - { - THROW_HR_MSG(E_INVALIDARG, "Error: %d, %s, file: %s", e.code().value(), e.what(), _destFilePath.data()); - } - + _fileStream = DOFile::Create(_destFilePath); _httpAgent = std::make_unique(*this); _proxyList.Refresh(_url); @@ -275,18 +264,10 @@ void Download::_Start() void Download::_Resume() { DO_ASSERT(_httpAgent); - DO_ASSERT(_fileStream); // BytesTotal can be zero if the start request never completed due to an error/pause DO_ASSERT((_status.BytesTotal != 0) || (_status.BytesTransferred == 0)); - try - { - _fileStream->open(_destFilePath, (std::fstream::out | std::fstream::binary | std::fstream::app)); - } - catch (const std::system_error& e) - { - THROW_HR_MSG(E_INVALIDARG, "Error: %d, %s, file: %s", e.code().value(), e.what(), _destFilePath.data()); - } + _fileStream = DOFile::Open(_destFilePath); if ((_status.BytesTotal != 0) && (_status.BytesTransferred == _status.BytesTotal)) { @@ -310,14 +291,14 @@ void Download::_Pause() _httpAgent->Close(); // waits until all callbacks are complete _timer.Stop(); _fHttpRequestActive = false; - _fileStream->close(); // safe to close now that no callbacks are expected + _fileStream.Close(); // safe to close now that no callbacks are expected } void Download::_Finalize() { _httpAgent->Close(); // waits until all callbacks are complete _fHttpRequestActive = false; - _fileStream.reset(); // safe since no callbacks are expected + _fileStream.Close(); // safe since no callbacks are expected _CancelTasks(); } @@ -328,7 +309,7 @@ void Download::_Abort() try _httpAgent->Close(); } _timer.Stop(); - _fileStream.reset(); + _fileStream.Close(); _CancelTasks(); if (!_destFilePath.empty()) { @@ -501,12 +482,7 @@ HRESULT Download::OnHeadersAvailable(UINT64 httpContext, UINT64) try HRESULT Download::OnData(_In_reads_bytes_(cbData) BYTE* pData, UINT cbData, UINT64, UINT64) try { - const auto before = _fileStream->tellp(); - _fileStream->write(reinterpret_cast(pData), cbData); - const auto after = _fileStream->tellp(); - _fileStream->flush(); - DO_ASSERT(before < after); - RETURN_HR_IF(HRESULT_FROM_WIN32(ERROR_BAD_LENGTH), (after - before) != cbData); + _fileStream.Append(pData, cbData); _taskThread.Sched([this, cbData]() { _status.BytesTransferred += cbData; diff --git a/client-lite/src/download/download.h b/client-lite/src/download/download.h index 6625b325..f8523fc1 100644 --- a/client-lite/src/download/download.h +++ b/client-lite/src/download/download.h @@ -2,6 +2,7 @@ #include #include +#include "do_file.h" #include "do_guid.h" #include "download_progress_tracker.h" #include "download_status.h" @@ -89,7 +90,7 @@ class Download : public IHttpAgentEvents StopWatch _timer; - std::unique_ptr _fileStream; + DOFile _fileStream; std::unique_ptr _httpAgent; std::string _responseHeaders; UINT _httpStatusCode { 0 }; diff --git a/client-lite/src/include/hresult_helpers.h b/client-lite/src/include/hresult_helpers.h index b1ba4bb2..068bf8de 100644 --- a/client-lite/src/include/hresult_helpers.h +++ b/client-lite/src/include/hresult_helpers.h @@ -96,7 +96,7 @@ static_assert(FAILED(E_NOT_SET), "FAILED macro does not recognize failure code") #endif // Convert std c++ and boost errors to NTSTATUS-like values but with 0xD0 facility (0xC0D00005 for example). -#define HRESULT_FROM_XPLAT_SYSERR(err) (0xC0000000 | (FACILITY_DELIVERY_OPTIMIZATION << 16) | ((HRESULT)(err) & 0x0000FFFF)) +#define HRESULT_FROM_XPLAT_SYSERR(err) (HRESULT)(0xC0000000 | (FACILITY_DELIVERY_OPTIMIZATION << 16) | ((HRESULT)(err) & 0x0000FFFF)) inline HRESULT HRESULT_FROM_STDCPP(const std::error_code& ec) { diff --git a/client-lite/src/util/do_file.cpp b/client-lite/src/util/do_file.cpp new file mode 100644 index 00000000..2afce21a --- /dev/null +++ b/client-lite/src/util/do_file.cpp @@ -0,0 +1,57 @@ +#include "do_common.h" +#include "do_file.h" + +#include +#include +#include + +DOFile::DOFile(int fd) : + _fd(fd) +{ + DO_ASSERT(_fd >= 0); +} + +DOFile DOFile::Create(const std::string& path) +{ + // TODO(shishirb) expect file to not exist + int fd = open(path.data(), O_CREAT | O_WRONLY, S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP | S_IROTH); + if (fd == -1) + { + THROW_HR_MSG(HRESULT_FROM_XPLAT_SYSERR(errno), "Cannot create file at %s", path.data()); + } + + return DOFile{fd}; +} + +DOFile DOFile::Open(const std::string& path) +{ + int fd = open(path.data(), O_APPEND | O_WRONLY); + if (fd == -1) + { + THROW_HR_MSG(HRESULT_FROM_XPLAT_SYSERR(errno), "Cannot open file at %s", path.data()); + } + + return DOFile{fd}; +} + +void DOFile::Append(_In_reads_bytes_(cbData) BYTE* pData, UINT cbData) const +{ + const ssize_t cbWritten = write(_fd, pData, cbData); + if (cbWritten == -1) + { + THROW_HR(HRESULT_FROM_XPLAT_SYSERR(errno)); + } + THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_BAD_LENGTH), cbWritten != static_cast(cbData)); +} + +void DOFile::Close() +{ + if (_fd != -1) + { + if (close(_fd) == -1) + { + THROW_HR(HRESULT_FROM_XPLAT_SYSERR(errno)); + } + _fd = -1; + } +} diff --git a/client-lite/src/util/do_file.h b/client-lite/src/util/do_file.h new file mode 100644 index 00000000..17c37c97 --- /dev/null +++ b/client-lite/src/util/do_file.h @@ -0,0 +1,41 @@ +#pragma once + +#include "do_noncopyable.h" + +// Write-only, binary file wrapper. +// Uses POSIX APIs to provide better error codes than std::fstream/boost::fstream. +class DOFile : DONonCopyable +{ +private: + DOFile(int fd); + +public: + DOFile() = default; + + DOFile(DOFile&& other) noexcept : + DOFile() + { + DO_ASSERT(!IsValid()); + *this = std::move(other); + } + + DOFile& operator=(DOFile&& other) noexcept + { + Close(); + std::swap(_fd, other._fd); + DO_ASSERT(!other.IsValid()); + return *this; + } + + static DOFile Create(const std::string& path); + static DOFile Open(const std::string& path); + + void Append(_In_reads_bytes_(cbData) BYTE* pData, UINT cbData) const; + void Close(); + + operator bool() const noexcept { return IsValid(); } + bool IsValid() const noexcept { return (_fd >= 0) ;} + +private: + int _fd { -1 }; +}; diff --git a/client-lite/src/util/do_noncopyable.h b/client-lite/src/util/do_noncopyable.h index 396572f6..f7309a59 100644 --- a/client-lite/src/util/do_noncopyable.h +++ b/client-lite/src/util/do_noncopyable.h @@ -1,12 +1,15 @@ #pragma once -// Handy base class to create non-copyable classes +// Handy base class to create non-copyable but movable classes class DONonCopyable { public: DONonCopyable(const DONonCopyable&) = delete; DONonCopyable& operator=(const DONonCopyable&) = delete; + DONonCopyable(DONonCopyable&&) noexcept = default; + DONonCopyable& operator=(DONonCopyable&&) noexcept = default; + protected: DONonCopyable() {} }; diff --git a/client-lite/test/download_manager_tests.cpp b/client-lite/test/download_manager_tests.cpp index dbf30661..76c2c192 100644 --- a/client-lite/test/download_manager_tests.cpp +++ b/client-lite/test/download_manager_tests.cpp @@ -271,3 +271,23 @@ TEST_F(DownloadManagerTests, InvalidState) manager.SetDownloadProperty(id, DownloadProperty::LocalPath, destFile); }); } + +TEST_F(DownloadManagerTests, DownloadPathAccessDenied) +{ + const auto id = manager.CreateDownload(g_smallFileUrl, "/var/run/doagent-test.bin"); + VerifyDOResultException(HRESULT_FROM_XPLAT_SYSERR(EACCES), [&]() + { + manager.StartDownload(id); + }); + manager.AbortDownload(id); +} + +TEST_F(DownloadManagerTests, DownloadPathNotFound) +{ + const auto id = manager.CreateDownload(g_smallFileUrl, "/var2/run/doagent-test.bin"); + VerifyDOResultException(HRESULT_FROM_XPLAT_SYSERR(ENOENT), [&]() + { + manager.StartDownload(id); + }); + manager.AbortDownload(id); +} diff --git a/sdk-cpp/tests/download_tests.cpp b/sdk-cpp/tests/download_tests.cpp index 72ab14a4..42833d71 100644 --- a/sdk-cpp/tests/download_tests.cpp +++ b/sdk-cpp/tests/download_tests.cpp @@ -17,7 +17,8 @@ namespace msdo = microsoft::deliveryoptimization; using namespace std::chrono_literals; // NOLINT(build/namespaces) -#define HTTP_E_STATUS_NOT_FOUND ((int32_t)0x80190194L) +#define DO_ERROR_FROM_XPLAT_SYSERR(err) (int32_t)(0xC0000000 | (0xD0 << 16) | ((int32_t)(err) & 0x0000FFFF)) +#define HTTP_E_STATUS_NOT_FOUND ((int32_t)0x80190194L) void WaitForDownloadCompletion(msdo::download& simpleDownload) { @@ -38,31 +39,6 @@ class DownloadTests : public ::testing::Test public: void SetUp() override; void TearDown() override; - - void SimpleDownloadTest(); - void SimpleDownloadTest_With404Url(); - void SimpleDownloadTest_WithMalformedPath(); - void SimpleDownloadTest_With404UrlAndMalformedPath(); - - //void Download1PausedDownload2SameDestTest(); - void Download1PausedDownload2SameFileDownload1Resume(); - void Download1NeverStartedDownload2CancelledSameFileTest(); - void ResumeOnAlreadyDownloadedFileTest(); - - void CancelDownloadOnCompletedState(); - void CancelDownloadInTransferredState(); - - void PauseResumeTest(); - void PauseResumeTestWithDelayAfterStart(); - - void SimpleBlockingDownloadTest(); - void CancelBlockingDownloadTest(); - void MultipleConsecutiveDownloadTest(); - void MultipleConcurrentDownloadTest(); - void MultipleConcurrentDownloadTest_WithCancels(); - - void SimpleBlockingDownloadTest_ClientNotRunning(); - void MultipleRestPortFileExists_Download(); }; void DownloadTests::SetUp() @@ -149,7 +125,7 @@ TEST_F(DownloadTests, SimpleDownloadTest_WithMalformedPath) } catch (const msdo::exception& e) { - ASSERT_EQ(e.error_code(), static_cast(msdo::errc::invalid_arg)); + ASSERT_EQ(e.error_code(), DO_ERROR_FROM_XPLAT_SYSERR(ENOENT)); ASSERT_FALSE(boost::filesystem::exists(g_tmpFileName)); } } @@ -165,7 +141,7 @@ TEST_F(DownloadTests, SimpleDownloadTest_With404UrlAndMalformedPath) } catch (const msdo::exception& e) { - ASSERT_EQ(e.error_code(), static_cast(msdo::errc::invalid_arg)); + ASSERT_EQ(e.error_code(), DO_ERROR_FROM_XPLAT_SYSERR(ENOENT)); ASSERT_FALSE(boost::filesystem::exists(g_tmpFileName)); } }