diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index c95d578df686..270d81f218ea 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -178,6 +178,17 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } bool fail = false; if (t.is_float()) { + // Need to care about sizes and alignment of half3/float3 because tir representation might not + // be aware of Metal half3/float3 details and can treat them as just three elements, + // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/ + // float13 - 16bytes). + // Example of problematic pattern: filling of threadgroup packed array using float3 elements + // by threads concurrently can lead to datarace and wrong data in threadgroup shared array. + // packed_(half3/float3) are exactly datatypes dealing with 3 elements and per-element + // alignment + if (lanes == 3) { + os << "packed_"; + } switch (t.bits()) { case 16: os << "half";