Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/runtime/pack_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#ifndef TVM_RUNTIME_PACK_ARGS_H_
#define TVM_RUNTIME_PACK_ARGS_H_

#include <tvm/runtime/builtin_fp16.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>

Expand All @@ -43,6 +44,8 @@ namespace runtime {
* \brief argument union type of 32bit.
*/
union ArgUnion32 {
// uint16 useful for 16 bit types like FP16.
uint16_t v_uint16[2];
int32_t v_int32;
uint32_t v_uint32;
float v_float32;
Expand All @@ -52,6 +55,8 @@ union ArgUnion32 {
* \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
*/
union ArgUnion64 {
// uint16 useful for 16 bit types like FP16.
uint16_t v_uint16[4];
int32_t v_int32[2];
uint32_t v_uint32[2];
float v_float32[2];
Expand Down Expand Up @@ -125,6 +130,7 @@ enum ArgConvertCode {
INT64_TO_INT64,
INT64_TO_INT32,
INT64_TO_UINT32,
FLOAT64_TO_FLOAT16,
FLOAT64_TO_FLOAT32,
FLOAT64_TO_FLOAT64,
HANDLE_TO_HANDLE
Expand All @@ -140,6 +146,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) {
} else if (t.code == kDLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
if (t.bits == 16U) return FLOAT64_TO_FLOAT16;
} else if (t.code == kTVMOpaqueHandle) {
return HANDLE_TO_HANDLE;
}
Expand Down Expand Up @@ -178,6 +185,12 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector<ArgConvertCode>& code
addr[i] = &(holder[i]);
break;
}
case FLOAT64_TO_FLOAT16: {
holder[i].v_float32 = (args.values[i].v_float64);
holder[i].v_uint16[0] = __gnu_f2h_ieee(holder[i].v_float32);
addr[i] = &(holder[i]);
break;
}
}
}
f(args, ret, addr);
Expand Down Expand Up @@ -213,6 +226,11 @@ inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConv
holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64);
break;
}
case FLOAT64_TO_FLOAT16: {
holder[i].v_uint16[0] =
__gnu_f2h_ieee(static_cast<float>(args.values[base + i].v_float64));
break;
}
case HANDLE_TO_HANDLE: {
LOG(FATAL) << "not reached";
break;
Expand Down Expand Up @@ -261,6 +279,12 @@ inline PackedFunc PackFuncPackedArg_(F f, const std::vector<ArgConvertCode>& cod
++ptr;
break;
}
case FLOAT64_TO_FLOAT16: {
*reinterpret_cast<uint16_t*>(ptr) =
__gnu_f2h_ieee(static_cast<float>(args.values[i].v_float64));
++ptr;
break;
}
default: {
LOG(FATAL) << "not reached";
break;
Expand Down Expand Up @@ -337,4 +361,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_type
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACK_ARGS_H_
#endif // TVM_RUNTIME_PACK_ARGS_H_