Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4033a85

Browse files
authored
Revert PR 17767 for fixing GPU memory usage regression (#18283) (#18311)
* Revert "Fix and optimize handling of vectorized memory accesses (#17767)" This reverts commit 5542d03. * add license to reverted file
1 parent 1eefe66 commit 4033a85

19 files changed

+464
-1344
lines changed

3rdparty/mshadow/mshadow/base.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ extern "C" {
295295
}
296296

297297
#include "./half.h"
298+
#include "./half2.h"
298299
#include "./bfloat.h"
299300
#define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \
300301
MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \
@@ -409,6 +410,11 @@ struct DataType<half::half_t> {
409410
#endif
410411
};
411412
template<>
413+
struct DataType<half::half2_t> {
414+
static const int kFlag = kFloat16;
415+
static const int kLanes = 2;
416+
};
417+
template<>
412418
struct DataType<bfloat::bf16_t> {
413419
static const int kFlag = kBfloat16;
414420
static const int kLanes = 1;
@@ -1161,6 +1167,48 @@ struct minimum {
11611167
}
11621168
#endif
11631169

1170+
#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \
1171+
switch (type) { \
1172+
case mshadow::kFloat32: \
1173+
{ \
1174+
typedef float DType; \
1175+
{__VA_ARGS__} \
1176+
} \
1177+
break; \
1178+
case mshadow::kFloat64: \
1179+
{ \
1180+
typedef double DType; \
1181+
{__VA_ARGS__} \
1182+
} \
1183+
break; \
1184+
case mshadow::kFloat16: \
1185+
{ \
1186+
typedef mshadow::half::half2_t DType; \
1187+
{__VA_ARGS__} \
1188+
} \
1189+
break; \
1190+
case mshadow::kUint8: \
1191+
{ \
1192+
typedef uint8_t DType; \
1193+
{__VA_ARGS__} \
1194+
} \
1195+
break; \
1196+
case mshadow::kInt32: \
1197+
{ \
1198+
typedef int32_t DType; \
1199+
{__VA_ARGS__} \
1200+
} \
1201+
break; \
1202+
case mshadow::kInt64: \
1203+
{ \
1204+
typedef int64_t DType; \
1205+
{__VA_ARGS__} \
1206+
} \
1207+
break; \
1208+
default: \
1209+
LOG(FATAL) << "Unknown type enum " << type; \
1210+
}
1211+
11641212
#define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \
11651213
switch (type) { \
11661214
case mshadow::kFloat32: \

3rdparty/mshadow/mshadow/half2.h

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* Copyright (c) 2017 by Contributors
22+
* \file half2.h
23+
* \brief definition of vector float16, half2 type.
24+
*
25+
* \author Antti-Pekka Hynninen
26+
*/
27+
#ifndef MSHADOW_HALF2_H_
28+
#define MSHADOW_HALF2_H_
29+
30+
#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
31+
#define MSHADOW_CUDA_HALF2 1
32+
#include <cuda_fp16.h>
33+
#else
34+
#define MSHADOW_CUDA_HALF2 0
35+
#endif
36+
37+
#include<math.h>
38+
39+
/*! \brief namespace for mshadow */
40+
namespace mshadow {
41+
/* \brief name space for host/device portable half-precision floats */
42+
namespace half {
43+
44+
#define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \
45+
template<typename T> \
46+
MSHADOW_XINLINE half2_t operator AOP (const T& a) { \
47+
return *this = half2_t(*this OP a); /* NOLINT(*)*/ \
48+
} \
49+
50+
class MSHADOW_ALIGNED(4) half2_t {
51+
public:
52+
#if MSHADOW_CUDA_HALF2
53+
half2 half2_;
54+
#else
55+
half_t half_t2[2];
56+
#endif
57+
58+
MSHADOW_XINLINE half2_t() {}
59+
60+
#if MSHADOW_CUDA_HALF2
61+
MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {}
62+
#else
63+
MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) {
64+
half_t2[0] = a;
65+
half_t2[1] = b;
66+
}
67+
#endif
68+
69+
MSHADOW_XINLINE explicit half2_t(int a) {
70+
#if MSHADOW_CUDA_HALF2
71+
half2_ = __half2half2(__int2half_rz(a));
72+
#else
73+
half_t2[0] = (half_t)a;
74+
half_t2[1] = (half_t)a;
75+
#endif
76+
}
77+
78+
MSHADOW_XINLINE half2_t operator+() {
79+
return *this;
80+
}
81+
82+
MSHADOW_XINLINE half2_t operator-() {
83+
#if MSHADOW_CUDA_HALF2
84+
return half2_t(__hneg2(half2_));
85+
#else
86+
return half2_t(-half_t2[0], -half_t2[1]);
87+
#endif
88+
}
89+
90+
MSHADOW_XINLINE half2_t operator=(const half2_t& a) {
91+
#if MSHADOW_CUDA_HALF2
92+
half2_ = a.half2_;
93+
#else
94+
half_t2[0] = a.half_t2[0];
95+
half_t2[1] = a.half_t2[1];
96+
#endif
97+
return a;
98+
}
99+
100+
MSHADOW_HALF2_ASSIGNOP(+=, +)
101+
MSHADOW_HALF2_ASSIGNOP(-=, -)
102+
MSHADOW_HALF2_ASSIGNOP(*=, *)
103+
MSHADOW_HALF2_ASSIGNOP(/=, /)
104+
};
105+
106+
/*! \brief overloaded + operator for half2_t */
107+
MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) {
108+
#if MSHADOW_CUDA_HALF2
109+
return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_),
110+
__high2float(a.half2_) + __high2float(b.half2_)));
111+
#else
112+
return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
113+
#endif
114+
}
115+
/*! \brief overloaded - operator for half2_t */
116+
MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) {
117+
#if MSHADOW_CUDA_HALF2
118+
return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_),
119+
__high2float(a.half2_) - __high2float(b.half2_)));
120+
#else
121+
return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
122+
#endif
123+
}
124+
/*! \brief overloaded * operator for half2_t */
125+
MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) {
126+
#if MSHADOW_CUDA_HALF2
127+
return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_),
128+
__high2float(a.half2_) * __high2float(b.half2_)));
129+
#else
130+
return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
131+
#endif
132+
}
133+
/*! \brief overloaded / operator for half2_t */
134+
MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) {
135+
#if MSHADOW_CUDA_HALF2
136+
return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_),
137+
__high2float(a.half2_) / __high2float(b.half2_)));
138+
#else
139+
return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
140+
#endif
141+
}
142+
/*! \brief overloaded % operator for half2_t */
143+
MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) {
144+
#if MSHADOW_CUDA_HALF2
145+
return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)),
146+
::fmod(__high2float(a.half2_), __high2float(b.half2_))));
147+
#else
148+
return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1]));
149+
#endif
150+
}
151+
/*! \brief overloaded == operator for half2_t */
152+
MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) {
153+
#if MSHADOW_CUDA_HALF2
154+
return __hbeq2(a.half2_, b.half2_);
155+
#else
156+
return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
157+
#endif
158+
}
159+
160+
} // namespace half
161+
} // namespace mshadow
162+
#endif // MSHADOW_HALF2_H_

0 commit comments

Comments
 (0)