Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 45 additions & 38 deletions strings/base_com_ptr.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@

WINRT_EXPORT namespace winrt
{
template <typename T>
struct com_ptr;
}

namespace winrt::impl
{
template <typename T, typename F, typename...Args>
int32_t capture_to(void**result, F function, Args&& ...args)
{
return function(args..., guid_of<T>(), result);
}

template <typename T, typename O, typename M, typename...Args, std::enable_if_t<std::is_class_v<O> || std::is_union_v<O>, int> = 0>
int32_t capture_to(void** result, O* object, M method, Args&& ...args)
{
return (object->*method)(args..., guid_of<T>(), result);
}

template <typename T, typename O, typename M, typename...Args>
int32_t capture_to(void** result, com_ptr<O> const& object, M method, Args&& ...args);
}

WINRT_EXPORT namespace winrt
{
template <typename T>
Expand Down Expand Up @@ -162,28 +186,16 @@ WINRT_EXPORT namespace winrt
*other = m_ptr;
}

template <typename F, typename...Args>
bool try_capture(F function, Args&&...args)
{
return function(args..., guid_of<T>(), put_void()) >= 0;
}

template <typename O, typename M, typename...Args>
bool try_capture(com_ptr<O> const& object, M method, Args&&...args)
{
return (object.get()->*(method))(args..., guid_of<T>(), put_void()) >= 0;
}

template <typename F, typename...Args>
void capture(F function, Args&&...args)
template <typename...Args>
bool try_capture(Args&&...args)
{
check_hresult(function(args..., guid_of<T>(), put_void()));
return impl::capture_to<T>(put_void(), std::forward<Args>(args)...) >= 0;
}

template <typename O, typename M, typename...Args>
void capture(com_ptr<O> const& object, M method, Args&&...args)
template <typename...Args>
void capture(Args&&...args)
{
check_hresult((object.get()->*(method))(args..., guid_of<T>(), put_void()));
check_hresult(impl::capture_to<T>(put_void(), std::forward<Args>(args)...));
}

private:
Expand Down Expand Up @@ -225,33 +237,19 @@ WINRT_EXPORT namespace winrt
type* m_ptr{};
};

template <typename T, typename F, typename...Args>
impl::com_ref<T> try_capture(F function, Args&& ...args)
template <typename T, typename...Args>
impl::com_ref<T> try_capture(Args&& ...args)
{
void* result{};
function(args..., guid_of<T>(), &result);
impl::capture_to<T>(&result, std::forward<Args>(args)...);
return { result, take_ownership_from_abi };
}

template <typename T, typename O, typename M, typename...Args>
impl::com_ref<T> try_capture(com_ptr<O> const& object, M method, Args&& ...args)
{
void* result{};
(object.get()->*(method))(args..., guid_of<T>(), &result);
return { result, take_ownership_from_abi };
}
template <typename T, typename F, typename...Args>
impl::com_ref<T> capture(F function, Args&& ...args)
{
void* result{};
check_hresult(function(args..., guid_of<T>(), &result));
return { result, take_ownership_from_abi };
}
template <typename T, typename O, typename M, typename...Args>
impl::com_ref<T> capture(com_ptr<O> const& object, M method, Args && ...args)
template <typename T, typename...Args>
impl::com_ref<T> capture(Args&& ...args)
{
void* result{};
check_hresult((object.get()->*(method))(args..., guid_of<T>(), &result));
check_hresult(impl::capture_to<T>(&result, std::forward<Args>(args)...));
return { result, take_ownership_from_abi };
}

Expand Down Expand Up @@ -340,6 +338,15 @@ WINRT_EXPORT namespace winrt
}
}

namespace winrt::impl
{
template <typename T, typename O, typename M, typename...Args>
int32_t capture_to(void** result, com_ptr<O> const& object, M method, Args&& ...args)
{
return (object.get()->*(method))(args..., guid_of<T>(), result);
}
}

template <typename T>
void** IID_PPV_ARGS_Helper(winrt::com_ptr<T>* ptr) noexcept
{
Expand Down
25 changes: 25 additions & 0 deletions test/old_tests/UnitTests/capture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,46 +39,71 @@ HRESULT __stdcall CreateCapture(int value, GUID const& iid, void** object) noexc

TEST_CASE("capture")
{
// Capture from global function.
com_ptr<ICapture> a = capture<ICapture>(CreateCapture, 10);
REQUIRE(a->GetValue() == 10);
a = nullptr;
a.capture(CreateCapture, 20);
REQUIRE(a->GetValue() == 20);

// Capture from com_ptr + method.
auto b = capture<ICapture>(a, &ICapture::CreateMemberCapture, 30);
REQUIRE(b->GetValue() == 30);
b = nullptr;
b.capture(a, &ICapture::CreateMemberCapture, 40);
REQUIRE(b->GetValue() == 40);

// Capture from raw pointer + method.
b = nullptr;
b = capture<ICapture>(a.get(), &ICapture::CreateMemberCapture, 30);
REQUIRE(b->GetValue() == 30);
b = nullptr;
b.capture(a.get(), &ICapture::CreateMemberCapture, 40);
REQUIRE(b->GetValue() == 40);

com_ptr<IDispatch> d;

REQUIRE_THROWS_AS(capture<IDispatch>(CreateCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(d.capture(CreateCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(d.capture(a, &ICapture::CreateMemberCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(capture<IDispatch>(a.get(), &ICapture::CreateMemberCapture, 0), hresult_no_interface);
REQUIRE_THROWS_AS(d.capture(a.get(), &ICapture::CreateMemberCapture, 0), hresult_no_interface);
}

TEST_CASE("try_capture")
{
// Identical to the "capture" test above, just with different
// error handling.

// Capture from global function.
com_ptr<ICapture> a = try_capture<ICapture>(CreateCapture, 10);
REQUIRE(a->GetValue() == 10);
a = nullptr;
REQUIRE(a.try_capture(CreateCapture, 20));
REQUIRE(a->GetValue() == 20);

// Capture from com_ptr + method.
auto b = try_capture<ICapture>(a, &ICapture::CreateMemberCapture, 30);
REQUIRE(b->GetValue() == 30);
b = nullptr;
REQUIRE(b.try_capture(a, &ICapture::CreateMemberCapture, 40));
REQUIRE(b->GetValue() == 40);

// Capture from raw pointer + method.
b = nullptr;
b = try_capture<ICapture>(a.get(), &ICapture::CreateMemberCapture, 30);
REQUIRE(b->GetValue() == 30);
b = nullptr;
b.try_capture(a.get(), &ICapture::CreateMemberCapture, 40);
REQUIRE(b->GetValue() == 40);

com_ptr<IDispatch> d;

REQUIRE(!try_capture<IDispatch>(CreateCapture, 0));
REQUIRE(!d.try_capture(CreateCapture, 0));
REQUIRE(!try_capture<IDispatch>(a, &ICapture::CreateMemberCapture, 0));
REQUIRE(!d.try_capture(a, &ICapture::CreateMemberCapture, 0));
REQUIRE(!try_capture<IDispatch>(a.get(), &ICapture::CreateMemberCapture, 0));
REQUIRE(!d.try_capture(a.get(), &ICapture::CreateMemberCapture, 0));
}