From 472c435bb20358dc6c4f5194e9dc43dfd4e52ccc Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 15 Dec 2020 11:51:59 -0800 Subject: [PATCH 1/2] [metal] update language version --- src/runtime/metal/metal_module.mm | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 7d46811fe78d..981dd6129f9e 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -88,8 +88,7 @@ void SaveToBinary(dmlc::Stream* stream) final { if (e.lib == nil) { if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; - // Use the Metal 1.2 for now. - opts.languageVersion = MTLLanguageVersion1_2; + opts.languageVersion = MTLLanguageVersion2_3; opts.fastMathEnabled = YES; // opts = nil; e.lib = [w->devices[device_id] From 279e4cce2a9a8c63abb28c13a9b093df05c4c91a Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 15 Dec 2020 14:58:02 -0800 Subject: [PATCH 2/2] fix mps --- src/runtime/contrib/mps/conv.mm | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 3b16f0820d64..b860ee29bdf5 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -34,7 +34,8 @@ id dev = entry_ptr->metal_api->GetDevice(buf->ctx); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, - [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); + [mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype, + nullptr); MPSImageDescriptor* desc = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 @@ -69,7 +70,8 @@ imageIndex:0]; entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, - [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); + [mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype, + nullptr); }); TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -111,7 +113,8 @@ id bufB = (__bridge id)(weight->data); id tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, - [bufB length], weight -> ctx, weight -> ctx, nullptr); + [bufB length], weight -> ctx, weight -> ctx, tmp_in.dtype, + nullptr); float* ptr_w = (float*)[tempB contents]; // output to MPSImage DLTensor tmp_out;