Skip to content

cuda::std::complex specializations for half and bfloat#1140

Merged
miscco merged 37 commits intoNVIDIA:mainfrom
griwes:feature/small-complex
Mar 12, 2024
Merged

cuda::std::complex specializations for half and bfloat#1140
miscco merged 37 commits intoNVIDIA:mainfrom
griwes:feature/small-complex

Conversation

@griwes
Copy link
Copy Markdown
Contributor

@griwes griwes commented Nov 21, 2023

Description

Resolves #1139

Introduce specializations of complex<T> for half and bfloat.

Checklist

  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Additional checklist

  • The documentation contains the actual release this will be made available in.

@griwes griwes requested review from a team as code owners November 21, 2023 23:42
@griwes griwes requested review from alliepiper and ericniebler and removed request for a team November 21, 2023 23:42
@griwes griwes marked this pull request as draft November 21, 2023 23:45
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated
@griwes griwes force-pushed the feature/small-complex branch 2 times, most recently from 65e6f36 to 744f2d1 Compare November 22, 2023 19:57
Copy link
Copy Markdown
Contributor

@miscco miscco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great job working around the quirks of those types 👏

I would love to move some of the traits around (e.g. into is_floating_point.h) and importantly add a proper named define that one can grep for.

Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex Outdated
Copy link
Copy Markdown
Contributor

@gonzalobg gonzalobg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, thanks for working on this @griwes !
I think I missed some static_asserts for the size and alignment of complex half and bfloat, do we have these somewhere? Thanks!

@griwes griwes force-pushed the feature/small-complex branch 3 times, most recently from c2d87c2 to add3d52 Compare January 27, 2024 05:52
Copy link
Copy Markdown
Contributor

@miscco miscco left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether we should just keep all the _LIBCUDACXX_HAS_NO_NVFP16 in place and define it conditionally for host

Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath Outdated
griwes added 10 commits February 6, 2024 20:56
Specifically:
* disable BF16 when FP16 is disabled, since the former includes the
  latter;
* disable both when the toolkit version is lower than 12.2, since 12.2
  is when both types got the host versions of a lot of functions we need
  to make useful heterogeneous things with them;
* disable both in host-only TU, as there's no easy way I could find to
  detect the condition above. I've included an opt-in macro for
  asserting that the headers (if available) are from a sufficiently new
  CTK, will add that to docs in a later commit.
@griwes griwes force-pushed the feature/small-complex branch from f2893fa to 8121bba Compare February 7, 2024 04:57
NVCC is spewing code that makes various versions of clang unhappy about
a deprecated implicit copy constructor of a lambda wrapper, so just work
around that by not using one.
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__cuda/cmath_nvfp16.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h Outdated
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/cmath
Comment thread libcudacxx/include/cuda/std/detail/libcxx/include/complex
@griwes griwes added the libcu++ For all items related to libcu++ label Feb 27, 2024
@miscco miscco mentioned this pull request Feb 28, 2024
2 tasks
Comment thread libcudacxx/docs/standard_api/numerics_library/complex.md Outdated
Comment thread libcudacxx/docs/standard_api/numerics_library/complex.md Outdated
miscco and others added 2 commits March 11, 2024 18:46
@miscco miscco enabled auto-merge (squash) March 11, 2024 17:47
@leofang
Copy link
Copy Markdown
Member

leofang commented Mar 11, 2024

Note: As discussed offline, local tests show that at least on sm86/89 we need this patch for performance reasons. I haven't had a chance to test on sm70/80/90, though.

diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/complex b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
index 3ba249779..416c0e71d 100644
--- a/libcudacxx/include/cuda/std/detail/libcxx/include/complex
+++ b/libcudacxx/include/cuda/std/detail/libcxx/include/complex
@@ -1702,6 +1702,16 @@ atanh(const complex<_Tp>& __x)
     return complex<_Tp>(__constexpr_copysign(__z.real(), __x.real()), __constexpr_copysign(__z.imag(), __x.imag()));
 }
 
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atanh(const complex<__half>& __x)
+{
+    complex<float> __temp(__x);
+    __temp = _CUDA_VSTD::atanh(__temp);
+    return complex<__half>(__temp.real(), __temp.imag());
+}
+
 // sinh
 
 template<class _Tp>
@@ -1815,6 +1825,16 @@ atan(const complex<_Tp>& __x)
     return complex<_Tp>(__z.imag(), -__z.real());
 }
 
+// we add a specialization for fp16 atanh because of performance issues
+template<>
+_LIBCUDACXX_INLINE_VISIBILITY complex<__half>
+atan(const complex<__half>& __x)
+{
+    complex<float> __temp(__x);
+    __temp = _CUDA_VSTD::atan(__temp);
+    return complex<__half>(__temp.real(), __temp.imag());
+}
+
 // sin
 
 template<class _Tp>

@miscco
Copy link
Copy Markdown
Contributor

miscco commented Mar 12, 2024

@leofang I added some workarounds for asinh acosh atanh and cosh

@miscco miscco merged commit ae0ee04 into NVIDIA:main Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

libcu++ For all items related to libcu++

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

Specializations of complex<T> for half and bfloat

6 participants