diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 3776d18fafcc..7f8a9571f4c0 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -31,6 +31,7 @@ #ifndef TVM_RUNTIME_PACK_ARGS_H_ #define TVM_RUNTIME_PACK_ARGS_H_ +#include #include #include @@ -43,6 +44,8 @@ namespace runtime { * \brief argument union type of 32bit. */ union ArgUnion32 { + // uint16 useful for 16 bit types like FP16. + uint16_t v_uint16[2]; int32_t v_int32; uint32_t v_uint32; float v_float32; @@ -52,6 +55,8 @@ union ArgUnion32 { * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime. */ union ArgUnion64 { + // uint16 useful for 16 bit types like FP16. + uint16_t v_uint16[4]; int32_t v_int32[2]; uint32_t v_uint32[2]; float v_float32[2]; @@ -125,6 +130,7 @@ enum ArgConvertCode { INT64_TO_INT64, INT64_TO_INT32, INT64_TO_UINT32, + FLOAT64_TO_FLOAT16, FLOAT64_TO_FLOAT32, FLOAT64_TO_FLOAT64, HANDLE_TO_HANDLE @@ -140,6 +146,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { } else if (t.code == kDLFloat) { if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 32U) return FLOAT64_TO_FLOAT32; + if (t.bits == 16U) return FLOAT64_TO_FLOAT16; } else if (t.code == kTVMOpaqueHandle) { return HANDLE_TO_HANDLE; } @@ -178,6 +185,12 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code addr[i] = &(holder[i]); break; } + case FLOAT64_TO_FLOAT16: { + holder[i].v_float32 = (args.values[i].v_float64); + holder[i].v_uint16[0] = __gnu_f2h_ieee(holder[i].v_float32); + addr[i] = &(holder[i]); + break; + } } } f(args, ret, addr); @@ -213,6 +226,11 @@ inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector(args.values[base + i].v_float64); break; } + case FLOAT64_TO_FLOAT16: { + holder[i].v_uint16[0] = + __gnu_f2h_ieee(static_cast(args.values[base + i].v_float64)); + break; + } case HANDLE_TO_HANDLE: { LOG(FATAL) << "not reached"; break; @@ -261,6 +279,12 @@ inline PackedFunc PackFuncPackedArg_(F f, const std::vector& cod ++ptr; break; } + case FLOAT64_TO_FLOAT16: { + *reinterpret_cast(ptr) = + __gnu_f2h_ieee(static_cast(args.values[i].v_float64)); + ++ptr; + break; + } default: { LOG(FATAL) << "not reached"; break; @@ -337,4 +361,4 @@ inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_type } } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_PACK_ARGS_H_ +#endif // TVM_RUNTIME_PACK_ARGS_H_ \ No newline at end of file