From 94c3e0bf0244c5690c024543fd32e6a8f25de7d9 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Fri, 19 Jan 2024 13:06:52 -0800 Subject: [PATCH 1/3] Add new mps runtime with support for FP16, lifted and unlifted graphs (iOS15+, macOS12+) --- backends/apple/mps/CMakeLists.txt | 36 +- backends/apple/mps/mps_preprocess.py | 905 +++--------------- .../apple/mps/operations/ActivationOps.mm | 165 ---- backends/apple/mps/operations/BinaryOps.h | 70 -- backends/apple/mps/operations/BinaryOps.mm | 222 ----- backends/apple/mps/operations/BitwiseOps.mm | 19 - backends/apple/mps/operations/ClampOps.mm | 58 -- backends/apple/mps/operations/ConstantOps.mm | 55 -- .../apple/mps/operations/ConvolutionOps.mm | 80 -- backends/apple/mps/operations/Indexing.mm | 30 - .../apple/mps/operations/LinearAlgebraOps.mm | 48 - .../apple/mps/operations/NormalizationOps.mm | 93 -- backends/apple/mps/operations/PoolingOps.mm | 109 --- backends/apple/mps/operations/RangeOps.mm | 31 - backends/apple/mps/operations/ReduceOps.mm | 238 ----- backends/apple/mps/operations/ShapeOps.mm | 260 ----- backends/apple/mps/operations/UnaryOps.h | 17 - backends/apple/mps/operations/UnaryOps.mm | 30 - backends/apple/mps/operators/__init__.py | 70 ++ .../apple/mps/operators/activation_ops.py | 99 ++ backends/apple/mps/operators/binary_ops.py | 156 +++ backends/apple/mps/operators/clamp_ops.py | 60 ++ backends/apple/mps/operators/constant_ops.py | 104 ++ .../apple/mps/operators/convolution_ops.py | 70 ++ backends/apple/mps/operators/indexing_ops.py | 70 ++ .../apple/mps/operators/linear_algebra_ops.py | 52 + backends/apple/mps/operators/node_visitor.py | 404 ++++++++ .../apple/mps/operators/normalization_ops.py | 97 ++ backends/apple/mps/operators/op_clone.py | 35 + backends/apple/mps/operators/op_getitem.py | 29 + backends/apple/mps/operators/pad_ops.py | 37 + backends/apple/mps/operators/pooling_ops.py | 142 +++ backends/apple/mps/operators/range_ops.py | 53 + backends/apple/mps/operators/reduce_ops.py | 39 + backends/apple/mps/operators/shape_ops.py | 264 +++++ backends/apple/mps/operators/unary_ops.py | 122 +++ .../apple/mps/partition/mps_partitioner.py | 89 +- backends/apple/mps/runtime/MPSBackend.mm | 6 +- backends/apple/mps/runtime/MPSCompiler.h | 4 +- backends/apple/mps/runtime/MPSCompiler.mm | 80 +- backends/apple/mps/runtime/MPSExecutor.h | 6 +- backends/apple/mps/runtime/MPSExecutor.mm | 34 +- backends/apple/mps/runtime/MPSGraphBuilder.h | 183 ++++ backends/apple/mps/runtime/MPSGraphBuilder.mm | 115 +++ backends/apple/mps/runtime/MPSStream.h | 1 + backends/apple/mps/runtime/MPSStream.mm | 4 + .../mps/runtime/operations/ActivationOps.mm | 155 +++ .../apple/mps/runtime/operations/BinaryOps.mm | 285 ++++++ .../apple/mps/runtime/operations/ClampOps.mm | 94 ++ .../mps/runtime/operations/ConstantOps.mm | 63 ++ .../mps/runtime/operations/ConvolutionOps.mm | 146 +++ .../mps/runtime/operations/IndexingOps.mm | 114 +++ .../mps/runtime/operations/LinearAlgebra.mm | 86 ++ .../runtime/operations/NormalizationOps.mm | 129 +++ .../mps/runtime/operations/OperationUtils.h | 54 ++ .../mps/runtime/operations/OperationUtils.mm | 283 ++++++ .../mps/{ => runtime}/operations/PadOps.mm | 101 +- .../mps/runtime/operations/PoolingOps.mm | 107 +++ .../apple/mps/runtime/operations/RangeOps.mm | 53 + .../apple/mps/runtime/operations/ReduceOps.mm | 59 ++ .../apple/mps/runtime/operations/ShapeOps.mm | 275 ++++++ .../apple/mps/runtime/operations/UnaryOps.mm | 134 +++ .../mps/serialization/mps_graph_schema.py | 743 ++++++++++++++ .../mps/serialization/mps_graph_serialize.py | 30 + backends/apple/mps/serialization/schema.fbs | 443 +++++++++ backends/apple/mps/targets.bzl | 59 +- backends/apple/mps/test/test_mps.py | 239 +++-- backends/apple/mps/test/test_mps_utils.py | 124 ++- backends/apple/mps/utils/Bindings.mm | 345 ------- backends/apple/mps/utils/MPSGraphInterface.h | 259 ----- backends/apple/mps/utils/MPSGraphInterface.mm | 154 --- .../apple/mps/utils/MPSGraphPackageExport.h | 16 - backends/apple/mps/utils/OperationUtils.h | 101 -- backends/apple/mps/utils/OperationUtils.mm | 144 --- backends/apple/mps/utils/graph_bindings.py | 42 - backends/apple/mps/utils/mps_utils.py | 100 +- build/cmake_deps.toml | 8 + examples/apple/mps/CMakeLists.txt | 5 +- examples/apple/mps/README.md | 2 +- .../executor_runner/mps_executor_runner.mm | 76 +- .../apple/mps/executor_runner/targets.bzl | 3 +- examples/apple/mps/scripts/mps_example.py | 36 +- 82 files changed, 6269 insertions(+), 3689 deletions(-) delete mode 100644 backends/apple/mps/operations/ActivationOps.mm delete mode 100644 backends/apple/mps/operations/BinaryOps.h delete mode 100644 backends/apple/mps/operations/BinaryOps.mm delete mode 100644 backends/apple/mps/operations/BitwiseOps.mm delete mode 100644 backends/apple/mps/operations/ClampOps.mm delete mode 100644 backends/apple/mps/operations/ConstantOps.mm delete mode 100644 backends/apple/mps/operations/ConvolutionOps.mm delete mode 100644 backends/apple/mps/operations/Indexing.mm delete mode 100644 backends/apple/mps/operations/LinearAlgebraOps.mm delete mode 100644 backends/apple/mps/operations/NormalizationOps.mm delete mode 100644 backends/apple/mps/operations/PoolingOps.mm delete mode 100644 backends/apple/mps/operations/RangeOps.mm delete mode 100644 backends/apple/mps/operations/ReduceOps.mm delete mode 100644 backends/apple/mps/operations/ShapeOps.mm delete mode 100644 backends/apple/mps/operations/UnaryOps.h delete mode 100644 backends/apple/mps/operations/UnaryOps.mm create mode 100644 backends/apple/mps/operators/__init__.py create mode 100644 backends/apple/mps/operators/activation_ops.py create mode 100644 backends/apple/mps/operators/binary_ops.py create mode 100644 backends/apple/mps/operators/clamp_ops.py create mode 100644 backends/apple/mps/operators/constant_ops.py create mode 100644 backends/apple/mps/operators/convolution_ops.py create mode 100644 backends/apple/mps/operators/indexing_ops.py create mode 100644 backends/apple/mps/operators/linear_algebra_ops.py create mode 100644 backends/apple/mps/operators/node_visitor.py create mode 100644 backends/apple/mps/operators/normalization_ops.py create mode 100644 backends/apple/mps/operators/op_clone.py create mode 100644 backends/apple/mps/operators/op_getitem.py create mode 100644 backends/apple/mps/operators/pad_ops.py create mode 100644 backends/apple/mps/operators/pooling_ops.py create mode 100644 backends/apple/mps/operators/range_ops.py create mode 100644 backends/apple/mps/operators/reduce_ops.py create mode 100644 backends/apple/mps/operators/shape_ops.py create mode 100644 backends/apple/mps/operators/unary_ops.py create mode 100644 backends/apple/mps/runtime/MPSGraphBuilder.h create mode 100644 backends/apple/mps/runtime/MPSGraphBuilder.mm create mode 100644 backends/apple/mps/runtime/operations/ActivationOps.mm create mode 100644 backends/apple/mps/runtime/operations/BinaryOps.mm create mode 100644 backends/apple/mps/runtime/operations/ClampOps.mm create mode 100644 backends/apple/mps/runtime/operations/ConstantOps.mm create mode 100644 backends/apple/mps/runtime/operations/ConvolutionOps.mm create mode 100644 backends/apple/mps/runtime/operations/IndexingOps.mm create mode 100644 backends/apple/mps/runtime/operations/LinearAlgebra.mm create mode 100644 backends/apple/mps/runtime/operations/NormalizationOps.mm create mode 100644 backends/apple/mps/runtime/operations/OperationUtils.h create mode 100644 backends/apple/mps/runtime/operations/OperationUtils.mm rename backends/apple/mps/{ => runtime}/operations/PadOps.mm (59%) create mode 100644 backends/apple/mps/runtime/operations/PoolingOps.mm create mode 100644 backends/apple/mps/runtime/operations/RangeOps.mm create mode 100644 backends/apple/mps/runtime/operations/ReduceOps.mm create mode 100644 backends/apple/mps/runtime/operations/ShapeOps.mm create mode 100644 backends/apple/mps/runtime/operations/UnaryOps.mm create mode 100644 backends/apple/mps/serialization/mps_graph_schema.py create mode 100644 backends/apple/mps/serialization/mps_graph_serialize.py create mode 100644 backends/apple/mps/serialization/schema.fbs delete mode 100644 backends/apple/mps/utils/Bindings.mm delete mode 100644 backends/apple/mps/utils/MPSGraphInterface.h delete mode 100644 backends/apple/mps/utils/MPSGraphInterface.mm delete mode 100644 backends/apple/mps/utils/MPSGraphPackageExport.h delete mode 100644 backends/apple/mps/utils/OperationUtils.h delete mode 100644 backends/apple/mps/utils/OperationUtils.mm delete mode 100644 backends/apple/mps/utils/graph_bindings.py diff --git a/backends/apple/mps/CMakeLists.txt b/backends/apple/mps/CMakeLists.txt index f836a4d5104..8cbf7072b8f 100644 --- a/backends/apple/mps/CMakeLists.txt +++ b/backends/apple/mps/CMakeLists.txt @@ -20,14 +20,46 @@ if(NOT EXECUTORCH_ROOT) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) endif() +if(NOT FLATC_EXECUTABLE) + set(FLATC_EXECUTABLE flatc) +endif() + set(_common_compile_options -Wno-deprecated-declarations) set(_common_include_directories ${EXECUTORCH_ROOT}/..) +set(_mps_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include") + +# Paths to headers generated from the .fbs files. +set(_mps_schema__outputs) +foreach(fbs_file ${_mps_schema__srcs}) + string(REGEX REPLACE "serialization/([^/]+)[.]fbs$" "\\1_generated.h" + generated "${fbs_file}") + list(APPEND _mps_schema__outputs + "${_mps_schema__include_dir}/executorch/${generated}") +endforeach() + +# Generate the headers from the .fbs files. +add_custom_command( + OUTPUT ${_mps_schema__outputs} + COMMAND + ${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o + "${_mps_schema__include_dir}/executorch/backends/apple/mps" + ${_mps_schema__srcs} + WORKING_DIRECTORY ${EXECUTORCH_ROOT} + COMMENT "Generating mps_schema headers" + VERBATIM) + +add_library(mps_schema INTERFACE ${_mps_schema__outputs}) +set_target_properties(mps_schema PROPERTIES LINKER_LANGUAGE CXX) +target_include_directories( +mps_schema INTERFACE ${_mps_schema__include_dir} + ${EXECUTORCH_ROOT}/third-party/flatbuffers/include + ${_common_include_directories}) + list(TRANSFORM _mps_backend__srcs PREPEND "${EXECUTORCH_ROOT}/") add_library(mpsdelegate ${_mps_backend__srcs}) -target_link_libraries(mpsdelegate PRIVATE baundled_program +target_link_libraries(mpsdelegate PRIVATE baundled_program mps_schema ${_executor_runner_libs}) -target_include_directories(mpsdelegate PUBLIC ${_common_include_directories}) install( TARGETS mpsdelegate diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index c8887cee4c5..0e543d7e079 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -3,782 +3,181 @@ # Provided subject to the LICENSE file in the top level directory. # -from typing import Any, Dict, final, List, Optional, Union +import logging +from typing import Dict, final, List import torch -from executorch.backends.apple.mps.utils.graph_bindings import graph_bindings -from executorch.backends.apple.mps.utils.mps_utils import get_mps_data_type +from executorch.backends.apple.mps.operators.node_visitor import ( + get_node_visitors, + NodeVisitor, + process_output_node, + process_placeholder_nodes, +) + +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSGraph, + MPSTensor, +) + +from executorch.backends.apple.mps.serialization.mps_graph_serialize import ( + convert_to_flatbuffer, +) +from executorch.backends.apple.mps.utils.mps_utils import is_parameter from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, PreprocessResult, ) - -from executorch.exir.dialects._ops import ops as exir_ops from torch._export.exported_program import ExportedProgram -from torch._subclasses import FakeTensor - -def get_param_from_node( - node: torch.fx.Node, edge_program: ExportedProgram -) -> Optional[torch.nn.Parameter]: - """ - Returns the parameter associated with the given node in the edge program. - Returns None if the node is not a parameter within the edge_program - """ - if node.name in edge_program.graph_signature.inputs_to_parameters: - parameter_name = edge_program.graph_signature.inputs_to_parameters[node.name] - return edge_program.state_dict[parameter_name] - elif node.name in edge_program.graph_signature.inputs_to_buffers: - buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name] - return edge_program.state_dict[buffer_name] - return None - - -def create_mpsgraph_constant_tensor(tensor: torch.Tensor, mpsGraph): - if tensor.dim() == 0: - return mpsGraph.constant(tensor.item(), get_mps_data_type(tensor.dtype)) - else: - return mpsGraph.constantTensor(tensor, get_mps_data_type(tensor.dtype)) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) @final class MPSBackend(BackendDetails): @staticmethod - def fetch_attr(node: torch.fx.Node, edge_program: ExportedProgram): - attr_itr = edge_program - - attr_itr = getattr(node.graph.owning_module, node.target) - return attr_itr - - @staticmethod - def eval_shape(node): - def eval_expr(symint: Union[int, torch.SymInt, FakeTensor]) -> Optional[int]: - if isinstance(symint, int): - return symint - - return None - - """ - Evaluate the shape of a node. - - symint can of of type `SymInt`, `FakeTensor`, a `List[Union[FakeTensor, SymInt]]`, or `None` - """ - if isinstance(node, FakeTensor): - return node.shape - - new_shape = [] - for _, s in enumerate(node): - new_shape.append(eval_expr(s)) - return new_shape - - @staticmethod - def preprocess( # noqa: C901 + def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], - ) -> bytes: - # C++ MPSGraph bindings. - mpsGraph = graph_bindings.MPSGraphModule() - - unaryOps = { - exir_ops.edge.aten.exp.default: mpsGraph.exp, - exir_ops.edge.aten.exp2.default: mpsGraph.exp2, - exir_ops.edge.aten.reciprocal.default: mpsGraph.reciprocal, - exir_ops.edge.aten.sqrt.default: mpsGraph.sqrt, - exir_ops.edge.aten.neg.default: mpsGraph.neg, - exir_ops.edge.aten.log.default: mpsGraph.log, - exir_ops.edge.aten.log10.default: mpsGraph.log10, - exir_ops.edge.aten.log2.default: mpsGraph.log2, - exir_ops.edge.aten.erf.default: mpsGraph.erf, - exir_ops.edge.aten.floor.default: mpsGraph.floor, - exir_ops.edge.aten.ceil.default: mpsGraph.ceil, - exir_ops.edge.aten.rsqrt.default: mpsGraph.rsqrt, - exir_ops.edge.aten.sigmoid.default: mpsGraph.sigmoid, - exir_ops.edge.aten.sin.default: mpsGraph.sin, - exir_ops.edge.aten.sign.default: mpsGraph.sign, - exir_ops.edge.aten.cos.default: mpsGraph.cos, - exir_ops.edge.aten.tan.default: mpsGraph.tan, - exir_ops.edge.aten.abs.default: mpsGraph.abs, - exir_ops.edge.aten.asin.default: mpsGraph.asin, - exir_ops.edge.aten.acos.default: mpsGraph.acos, - exir_ops.edge.aten.atan.default: mpsGraph.atan, - exir_ops.edge.aten.sinh.default: mpsGraph.sinh, - exir_ops.edge.aten.cosh.default: mpsGraph.cosh, - exir_ops.edge.aten.tanh.default: mpsGraph.tanh, - exir_ops.edge.aten.asinh.default: mpsGraph.asinh, - exir_ops.edge.aten.acosh.default: mpsGraph.acosh, - exir_ops.edge.aten.atanh.default: mpsGraph.atanh, - exir_ops.edge.aten.bitwise_not.default: mpsGraph.bitwise_not, - exir_ops.edge.aten.isnan.default: mpsGraph.isnan, - exir_ops.edge.aten.isinf.default: mpsGraph.isinf, - exir_ops.edge.aten.round.default: mpsGraph.round, - } - - binaryOps = { - exir_ops.edge.aten.mm.default: mpsGraph.mm, - exir_ops.edge.aten.bmm.default: mpsGraph.bmm, - exir_ops.edge.aten.mul.Tensor: mpsGraph.mul, - exir_ops.edge.aten.div.Tensor: mpsGraph.div, - exir_ops.edge.aten.div.Tensor_mode: mpsGraph.div, - exir_ops.edge.aten.floor_divide.default: mpsGraph.floor_divide, - exir_ops.edge.aten.fmod.Tensor: mpsGraph.fmod, - exir_ops.edge.aten.remainder.Tensor: mpsGraph.remainder, - exir_ops.edge.aten.bitwise_and.Tensor: mpsGraph.bitwise_and, - exir_ops.edge.aten.bitwise_or.Tensor: mpsGraph.bitwise_or, - exir_ops.edge.aten.bitwise_xor.Tensor: mpsGraph.bitwise_xor, - exir_ops.edge.aten.eq.Tensor: mpsGraph.eq, - exir_ops.edge.aten.ne.Tensor: mpsGraph.ne, - exir_ops.edge.aten.ge.Tensor: mpsGraph.ge, - exir_ops.edge.aten.gt.Tensor: mpsGraph.gt, - exir_ops.edge.aten.le.Tensor: mpsGraph.le, - exir_ops.edge.aten.lt.Tensor: mpsGraph.lt, - exir_ops.edge.aten.pow.Tensor_Tensor: mpsGraph.pow, - exir_ops.edge.aten.minimum.default: mpsGraph.minimum, - } - - binaryOpsWithScalar = { - exir_ops.edge.aten.mul.Scalar: mpsGraph.mulWithScalar, - exir_ops.edge.aten.remainder.Scalar: mpsGraph.remainder, - exir_ops.edge.aten.eq.Scalar: mpsGraph.eq, - exir_ops.edge.aten.ne.Scalar: mpsGraph.ne, - exir_ops.edge.aten.ge.Scalar: mpsGraph.ge, - exir_ops.edge.aten.gt.Scalar: mpsGraph.gt, - exir_ops.edge.aten.le.Scalar: mpsGraph.le, - exir_ops.edge.aten.lt.Scalar: mpsGraph.lt, - exir_ops.edge.aten.bitwise_and.Scalar: mpsGraph.bitwise_and, - exir_ops.edge.aten.bitwise_or.Scalar: mpsGraph.bitwise_or, - exir_ops.edge.aten.bitwise_xor.Scalar: mpsGraph.bitwise_xor, - exir_ops.edge.aten.pow.Tensor_Scalar: mpsGraph.pow, + ) -> PreprocessResult: + # The EdgeIR nodes are processed in the following order: + # 1. Process first the input feeds to the graph (in the same + # order as args from forward(*args)), and generate a unique + # id for each input placeholder. Each input id is appended to + # `input_ids` array from the FlatBuffer schema. + # 2. Process the nodes the graph (e.g `call_function`). For each + # EdgeIR node, create an equivalent MPS node in the FlatBuffer, + # based on which the MPSGraph is constructed at runtime. During + # this process, any visited constant in the EdgeIR is added to the + # final MPS FlatBuffer schema. Each constant id is appended to the + # `constant_ids` FlatBuffer schema. + # 3. After all the inputs, nodes and constants are added to the + # FlatBuffer graph, process the `output` nodes and add their id to + # the `output_ids` array in the schema. + + mps_graph = MPSGraph( + version="0", + mps_nodes=[], + mps_values=[], + input_ids=[], + output_ids=[], + constant_ids=[], + ) + + convert_model_to_fp16 = True + for spec in compile_specs: + if spec.key == "use_fp16": + convert_model_to_fp16 = bool(list(bytes(spec.value))[0]) + + logging.debug(f"Convert model to FP16: {convert_model_to_fp16}") + + node_visitors = get_node_visitors(edge_program, convert_model_to_fp16) + if logging.DEBUG >= logging.root.level: + edge_program.graph.print_tabular() + + process_placeholder_nodes( + edge_program, + edge_program.graph_module, + mps_graph, + node_visitors["placeholder"], + ) + + op_handler = { + "call_function": MPSBackend.handle_call_function, + "placeholder": MPSBackend.handle_placeholder, + "output": MPSBackend.handle_output, + "get_attr": MPSBackend.handle_get_attr, } - # `graph_nodes` dictionary is made out of : - graphNodes: Dict[str, Any] = {} - - for node in edge_program.graph.nodes: - if node.op == "get_attr": - attr = MPSBackend.fetch_attr(node, edge_program) - graphNodes[node.name] = create_mpsgraph_constant_tensor( - tensor=attr, mpsGraph=mpsGraph - ) - - # Handle inputs to the graph. - elif node.op == "placeholder": - # Check if this is a lifted parameter / buffer - # If so, bundle the constants in the graph instead of creating placeholders - lifted_param_or_buffer = get_param_from_node(node, edge_program) - if lifted_param_or_buffer is not None: - graphNodes[node.name] = create_mpsgraph_constant_tensor( - tensor=lifted_param_or_buffer, mpsGraph=mpsGraph - ) - else: - if node.meta["val"] is None: - continue - shape = MPSBackend.eval_shape(node.meta["val"]) - if shape is None: - graphNodes[node.name] = mpsGraph.mpsGraphUnrankedPlaceHolder( - get_mps_data_type(node.meta["val"].dtype) - ) - else: - graphNodes[node.name] = mpsGraph.mpsGraphRankedPlaceHolder( - get_mps_data_type(node.meta["val"].dtype), shape - ) - - # Handle `call_function` calls. - elif node.op == "call_function": - if node.target == exir_ops.edge.aten.mm.default: - graphNodes[node.name] = mpsGraph.mm( - graphNodes[node.args[0].name], graphNodes[node.args[1].name] - ) - elif node.target == exir_ops.edge.aten.bmm.default: - graphNodes[node.name] = mpsGraph.bmm( - graphNodes[node.args[0].name], graphNodes[node.args[1].name] - ) - elif node.target == exir_ops.edge.aten.add.Tensor: - alpha = 1.0 - if node.kwargs and node.kwargs["alpha"] is not None: - alpha = node.kwargs["alpha"] - graphNodes[node.name] = mpsGraph.add( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - alpha, - ) - elif node.target == exir_ops.edge.aten.add.Scalar: - graphNodes[node.name] = mpsGraph.add( - graphNodes[node.args[0].name], node.args[1] - ) - elif node.target == exir_ops.edge.aten.sub.Tensor: - alpha = 1.0 - if node.kwargs and node.kwargs["alpha"] is not None: - alpha = node.kwargs["alpha"] - graphNodes[node.name] = mpsGraph.sub( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - alpha, - ) - elif node.target == exir_ops.edge.aten.sub.Scalar: - graphNodes[node.name] = mpsGraph.sub( - graphNodes[node.args[0].name], node.args[1] - ) - elif node.target == exir_ops.edge.aten.mul.Tensor: - graphNodes[node.name] = mpsGraph.mul( - graphNodes[node.args[0].name], graphNodes[node.args[1].name] - ) - elif node.target == exir_ops.edge.aten.mul.Scalar: - graphNodes[node.name] = mpsGraph.mulWithScalar( - graphNodes[node.args[0].name], node.args[1] - ) - elif node.target in binaryOps: - graphNodes[node.name] = binaryOps[node.target]( - graphNodes[node.args[0].name], graphNodes[node.args[1].name] - ) - elif node.target in binaryOpsWithScalar: - graphNodes[node.name] = binaryOpsWithScalar[node.target]( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.full.default: - if len(node.args) < 2: - raise AssertionError( - "Full op requires at least size & fill_value args" - ) - dtype = get_mps_data_type(torch.float32) - if len(node.args) >= 3: - dtype = get_mps_data_type(node.args[2]) - if len(node.args) >= 4: - raise AssertionError("Unexpected number of input parameters") - graphNodes[node.name] = mpsGraph.full( - node.args[0], node.args[1], dtype - ) - - elif node.target == exir_ops.edge.aten.full_like.default: - if len(node.args) < 2: - raise AssertionError("Too few input parameters") - graphNodes[node.name] = mpsGraph.full_like( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.convolution.default: - from typing import cast - - input_node = cast(torch.fx.Node, node.args[0]).meta["val"] - sizes = input_node.size() - dim0 = sizes[0] - dim1 = sizes[1] - groups = int(node.args[8]) - group_in_channels = dim1 - group_out_channels = int(dim0 / groups) - - # Convolution is depthwise if groups = input channels and output channel - # is a positive multiple of input channels - is_depthwise_conv = (group_in_channels == 1) and ( - group_out_channels % group_in_channels == 0 - ) - - if node.args[2] is None: - graphNodes[node.name] = mpsGraph.conv2D( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - None, - node.args[3], - node.args[4], - node.args[5], - node.args[6], - node.args[7], - node.args[8], - is_depthwise_conv, - ) - else: - graphNodes[node.name] = mpsGraph.conv2D( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - graphNodes[node.args[2].name], - node.args[3], - node.args[4], - node.args[5], - node.args[6], - node.args[7], - node.args[8], - is_depthwise_conv, - ) - - elif node.target == exir_ops.edge.aten.max_pool2d_with_indices.default: - n_args = len(node.args) - if n_args > 6: - raise AssertionError("Unexpected number of input parameters") - - padding = [0, 0] - dilation = [1, 1] - ceil_mode = False - if n_args >= 4: - padding = node.args[3] - if n_args >= 5: - dilation = node.args[4] - if n_args == 6: - ceil_mode = node.args[5] - - graphNodes[node.name] = mpsGraph.maxPool2DWithIndices( - graphNodes[node.args[0].name], - node.args[1], - node.args[2], - padding, - dilation, - ceil_mode, - ) - - elif node.target == exir_ops.edge.aten.avg_pool2d.default: - stride = node.args[1] - padding = [0, 0] - ceil_mode = False - count_include_pad = True - divisor_override = None - - n_args = len(node.args) - if n_args >= 3: - stride = node.args[2] - if n_args >= 4: - padding = node.args[3] - if n_args >= 5: - ceil_mode = node.args[4] - if n_args >= 6: - count_include_pad = node.args[5] - if n_args == 7: - divisor_override = node.args[6] - if n_args > 7: - raise AssertionError("Unexpected number of arguments") - - graphNodes[node.name] = mpsGraph.avgPool2D( - graphNodes[node.args[0].name], - node.args[1], - stride, - padding, - ceil_mode, - count_include_pad, - divisor_override, - ) - - elif ( - node.target - == exir_ops.edge.aten._native_batch_norm_legit_no_training.default - ): - graphNodes[node.name] = mpsGraph.batchNorm( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - graphNodes[node.args[2].name], - graphNodes[node.args[3].name], - graphNodes[node.args[4].name], - node.args[5], - node.args[6], - ) - - elif node.target == exir_ops.edge.aten.native_layer_norm.default: - graphNodes[node.name] = mpsGraph.layerNorm( - graphNodes[node.args[0].name], - node.args[1], - graphNodes[node.args[2].name], - graphNodes[node.args[3].name], - node.args[4], - ) - - elif node.target == exir_ops.edge.aten.hardtanh.default: - graphNodes[node.name] = mpsGraph.hardTanh( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - - elif node.target == exir_ops.edge.aten.relu.default: - graphNodes[node.name] = mpsGraph.relu(graphNodes[node.args[0].name]) - - elif node.target == exir_ops.edge.aten.leaky_relu.default: - n_args = len(node.args) - if n_args > 2: - raise AssertionError("Unexpected number of input parameters") - - negative_slope = 0.01 - if n_args == 2: - negative_slope = node.args[1] - graphNodes[node.name] = mpsGraph.leaky_relu( - graphNodes[node.args[0].name], negative_slope - ) - - elif node.target == exir_ops.edge.aten.gelu.default: - approximate = "none" - if len(node.args) > 1: - approximate = node.args[1] - graphNodes[node.name] = mpsGraph.gelu( - graphNodes[node.args[0].name], approximate - ) - - elif node.target == exir_ops.edge.aten.glu.default: - dim = -1 - if len(node.args) > 1: - dim = node.args[1] - graphNodes[node.name] = mpsGraph.glu( - graphNodes[node.args[0].name], dim - ) - - elif node.target == exir_ops.edge.aten.index_select.default: - dim = node.args[1] - index = graphNodes[node.args[2].name] - graphNodes[node.name] = mpsGraph.index_select( - graphNodes[node.args[0].name], dim, index - ) - - elif node.target == exir_ops.edge.aten._softmax.default: - graphNodes[node.name] = mpsGraph.softmax( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - - elif node.target == exir_ops.edge.aten._log_softmax.default: - graphNodes[node.name] = mpsGraph.log_softmax( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - - elif node.target == exir_ops.edge.aten._to_copy.default: - graphNodes[node.name] = mpsGraph.identity( - graphNodes[node.args[0].name] - ) - - elif node.target == exir_ops.edge.aten.min.dim: - keep_dim = False - if len(node.args) == 3: - keep_dim = node.args[2] - graphNodes[node.name] = mpsGraph.minDim( - graphNodes[node.args[0].name], node.args[1], keep_dim - ) - - elif node.target == exir_ops.edge.aten.max.dim: - keep_dim = False - if len(node.args) == 3: - keep_dim = node.args[2] - graphNodes[node.name] = mpsGraph.maxDim( - graphNodes[node.args[0].name], node.args[1], keep_dim - ) - - elif node.target == exir_ops.edge.aten.amax.default: - if len(node.args) == 2: - graphNodes[node.name] = mpsGraph.amax( - graphNodes[node.args[0].name], node.args[1], False - ) - elif len(node.args) == 3: - graphNodes[node.name] = mpsGraph.amax( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - else: - raise AssertionError("Unexpected number of input parameters") - - elif node.target == exir_ops.edge.aten.amin.default: - if len(node.args) == 2: - graphNodes[node.name] = mpsGraph.amin( - graphNodes[node.args[0].name], node.args[1], False - ) - elif len(node.args) == 3: - graphNodes[node.name] = mpsGraph.amin( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - else: - raise AssertionError("Unexpected number of input parameters") - - elif node.target == exir_ops.edge.aten.argmax.default: - n_args = len(node.args) - if n_args == 1: - graphNodes[node.name] = mpsGraph.argmax( - graphNodes[node.args[0].name], 0, False, True - ) - elif len(node.args) == 2: - graphNodes[node.name] = mpsGraph.argmax( - graphNodes[node.args[0].name], node.args[1], False, False - ) - elif len(node.args) == 3: - graphNodes[node.name] = mpsGraph.argmax( - graphNodes[node.args[0].name], - node.args[1], - node.args[2], - False, - ) - else: - raise AssertionError("Unexpected number of input parameters") - - elif node.target == exir_ops.edge.aten.argmin.default: - n_args = len(node.args) - if n_args == 1: - graphNodes[node.name] = mpsGraph.argmin( - graphNodes[node.args[0].name], 0, False, True - ) - elif len(node.args) == 2: - graphNodes[node.name] = mpsGraph.argmin( - graphNodes[node.args[0].name], node.args[1], False, False - ) - elif len(node.args) == 3: - graphNodes[node.name] = mpsGraph.argmin( - graphNodes[node.args[0].name], - node.args[1], - node.args[2], - False, - ) - else: - raise AssertionError("Unexpected number of input parameters") - - elif node.target == exir_ops.edge.aten.mean.dim: - if len(node.args) == 2: - graphNodes[node.name] = mpsGraph.mean( - graphNodes[node.args[0].name], node.args[1], False - ) - elif len(node.args) == 3: - graphNodes[node.name] = mpsGraph.mean( - graphNodes[node.args[0].name], node.args[1], node.args[2] - ) - else: - raise AssertionError("Unexpected number of input parameters") - - elif node.target == exir_ops.edge.aten.pixel_shuffle.default: - torch._assert( - len(node.args) == 2, "Unexpected number of input parameters" - ) - graphNodes[node.name] = mpsGraph.pixel_shuffle( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.split_with_sizes_copy.default: - dim = 0 - torch._assert( - len(node.args) >= 2, "Unexpected number of input parameters" - ) - split_sizes = node.args[1] - if len(node.args) >= 2: - dim = node.args[2] - graphNodes[node.name] = mpsGraph.split_size( - graphNodes[node.args[0].name], split_sizes, dim - ) - - elif node.target == exir_ops.edge.aten.split_copy.Tensor: - dim = 0 - torch._assert( - len(node.args) >= 2, "Unexpected number of input parameters" - ) - split_sizes = node.args[1] - if len(node.args) >= 3: - dim = node.args[2] - graphNodes[node.name] = mpsGraph.split( - graphNodes[node.args[0].name], split_sizes, dim - ) - - elif node.target == exir_ops.edge.aten.unbind_copy.int: - dim = 0 - if len(node.args) >= 2: - dim = node.args[1] - graphNodes[node.name] = mpsGraph.unbind( - graphNodes[node.args[0].name], dim - ) - - elif node.target == exir_ops.edge.aten.stack.default: - stackTensors = [] - dim = 0 - if len(node.args) > 1: - dim = node.args[1] - for inputTensor in node.args[0]: - stackTensors.append(graphNodes[inputTensor.name]) - graphNodes[node.name] = mpsGraph.stack(dim, *stackTensors) - - elif node.target == exir_ops.edge.aten.cat.default: - catTensors = [] - dim = 0 - if len(node.args) > 1: - dim = node.args[1] - for inputTensor in node.args[0]: - catTensors.append(graphNodes[inputTensor.name]) - graphNodes[node.name] = mpsGraph.cat(dim, *catTensors) - - elif node.target == exir_ops.edge.aten.slice_copy.Tensor: - dim = 0 - step = 1 - start = None - end = None - if len(node.args) >= 2: - dim = node.args[1] - if len(node.args) >= 4: - end = node.args[3] - start = node.args[2] - if len(node.args) >= 5: - step = node.args[4] - graphNodes[node.name] = mpsGraph.slice( - graphNodes[node.args[0].name], dim, start, end, step - ) - - elif node.target == exir_ops.edge.aten.expand_copy.default: - graphNodes[node.name] = mpsGraph.expand( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.view_copy.default: - graphNodes[node.name] = mpsGraph.view( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.clone.default: - # Per exir documentation the executorch memory format and layout is still WIP - # So we'll assume the memory layout option of clone is to be ignored - # TODO: adjust once the memory format and layout have been finalized - graphNodes[node.name] = mpsGraph.identity( - graphNodes[node.args[0].name] - ) - - elif node.target == exir_ops.edge.aten.select_copy.int: - idx = torch.sym_int(node.args[2]) - graphNodes[node.name] = mpsGraph.select( - graphNodes[node.args[0].name], node.args[1], idx - ) - - elif node.target == exir_ops.edge.aten.permute_copy.default: - graphNodes[node.name] = mpsGraph.permute( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.squeeze_copy.default: - graphNodes[node.name] = mpsGraph.squeeze( - graphNodes[node.args[0].name] - ) - - elif node.target == exir_ops.edge.aten.squeeze_copy.dim: - graphNodes[node.name] = mpsGraph.squeeze( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.squeeze_copy.dims: - graphNodes[node.name] = mpsGraph.squeeze( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.unsqueeze_copy.default: - graphNodes[node.name] = mpsGraph.unsqueeze( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target == exir_ops.edge.aten.constant_pad_nd.default: - padding = node.args[1] - c_value = node.args[2] - graphNodes[node.name] = mpsGraph.constant_pad_nd( - graphNodes[node.args[0].name], padding, c_value - ) - - elif node.target == exir_ops.edge.aten.addmm.default: - beta = 1.0 - alpha = 1.0 - if len(node.args) == 4: - beta = node.args[3] - if len(node.args) == 5: - alpha = node.args[4] - graphNodes[node.name] = mpsGraph.addmm( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - graphNodes[node.args[2].name], - beta, - alpha, - ) - elif node.target == exir_ops.edge.aten.clamp.default: - min_value = 0.0 - use_min = False - if len(node.args) >= 2 and node.args[1] is not None: - min_value = node.args[1] - use_min = True - - max_value = 1.0 - use_max = False - if len(node.args) >= 3 and node.args[2] is not None: - max_value = node.args[2] - use_max = True - graphNodes[node.name] = mpsGraph.clamp( - graphNodes[node.args[0].name], - min_value, - max_value, - use_min, - use_max, - ) - - elif node.target == exir_ops.edge.aten.cumsum.default: - if len(node.args) != 2: - raise AssertionError("Unexpected number of input parameters") - graphNodes[node.name] = mpsGraph.cumsum( - graphNodes[node.args[0].name], node.args[1] - ) - - elif node.target in unaryOps: - graphNodes[node.name] = unaryOps[node.target]( - graphNodes[node.args[0].name] - ) - - elif node.target == exir_ops.edge.aten.arange.start_step: - step = 1 - if len(node.args) > 2 and node.args[2] is not None: - step = node.args[2] - dtype = get_mps_data_type(node.meta["val"].dtype) - shape = node.meta["val"].shape[0] - graphNodes[node.name] = mpsGraph.arange( - node.args[0], node.args[1], step, dtype, shape - ) + for node in edge_program.graph_module.graph.nodes: + if node.op not in op_handler: + raise RuntimeError(f"{node.op} is not supported in MPS") + else: + op_handler[node.op](edge_program, node_visitors, node, mps_graph) - elif node.target == exir_ops.edge.aten.where.self: - graphNodes[node.name] = mpsGraph.where( - graphNodes[node.args[0].name], - graphNodes[node.args[1].name], - graphNodes[node.args[2].name], - ) - elif node.target == exir_ops.edge.aten.scalar_tensor.default: - graphNodes[node.name] = mpsGraph.scalar_out( - node.args[0], get_mps_data_type(node.meta["val"].dtype) - ) + if logging.DEBUG >= logging.root.level: + pretty_print(mps_graph) - elif node.target == exir_ops.edge.aten.empty.memory_format: - dtype = get_mps_data_type(torch.float32) - if len(node.args) >= 2: - dtype = get_mps_data_type(node.args[1]) - graphNodes[node.name] = mpsGraph.empty(node.args[0], dtype) - elif node.target == exir_ops.edge.aten.embedding.default: - if len(node.args) == 2: - graphNodes[node.name] = mpsGraph.index_select( - graphNodes[node.args[0].name], - 0, - graphNodes[node.args[1].name], - ) - elif len(node.args) > 2 and node.args[2] is not None: - r1 = mpsGraph.unsqueeze( - mpsGraph.ne(graphNodes[node.args[1].name], node.args[2]), -1 - ) - r2 = mpsGraph.index_select( - graphNodes[node.args[0].name], - 0, - graphNodes[node.args[1].name], - ) - graphNodes[node.name] = mpsGraph.where( - r1, r2, mpsGraph.full_like(r2, 0) - ) - # Cant check for target with getitem - # Arg[0] target node, arg[1] target index - elif "getitem" in node.name: - graphNodes[node.name] = graphNodes[node.args[0].name][node.args[1]] - else: - raise AssertionError(f"Unknown op: {node.target}") + return PreprocessResult(processed_bytes=convert_to_flatbuffer(mps_graph)) - # Handle `call_method` calls. - elif node.op == "call_method": - raise AssertionError("Not yet implemented") + @staticmethod + def handle_call_function( + _: ExportedProgram, + node_visitors: Dict[str, NodeVisitor], + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + logging.info(f"Visiting: {node}, {node.target.__name__}") + if node.target.__name__ in node_visitors: + node_visitors[node.target.__name__].define_node(node, mps_graph) + else: + pretty_print(mps_graph) + raise RuntimeError( + f"For {node}, {node.op}:{node.target.__name__} is not supported in MPS delegate" + ) - # Handle `call_method` calls. - elif node.op == "call_module": - raise AssertionError("Not yet implemented") + @staticmethod + def handle_placeholder( + edge_program: ExportedProgram, + node_visitors: Dict[str, NodeVisitor], + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # Handle only constants. Placeholders have already + # been visited in `process_input_placeholders` + if is_parameter(edge_program, node): + node_visitors[node.op].define_tensor(node, mps_graph) - # Handle output nodes in the graph. - elif node.op == "output": - output_nodes = [] - for i in range(len(node.args)): - for j in range(len(node.args[i])): - output_nodes.append(graphNodes[node.args[i][j].name]) - mpsGraph.set_outputs(*output_nodes) - else: - torch._assert( - False, - f"Unsupported operator: {node.op}, {node.name}, {node.target}", - ) + @staticmethod + def handle_output( + edge_program: ExportedProgram, + node_visitors: Dict[str, NodeVisitor], + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + for output_nodes in node.args: + for output_node in output_nodes: + process_output_node(output_node, mps_graph, node_visitors[node.op]) - mpsGraphExecutableBytes = mpsGraph.serialize() - return PreprocessResult(processed_bytes=bytes(mpsGraphExecutableBytes)) + @staticmethod + def handle_get_attr( + edge_program: ExportedProgram, + node_visitors: Dict[str, NodeVisitor], + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + pass + + +def tensor_to_str(mps_tensor: MPSTensor): + tensor_str = "MPSTensor(" + tensor_str += "datatype=" + str(mps_tensor.datatype) + ", " + tensor_str += "num_dims=" + str(mps_tensor.num_dims) + ", " + tensor_str += "dims=" + str(mps_tensor.dims) + ", " + tensor_str += "constant_buffer_size=" + str(mps_tensor.constant_buffer_size) + tensor_str += ")" + + return tensor_str + + +def pretty_print(mps_graph: MPSGraph): + logging.info("Serialized MPSGraph:") + logging.info(f" Version: {mps_graph.version}") + logging.info(" MPS nodes: ") + for i in range(len(mps_graph.mps_nodes)): + logging.info(f" [{i}]: {mps_graph.mps_nodes[i]}") + logging.info(" MPS values: ") + for i in range(len(mps_graph.mps_values)): + logging.info(f" [{i}]: {tensor_to_str(mps_graph.mps_values[i])}") + logging.info(" Input ids:") + for in_id in mps_graph.input_ids: + logging.info(f" {in_id}") + logging.info(" Constant ids:") + for constant_id in mps_graph.constant_ids: + logging.info(f" {constant_id}") + logging.info(" Output ids:") + for out_id in mps_graph.output_ids: + logging.info(f" {out_id}") diff --git a/backends/apple/mps/operations/ActivationOps.mm b/backends/apple/mps/operations/ActivationOps.mm deleted file mode 100644 index 214d438caab..00000000000 --- a/backends/apple/mps/operations/ActivationOps.mm +++ /dev/null @@ -1,165 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::relu(MPSGraphTensor* inputTensor) { - return [mpsGraph reLUWithTensor:inputTensor - name:@"relu"]; - -} - -PyMPSGraphTensor* -MPSGraphModule::leaky_relu(MPSGraphTensor* inputTensor, float negative_slope) { - return [mpsGraph leakyReLUWithTensor:inputTensor - alpha:negative_slope - name:@"leaky_relu"]; - -} - -MPSGraphTensor* tanh(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto dataType = [inputTensor dataType]; - constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5; - constexpr float kKappa = 0.044715f; - MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: kappaf - name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: inputTensor - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: betaf - name : nil]; - erfTensor = [mpsGraph tanhWithTensor: erfTensor - name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: onef - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: halff - name : nil]; - return erfTensor; -} - -MPSGraphTensor* normcdf(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { - auto dataType = [inputTensor dataType]; - const float SQRT1_2 = 0.707106781186547524400844362104849039f; - MPSGraphTensor *sqrt1_2 = [mpsGraph constantWithScalar: SQRT1_2 - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f - shape: @[@1] - dataType: dataType]; - MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f - shape: @[@1] - dataType: dataType]; - - MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor - secondaryTensor: sqrt1_2 - name : nil]; - erfTensor = [mpsGraph erfWithTensor: erfTensor name : nil]; - erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor - secondaryTensor: onef - name : nil]; - erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor - secondaryTensor: halff - name : nil]; - - return erfTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::gelu(MPSGraphTensor* inputTensor, - const std::string &approximation=nil) { - MPSGraphTensor* result; - if (approximation == "tanh") { - result = tanh(mpsGraph, inputTensor); - } else { - result = normcdf(mpsGraph, inputTensor); - } - return [mpsGraph multiplicationWithPrimaryTensor:result - secondaryTensor:inputTensor - name:nil]; -} - -PyMPSGraphTensor* -MPSGraphModule::softmax(MPSGraphTensor* inputTensor, const int dim, const bool half_to_float) { - TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); - return [mpsGraph softMaxWithTensor:inputTensor - axis:dim - name:@"softmax"]; -} - -PyMPSGraphTensor* -MPSGraphModule::log_softmax(MPSGraphTensor* inputTensor, const int dim, const bool half_to_float) { - TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on MPS"); - MPSGraphTensor* softmaxTensor = [mpsGraph softMaxWithTensor:inputTensor - axis:dim - name:@"softmax"]; - return [mpsGraph logarithmWithTensor:softmaxTensor - name:@"log_softmax"]; -} - -PyMPSGraphTensor* -MPSGraphModule::hardTanh(MPSGraphTensor* inputTensor, - float min_value, - float max_value) { - MPSDataType inputType = [inputTensor dataType]; - MPSShape* inputShape = [inputTensor shape]; - MPSGraphTensor* minTensor = [mpsGraph constantWithScalar:min_value shape:inputShape dataType:inputType]; - MPSGraphTensor* maxTensor = [mpsGraph constantWithScalar:max_value shape:inputShape dataType:inputType]; - MPSGraphTensor* lessThanMinPredicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor - secondaryTensor:minTensor - name:@"LessThanPredicate"]; - MPSGraphTensor* greaterThanMaxPredicateTensor = [mpsGraph greaterThanWithPrimaryTensor:inputTensor - secondaryTensor:maxTensor - name:@"MoreThanPredicate"]; - - MPSGraphTensor* temp = [mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor - truePredicateTensor:minTensor - falsePredicateTensor:inputTensor - name:@"minOutput"]; - MPSGraphTensor* result = [mpsGraph selectWithPredicateTensor:greaterThanMaxPredicateTensor - truePredicateTensor:maxTensor - falsePredicateTensor:temp - name:@"hardTanh"]; - return result; -} - -PyMPSGraphTensor* -MPSGraphModule::glu(MPSGraphTensor* inputTensor, int64_t dim) { - auto wrap_dim = maybe_wrap_dim(dim, inputTensor.shape.count); - auto splitTensors = [mpsGraph splitTensor:inputTensor - numSplits:2 - axis:wrap_dim - name:nil]; - return [mpsGraph multiplicationWithPrimaryTensor:splitTensors[0] - secondaryTensor:[mpsGraph sigmoidWithTensor:splitTensors[1] name:nil] - name:nil]; -} - -}//namespace mps diff --git a/backends/apple/mps/operations/BinaryOps.h b/backends/apple/mps/operations/BinaryOps.h deleted file mode 100644 index 81759b3493f..00000000000 --- a/backends/apple/mps/operations/BinaryOps.h +++ /dev/null @@ -1,70 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// -// clang-format off -#pragma once - -#define REGISTER_PYBIND11_MPS_BINARY_OP(py11_export_name, graph_op) \ -.def(py11_export_name, [](MPSGraphModule& self, PyMPSGraphTensor* input, PyMPSGraphTensor* other) { \ -return self.binaryOpTensor( \ - static_cast(input), static_cast(other), py11_export_name, \ - [&](MPSGraphTensor* primaryCastTensor, MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ - return [self.getMPSGraph() graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - }); \ -}) \ -.def(py11_export_name, [](MPSGraphModule& self, PyMPSGraphTensor* input, float sc) { \ -return self.binaryOpWithScalar( \ - static_cast(input), sc, py11_export_name, \ - [&](MPSGraphTensor* primaryCastTensor, MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ - MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ - return [self.getMPSGraph() graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - }); \ -}) \ -.def(py11_export_name, [](MPSGraphModule& self, PyMPSGraphTensor* input, int sc) { \ -return self.binaryOpWithScalar( \ - static_cast(input), sc, py11_export_name, \ - [&](MPSGraphTensor* primaryCastTensor, MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ - MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ - return [self.getMPSGraph() graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - }); \ -}) \ - -#define REGISTER_PYBIND11_MPS_BITWISE_BINARY_OP(py11_export_name, graph_op) \ -.def(py11_export_name, [](MPSGraphModule& self, PyMPSGraphTensor* input, PyMPSGraphTensor* other) { \ -return self.binaryOpTensor( \ - static_cast(input), static_cast(other), py11_export_name, \ - [&](MPSGraphTensor* primaryCastTensor, MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ - MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ - if (getScalarType(mpsInputDataType) == ScalarType::Bool) { \ - return [self.getMPSGraph() logical##graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - } \ - return [self.getMPSGraph() bitwise##graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - }); \ -}) \ -.def(py11_export_name, [](MPSGraphModule& self, PyMPSGraphTensor* input, int sc) { \ -return self.binaryOpWithScalar( \ - static_cast(input), sc, py11_export_name, \ - [&](MPSGraphTensor* primaryCastTensor, MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ - MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ - if (getScalarType(mpsInputDataType) == ScalarType::Bool) { \ - return [self.getMPSGraph() logical##graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - } \ - return [self.getMPSGraph() bitwise##graph_op##WithPrimaryTensor:primaryCastTensor \ - secondaryTensor:secondaryCastTensor \ - name:nil]; \ - }); \ -}) \ -// clang-format on diff --git a/backends/apple/mps/operations/BinaryOps.mm b/backends/apple/mps/operations/BinaryOps.mm deleted file mode 100644 index 190cc79d61e..00000000000 --- a/backends/apple/mps/operations/BinaryOps.mm +++ /dev/null @@ -1,222 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -#include "BinaryOps.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::binaryOpTensor( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - const std::string& op_name, - std::function binaryOpFunction) { - MPSDataType mpsInputDataType = [primaryTensor dataType]; - MPSDataType mpsOtherDataType = [secondaryTensor dataType]; - - ScalarType inputDataType = getScalarType(mpsInputDataType); - ScalarType otherDataType = getScalarType(mpsOtherDataType); - - MPSGraphTensor* primaryCastTensor = primaryTensor; - MPSGraphTensor* secondaryCastTensor = secondaryTensor; - ScalarType common_dtype = c10::promoteTypes(inputDataType, otherDataType); - if (inputDataType != common_dtype) { - primaryCastTensor = castMPSTensor(mpsGraph, primaryTensor, common_dtype); - } - if (otherDataType != common_dtype) { - secondaryCastTensor = castMPSTensor(mpsGraph, secondaryTensor, common_dtype); - } - - return binaryOpFunction(primaryCastTensor, secondaryCastTensor); -} - -PyMPSGraphTensor* -MPSGraphModule::additionWithTensor(MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - Scalar alpha) { - MPSGraphTensor* primaryCastTensor = primaryTensor; - MPSGraphTensor* secondaryCastTensor = secondaryTensor; - auto _alpha = alpha.isFloatingPoint() ? alpha.to() : alpha.to(); - - ScalarType primaryDataType = getScalarType(primaryCastTensor.dataType); - ScalarType secondaryDataType = getScalarType(secondaryCastTensor.dataType); - - MPSDataType commonDataType = getMPSDataType(c10::promoteTypes(primaryDataType, secondaryDataType)); - - if(primaryCastTensor.dataType != commonDataType) { - primaryCastTensor = [mpsGraph castTensor:primaryCastTensor - toType:commonDataType - name:nil]; - } - - if(secondaryCastTensor.dataType != commonDataType) { - secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor - toType:commonDataType - name:nil]; - } - - if(_alpha!=1.0) { - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:_alpha - shape:@[@1] - dataType:primaryCastTensor.dataType]; - secondaryCastTensor = [mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor - secondaryTensor:alphaTensor - name:nil]; - } - MPSGraphTensor* resultTensor = [mpsGraph additionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return resultTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::subtractionWithTensor(MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - Scalar alpha) { - MPSGraphTensor* primaryCastTensor = primaryTensor; - MPSGraphTensor* secondaryCastTensor = secondaryTensor; - auto _alpha = alpha.isFloatingPoint() ? alpha.to() : alpha.to(); - - ScalarType primaryDataType = getScalarType(primaryCastTensor.dataType); - ScalarType secondaryDataType = getScalarType(secondaryCastTensor.dataType); - - MPSDataType commonDataType = getMPSDataType(c10::promoteTypes(primaryDataType, secondaryDataType)); - - if(primaryCastTensor.dataType != commonDataType) { - primaryCastTensor = [mpsGraph castTensor:primaryCastTensor - toType:commonDataType - name:nil]; - } - - if(secondaryCastTensor.dataType != commonDataType) { - secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor - toType:commonDataType - name:nil]; - } - - if(_alpha!=1.0) { - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:_alpha - shape:@[@1] - dataType:primaryCastTensor.dataType]; - secondaryCastTensor = [mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor - secondaryTensor:alphaTensor - name:nil]; - } - MPSGraphTensor* resultTensor = [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return resultTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::multiplicationWithScalar(MPSGraphTensor* inputTensor, Scalar scalar) { - auto value = scalar.isFloatingPoint() ? scalar.to() : scalar.to(); - MPSGraphTensor* constantTensor = [mpsGraph constantWithScalar:value - shape:@[@1] - dataType:inputTensor.dataType]; - MPSGraphTensor* resultTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:constantTensor - name:nil]; - return resultTensor; -} - -MPSGraphTensor* -MPSGraphModule::trunc_tensor(MPSGraphTensor* inputTensor) { - // Rounding is a no-op for integral types, and also a reasonable workaround - // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` - // See https://github.com/pytorch/pytorch/issues/84995 - bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0; - if (!isFloatInput) { - return inputTensor; - } - - return [mpsGraph truncateWithTensor:inputTensor - name:nil]; -}; - -PyMPSGraphTensor* -MPSGraphModule::div_mode_template( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - c10::optional rounding_mode, - const string& op_name) { - MPSDataType mpsInputDataType = [primaryTensor dataType]; - MPSDataType mpsOtherDataType = [secondaryTensor dataType]; - - ScalarType inputDataType = getScalarType(mpsInputDataType); - ScalarType otherDataType = getScalarType(mpsOtherDataType); - - if(rounding_mode.has_value() && *rounding_mode == "trunc"){ - TORCH_CHECK(inputDataType != ScalarType::Half, - "MPS: does not support trunc_divide op with float16 input"); - } - - auto divOpFunc = [&](MPSGraphTensor* primaryCastTensor, - MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { - bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; - if(!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { - primaryCastTensor = [mpsGraph castTensor:primaryCastTensor - toType:MPSDataTypeFloat32 - name:@"primaryCastTensor"]; - secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor - toType:MPSDataTypeFloat32 - name:@"secondaryCastTensor"]; - } - MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor - secondaryTensor:secondaryCastTensor - name:nil]; - - // Rounding is a no-op for integral types, and also a reasonable workaround - // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` - // See https://github.com/pytorch/pytorch/issues/84995 - bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0; - if (!rounding_mode.has_value() || !isFloatOutput) { - return divTensor; - } else if (*rounding_mode == "trunc") { - auto truncTensor = trunc_tensor(divTensor); - if (op_name == "fmod_mps_out") { - auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:mulTensor - name:nil]; - } - return truncTensor; - } else if (*rounding_mode == "floor") { - MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; - if (op_name == "remainder_out_mps") { - auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor - secondaryTensor:secondaryCastTensor - name:nil]; - return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor - secondaryTensor:mulTensor - name:nil]; - } - return floorTensor; - } else { - assert(0 && "Invalid rounding mode\n"); - } - return nullptr; - }; - return binaryOpTensor(primaryTensor, secondaryTensor, op_name, divOpFunc); -} - -PyMPSGraphTensor* -MPSGraphModule::binaryOpWithScalar(MPSGraphTensor *inputTensor, Scalar scalar, - const std::string &op_name, - std::function binaryOpFunction) { - auto value = scalar.isFloatingPoint() ? scalar.to() : scalar.to(); - MPSGraphTensor* constantTensor = [mpsGraph constantWithScalar:value - shape:@[@1] - dataType:inputTensor.dataType]; - return binaryOpTensor(inputTensor, constantTensor, op_name, binaryOpFunction); -} - -} // namespace mps - diff --git a/backends/apple/mps/operations/BitwiseOps.mm b/backends/apple/mps/operations/BitwiseOps.mm deleted file mode 100644 index 2ba3d9c11fb..00000000000 --- a/backends/apple/mps/operations/BitwiseOps.mm +++ /dev/null @@ -1,19 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor *MPSGraphModule::bitwiseNotTensor(MPSGraphTensor *inputTensor, - const std::string &op_name) { - MPSDataType mpsInputDataType = [inputTensor dataType]; - if (getScalarType(mpsInputDataType) == ScalarType::Bool) { - return [getMPSGraph() notWithTensor:inputTensor name:nil]; - } - return [getMPSGraph() bitwiseNOTWithTensor:inputTensor name:nil]; -} -} // namespace mps diff --git a/backends/apple/mps/operations/ClampOps.mm b/backends/apple/mps/operations/ClampOps.mm deleted file mode 100644 index a0daffa89a6..00000000000 --- a/backends/apple/mps/operations/ClampOps.mm +++ /dev/null @@ -1,58 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::clamp(MPSGraphTensor* inputTensor, float min, float max, bool use_min, bool use_max) { - if(use_min && use_max) { - MPSGraphTensor* minTensor = [mpsGraph constantWithScalar:min - shape:inputTensor.shape - dataType:inputTensor.dataType]; - MPSGraphTensor* maxTensor = [mpsGraph constantWithScalar:max - shape:inputTensor.shape - dataType:inputTensor.dataType]; - return [mpsGraph clampWithTensor:inputTensor - minValueTensor:minTensor - maxValueTensor:maxTensor - name:@"clamp"]; - } else if(use_min && !use_max) { - MPSGraphTensor* minTensor = [mpsGraph constantWithScalar:min - shape:inputTensor.shape - dataType:inputTensor.dataType]; - return [mpsGraph maximumWithPrimaryTensor:inputTensor - secondaryTensor:minTensor - name:nil]; - } else if(!use_min && use_max) { - MPSGraphTensor* maxTensor = [mpsGraph constantWithScalar:max - shape:inputTensor.shape - dataType:inputTensor.dataType]; - return [mpsGraph minimumWithPrimaryTensor:inputTensor - secondaryTensor:maxTensor - name:nil]; - } - - //For the case that neither min nor max is given? Nothing in the documentation forbids this. - return inputTensor; -} - -PyMPSGraphTensor* MPSGraphModule::where(MPSGraphTensor* condition, - MPSGraphTensor* input, MPSGraphTensor* other) { - if ([condition dataType] != MPSDataTypeBool) { - condition = [mpsGraph castTensor:condition - toType:MPSDataTypeBool - name:@"condition"]; - } - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:condition - truePredicateTensor:input - falsePredicateTensor:other - name:nil]; - return outputTensor; -} - -}//namespace mps diff --git a/backends/apple/mps/operations/ConstantOps.mm b/backends/apple/mps/operations/ConstantOps.mm deleted file mode 100644 index 89c98c81762..00000000000 --- a/backends/apple/mps/operations/ConstantOps.mm +++ /dev/null @@ -1,55 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::constantWithScalar(MPSDataType dataType, const IntArrayRef& sizes, double scalar) { - TORCH_CHECK(!sizes.empty(), "No sizes passed to create a constant with scalar"); - if (sizes.back() == 0) { - // Cannot create a zero-sized dimension through mpsGraph - return nil; - } - return [mpsGraph constantWithScalar:scalar - shape:getMPSShape(sizes) - dataType:dataType]; -} - -PyMPSGraphTensor* -MPSGraphModule::full(IntArrayRef size, double scalar, MPSDataType dataType) { - if (size.back() == 0) { - // Cannot create a zero-sized dimension through mpsGraph - return nil; - } - return [mpsGraph constantWithScalar:scalar - shape:getMPSShape(size) - dataType:dataType]; -} - -PyMPSGraphTensor* -MPSGraphModule::full_like(MPSGraphTensor* inputTensor, double scalar) { - return [mpsGraph constantWithScalar:scalar - shape:inputTensor.shape - dataType:inputTensor.dataType]; -} - -PyMPSGraphTensor* -MPSGraphModule::constant(double scalar, MPSDataType dataType) { - return [mpsGraph constantWithScalar:scalar - dataType:dataType]; -} - -PyMPSGraphTensor* -MPSGraphModule::constantTensor(Tensor constant_tensor, MPSDataType dataType) { - NSData* dataBuffer = [[NSData alloc] initWithBytes:constant_tensor.data_ptr() - length:constant_tensor.nbytes()]; - return [mpsGraph constantWithData:dataBuffer - shape:getMPSShape(constant_tensor) - dataType:dataType]; -} -}//namespace mps diff --git a/backends/apple/mps/operations/ConvolutionOps.mm b/backends/apple/mps/operations/ConvolutionOps.mm deleted file mode 100644 index 029e91b404a..00000000000 --- a/backends/apple/mps/operations/ConvolutionOps.mm +++ /dev/null @@ -1,80 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::conv2D(MPSGraphTensor* primaryTensor, MPSGraphTensor* secondaryTensor, - MPSGraphTensor* biasTensor, IntArrayRef stride, - IntArrayRef padding, IntArrayRef dilation, bool transpose, - IntArrayRef outputPadding, int64_t groups, bool is_depthwise) { - - if(is_depthwise){ - MPSGraphDepthwiseConvolution2DOpDescriptor* desc = [MPSGraphDepthwiseConvolution2DOpDescriptor - descriptorWithStrideInX:stride[0] - strideInY:stride[1] - dilationRateInX:dilation[0] - dilationRateInY:dilation[1] - paddingLeft:padding[1] - paddingRight:padding[1] - paddingTop:padding[0] - paddingBottom:padding[0] - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - MPSGraphTensor* depthwiseConv2DTensor = [mpsGraph depthwiseConvolution2DWithSourceTensor:primaryTensor - weightsTensor:secondaryTensor - descriptor:desc - name:@"depthwiseConv2D"]; - //Can be a nullptr - if(biasTensor){ - //Need to add correct dimension to bias to avoid broadcasting issues - biasTensor = [mpsGraph expandDimsOfTensor:biasTensor - axes:@[@0, @2, @3] - name:nil]; - depthwiseConv2DTensor = [mpsGraph additionWithPrimaryTensor:depthwiseConv2DTensor - secondaryTensor:biasTensor - name:@"depthwiseConv2DWithBiasAdd"]; - } - - return depthwiseConv2DTensor; - } else { - MPSGraphConvolution2DOpDescriptor* desc = [MPSGraphConvolution2DOpDescriptor - descriptorWithStrideInX:stride[0] - strideInY:stride[1] - dilationRateInX:dilation[0] - dilationRateInY:dilation[1] - groups:groups - paddingLeft:padding[1] - paddingRight:padding[1] - paddingTop:padding[0] - paddingBottom:padding[0] - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW - weightsLayout:MPSGraphTensorNamedDataLayoutOIHW]; - - MPSGraphTensor* conv2DTensor = [mpsGraph convolution2DWithSourceTensor:primaryTensor - weightsTensor:secondaryTensor - descriptor:desc - name:@"conv2D"]; - - //Can be a nullptr - if(biasTensor){ - //Need to add correct dimension to bias to avoid broadcasting issues - biasTensor = [mpsGraph expandDimsOfTensor:biasTensor - axes:@[@0,@2,@3] - name:nil]; - conv2DTensor = [mpsGraph additionWithPrimaryTensor:conv2DTensor - secondaryTensor:biasTensor - name:@"conv2DWithBiasAdd"]; - } - return conv2DTensor; - } -} -} //namespace mps diff --git a/backends/apple/mps/operations/Indexing.mm b/backends/apple/mps/operations/Indexing.mm deleted file mode 100644 index 047739a9638..00000000000 --- a/backends/apple/mps/operations/Indexing.mm +++ /dev/null @@ -1,30 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::index_select(MPSGraphTensor* inputTensor, int64_t dim, MPSGraphTensor* indexTensor) { - dim = maybe_wrap_dim(dim, inputTensor.shape.count); - - MPSGraphTensor* castIndexTensor = indexTensor; - if(castIndexTensor.dataType != MPSDataTypeInt32) { - castIndexTensor = [mpsGraph castTensor:indexTensor - toType:MPSDataTypeInt32 - name:nil]; - } - - MPSGraphTensor* outputTensor = [mpsGraph gatherWithUpdatesTensor: inputTensor - indicesTensor: castIndexTensor - axis: dim - batchDimensions: 0 - name: nil]; - return outputTensor; -} - -} // namespace at::native diff --git a/backends/apple/mps/operations/LinearAlgebraOps.mm b/backends/apple/mps/operations/LinearAlgebraOps.mm deleted file mode 100644 index 88ffbc307c8..00000000000 --- a/backends/apple/mps/operations/LinearAlgebraOps.mm +++ /dev/null @@ -1,48 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::addmm(MPSGraphTensor* biasTensor, - MPSGraphTensor* inputTensor, - MPSGraphTensor* weightTensor, - const float beta, - const float alpha) { - - MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta - dataType:inputTensor.dataType]; - MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha - dataType:inputTensor.dataType]; - - if(inputTensor.shape == weightTensor.shape) { - weightTensor = [mpsGraph transposeTensor:weightTensor - dimension:0 - withDimension:1 - name:@"addmm/transposedWeightTensor"]; - } - - MPSGraphTensor* multiplyTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor - secondaryTensor:weightTensor - name:@"addmm/matmul"]; - MPSGraphTensor* alphaTimesMultiply = [mpsGraph multiplicationWithPrimaryTensor:multiplyTensor - secondaryTensor:alphaTensor - name:@"addmm/alpha*matmul"]; - MPSGraphTensor* betaBiasTensor = biasTensor; - if(beta!=0.0) { - betaBiasTensor = [mpsGraph multiplicationWithPrimaryTensor:biasTensor - secondaryTensor:betaTensor - name:@"addmm/beta*bias"]; - } - MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:alphaTimesMultiply - secondaryTensor:betaBiasTensor - name:@"addmm/beta*bias*alpha*matmul"]; - - return outputTensor; -} -}//namespace diff --git a/backends/apple/mps/operations/NormalizationOps.mm b/backends/apple/mps/operations/NormalizationOps.mm deleted file mode 100644 index 20ff998e219..00000000000 --- a/backends/apple/mps/operations/NormalizationOps.mm +++ /dev/null @@ -1,93 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -std::tuple -MPSGraphModule::batchNorm(MPSGraphTensor* inputTensor, - MPSGraphTensor* meanTensor, - MPSGraphTensor* varTensor, - MPSGraphTensor* weightTensor, - MPSGraphTensor* biasTensor, - float momentum, - float epsilon) { - - - //Shapes are NCHW so the input parameters to normalization are 1xCx1x1 - NSMutableArray* newShape = [NSMutableArray array]; - [newShape addObject:[NSNumber numberWithInt:1]]; - [newShape addObject:inputTensor.shape[1]]; - for(int i = 2; i<[inputTensor.shape count]; ++i) { - [newShape addObject:[NSNumber numberWithInt:1]]; - } - //No need for momentum since we are not training for now since not training? - MPSGraphTensor* reshapedMeanTensor = [mpsGraph reshapeTensor:meanTensor - withShape:newShape - name:nil]; - MPSGraphTensor* reshapedVarTensor = [mpsGraph reshapeTensor:varTensor - withShape:newShape - name:nil]; - MPSGraphTensor* reshapedWeightTensor = [mpsGraph reshapeTensor:weightTensor - withShape:newShape - name:nil]; - MPSGraphTensor* reshapedBiasTensor = [mpsGraph reshapeTensor:biasTensor - withShape:newShape - name:nil]; - - MPSGraphTensor* result = [mpsGraph normalizationWithTensor:inputTensor - meanTensor:reshapedMeanTensor - varianceTensor:reshapedVarTensor - gammaTensor:reshapedWeightTensor - betaTensor:reshapedBiasTensor - epsilon:epsilon - name:@"batch_norm"]; - MPSGraphTensor* saveVarTensor = [mpsGraph identityWithTensor:varTensor name:nil]; - MPSGraphTensor* saveMeanTensor = [mpsGraph identityWithTensor:meanTensor name:nil]; - - //For now just return meanTensor and varTensor assuming this isn't training - auto out_tuple = std::make_tuple(result, saveMeanTensor, saveVarTensor); - return out_tuple; -} - -//Normalizes over the last ndim=normalized_shape.size() dimensions scaling -//with weight and bias tensors if they are non-nil -std::tuple -MPSGraphModule::layerNorm(MPSGraphTensor* inputTensor, - IntArrayRef normalized_shape, - MPSGraphTensor* weightTensor, - MPSGraphTensor* biasTensor, - float eps) { - - const int input_ndim = [inputTensor.shape count]; - const int normalized_shape_ndim = normalized_shape.size(); - const int ndim_to_normalize = input_ndim-normalized_shape_ndim; - - NSMutableArray* axesArray = [NSMutableArray arrayWithCapacity:normalized_shape_ndim]; - for (const auto idx : c10::irange(ndim_to_normalize, input_ndim)) { - [axesArray addObject:[NSNumber numberWithInt:idx]]; - } - - MPSGraphTensor* meanTensor = [mpsGraph meanOfTensor:inputTensor - axes:axesArray - name:@"LayerNorm/MeanTensor"]; - - MPSGraphTensor* varianceTensor = [mpsGraph varianceOfTensor:inputTensor - meanTensor:meanTensor - axes:axesArray - name:@"LayerNorm/varianceTensor"]; - MPSGraphTensor* normalizedTensor = [mpsGraph normalizationWithTensor:inputTensor - meanTensor:meanTensor - varianceTensor:varianceTensor - gammaTensor:weightTensor - betaTensor:biasTensor - epsilon:eps - name:@"LayerNorm/resultTensor"]; - - return std::make_tuple(normalizedTensor, meanTensor, varianceTensor); -} -}//namespace mps diff --git a/backends/apple/mps/operations/PoolingOps.mm b/backends/apple/mps/operations/PoolingOps.mm deleted file mode 100644 index 4f57f409000..00000000000 --- a/backends/apple/mps/operations/PoolingOps.mm +++ /dev/null @@ -1,109 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" -#include - -namespace mps { -using namespace torch; - -std::tuple -MPSGraphModule::maxPool2DWithIndices(MPSGraphTensor* inputTensor, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode) { - - int padH = padding[0]; - int padW = padding.size() == 1 ? padH : padding[1]; - const int kH = kernel_size[0]; - const int kW = kernel_size.size() == 1 ? kH : kernel_size[1]; - const int dH = stride.empty() ? kH : stride[0]; - const int dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; - const int dilationH = dilation[0]; - const int dilationW = dilation.size() == 1 ? dilationH : dilation[1]; - - MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor - descriptorWithKernelWidth:kW - kernelHeight:kH - strideInX:dW - strideInY:dH - dilationRateInX:dilationW - dilationRateInY:dilationH - paddingLeft:padW - paddingRight:ceil_mode ? padW * dW : padW - paddingTop:padH - paddingBottom:ceil_mode ? padH * dH : padH - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; - desc.ceilMode = (padW == 0 && padH == 0) ? ceil_mode : false; - desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; - desc.returnIndicesDataType = MPSDataTypeInt32; - - NSArray* outputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:inputTensor - descriptor:desc - name:@"MaxPool2DWithIndices"]; - - - return std::make_tuple(outputs[0], outputs[1]); -} - - -PyMPSGraphTensor* -MPSGraphModule::avgPool2D(MPSGraphTensor* inputTensor, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override) { - int padH = padding[0]; - int padW = padding.size() == 1 ? padH : padding[1]; - const int kH = kernel_size[0]; - const int kW = kernel_size.size() == 1 ? kH : kernel_size[1]; - const int dH = stride.empty() ? kH : stride[0]; - const int dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; - const int dilationH = 1; - const int dilationW = 1; - - MPSGraphPooling2DOpDescriptor* desc = [MPSGraphPooling2DOpDescriptor - descriptorWithKernelWidth:kW - kernelHeight:kH - strideInX:dW - strideInY:dH - dilationRateInX:dilationW - dilationRateInY:dilationH - paddingLeft:padW - paddingRight:ceil_mode ? padW * dW : padW - paddingTop:padH - paddingBottom:ceil_mode ? padH * dH : padH - paddingStyle:MPSGraphPaddingStyleExplicit - dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; - - const bool use_divisor = divisor_override.has_value() && divisor_override.value() != 0; - - //if overriding divisor, zeroPads must be included to the average for correct behavior - desc.includeZeroPadToAverage = use_divisor ? true : count_include_pad; - - MPSGraphTensor* avgPoolTensor = [mpsGraph avgPooling2DWithSourceTensor:inputTensor - descriptor:desc - name:@"AvgPool2DTensor"]; - if(use_divisor) { - //here we rescale the average due to MPSGraph not supporting custom divisor directly - const float divisor = float(kH * kW) / (float)divisor_override.value(); - MPSGraphTensor* constantTensor = [mpsGraph constantWithScalar:divisor - shape:@[@1] - dataType:MPSDataTypeFloat32]; - avgPoolTensor = [mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor - secondaryTensor:constantTensor - name:@"AvgPool2DTensor/divisor_override"]; - - } - - return avgPoolTensor; - -} -} //namespace diff --git a/backends/apple/mps/operations/RangeOps.mm b/backends/apple/mps/operations/RangeOps.mm deleted file mode 100644 index fa4df5ac8b0..00000000000 --- a/backends/apple/mps/operations/RangeOps.mm +++ /dev/null @@ -1,31 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::arange(Scalar start, Scalar end, Scalar step, MPSDataType dataType, const int numEle) { - auto shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:&numEle length:sizeof(int32_t)] - shape:@[ @1 ] - dataType:MPSDataTypeInt32]; - auto startScalar = start.isFloatingPoint() ? start.to() : start.to(); - auto stepScalar = step.isFloatingPoint() ? step.to() : step.to(); - auto coordsTensor = [mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil]; - coordsTensor = [mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"]; - - auto startTensor = [mpsGraph constantWithScalar:startScalar - dataType:dataType]; - auto multiplyTensor = [mpsGraph constantWithScalar:stepScalar - dataType:dataType]; - auto scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor - secondaryTensor:multiplyTensor - name:nil]; - auto outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil]; - return outputTensor; -} -} // namespace mps diff --git a/backends/apple/mps/operations/ReduceOps.mm b/backends/apple/mps/operations/ReduceOps.mm deleted file mode 100644 index 0ebc51234c5..00000000000 --- a/backends/apple/mps/operations/ReduceOps.mm +++ /dev/null @@ -1,238 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" - -namespace mps { -using namespace torch; - -std::tuple -MPSGraphModule::minDim(MPSGraphTensor* inputTensor, - int dim, - bool keep_dims) { - - const int input_dims = inputTensor.shape.count; - int wrapped_dim = torch::maybe_wrap_dim(dim, input_dims); - - MPSGraphTensor* minTensor = [mpsGraph reductionMinimumWithTensor:inputTensor - axis:wrapped_dim - name:@"minTensor"]; - - MPSGraphTensor* indicesTensor = [mpsGraph reductionArgMinimumWithTensor:inputTensor - axis:wrapped_dim - name:@"argminTensor"]; - - if ([indicesTensor dataType] != MPSDataTypeInt64) { - indicesTensor = [mpsGraph castTensor:indicesTensor - toType:MPSDataTypeInt64 - name:@"argminTensor/cast"]; - } - - if(!keep_dims) { - minTensor = [mpsGraph squeezeTensor:minTensor - axis:wrapped_dim - name:@"minTensor/squeezed"]; - indicesTensor = [mpsGraph squeezeTensor:indicesTensor - axis:wrapped_dim - name:@"argminTensor/squeezed"]; - } - - return std::make_tuple(minTensor, indicesTensor); -} - -std::tuple -MPSGraphModule::maxDim(MPSGraphTensor* inputTensor, - int dim, - bool keep_dims) { - - const int input_dims = inputTensor.shape.count; - int wrapped_dim = torch::maybe_wrap_dim(dim, input_dims); - - MPSGraphTensor* maxTensor = [mpsGraph reductionMaximumWithTensor:inputTensor - axis:wrapped_dim - name:@"maxTensor"]; - - MPSGraphTensor* indicesTensor = [mpsGraph reductionArgMaximumWithTensor:inputTensor - axis:wrapped_dim - name:@"argmaxTensor"]; - - if ([indicesTensor dataType] != MPSDataTypeInt64) { - indicesTensor = [mpsGraph castTensor:indicesTensor - toType:MPSDataTypeInt64 - name:@"argmaxTensor/cast"]; - } - - if(!keep_dims) { - maxTensor = [mpsGraph squeezeTensor:maxTensor - axis:wrapped_dim - name:@"minTensor/squeezed"]; - indicesTensor = [mpsGraph squeezeTensor:indicesTensor - axis:wrapped_dim - name:@"argminTensor/squeezed"]; - } - - return std::make_tuple(maxTensor, indicesTensor); -} - -PyMPSGraphTensor* -MPSGraphModule::amax(MPSGraphTensor* inputTensor, - IntArrayRef dims, - bool keep_dims) { - - const int input_dims = inputTensor.shape.count; - NSMutableArray* dimArray = [NSMutableArray array]; - for(int dim: dims) { - int wrapped_dim = torch::maybe_wrap_dim(dim, input_dims); - [dimArray addObject:[NSNumber numberWithInt:wrapped_dim]]; - } - - MPSGraphTensor* amaxTensor = [mpsGraph reductionMaximumWithTensor:inputTensor - axes:dimArray - name:@"AmaxTensor"]; - if(!keep_dims) { - amaxTensor = [mpsGraph squeezeTensor:amaxTensor - axes:dimArray - name:@"AmaxTensor/squeezed"]; - } - - return amaxTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::amin(MPSGraphTensor* inputTensor, - IntArrayRef dims, - bool keep_dims) { - - const int input_dims = inputTensor.shape.count; - NSMutableArray* dimArray = [NSMutableArray array]; - for(int dim: dims) { - int wrapped_dim = torch::maybe_wrap_dim(dim, input_dims); - [dimArray addObject:[NSNumber numberWithInt:wrapped_dim]]; - } - - MPSGraphTensor* aminTensor = [mpsGraph reductionMinimumWithTensor:inputTensor - axes:dimArray - name:@"AminTensor"]; - if(!keep_dims) { - aminTensor = [mpsGraph squeezeTensor:aminTensor - axes:dimArray - name:@"AminTensor/squeezed"]; - } - - return aminTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::argmax(MPSGraphTensor* inputTensor, - int64_t dim, - bool keep_dims, - bool flatten) { - - auto dim_ = maybe_wrap_dim(dim, inputTensor.shape.count); - - MPSGraphTensor* output = inputTensor; - //In case the dimension is not specified, expectation is to return index - //of entry in flattened input tensor - if(flatten) { - NSInteger nElems = 0; - for (NSNumber *num in inputTensor.shape) { - nElems *= [num intValue]; - } - dim_ = 0; - output = [mpsGraph reshapeTensor:inputTensor - withShape:@[@-1] - name:nil]; - } - - output = [mpsGraph reductionArgMaximumWithTensor:output - axis:dim_ - name:@"ArgmaxTensor"]; - if(!keep_dims && !flatten) { - output = [mpsGraph squeezeTensor:output - axis:dim_ - name:@"ArgmaxTensor/squeezed"]; - } - - if ([output dataType] != MPSDataTypeInt64) { - output = [mpsGraph castTensor:output - toType:MPSDataTypeInt64 - name:@"ArgmaxTensor/cast"]; - } - - return output; -} - -PyMPSGraphTensor* -MPSGraphModule::argmin(MPSGraphTensor* inputTensor, - int64_t dim, - bool keep_dims, - bool flatten) { - - auto dim_ = maybe_wrap_dim(dim, inputTensor.shape.count); - - MPSGraphTensor* output = inputTensor; - //In case the dimension is not specified, expectation is to return index - //of entry in flattened input tensor - if(flatten) { - NSInteger nElems = 0; - for (NSNumber *num in inputTensor.shape) { - nElems *= [num intValue]; - } - dim_ = 0; - output = [mpsGraph reshapeTensor:inputTensor - withShape:@[@-1] - name:nil]; - } - - output = [mpsGraph reductionArgMinimumWithTensor:output - axis:dim_ - name:@"ArgminTensor"]; - if(!keep_dims && !flatten) { - output = [mpsGraph squeezeTensor:output - axis:dim_ - name:@"ArgminTensor/squeezed"]; - } - - if ([output dataType] != MPSDataTypeInt64) { - output = [mpsGraph castTensor:output - toType:MPSDataTypeInt64 - name:@"ArgminTensor/cast"]; - } - - return output; -} - -PyMPSGraphTensor* -MPSGraphModule::mean(MPSGraphTensor* inputTensor, - IntArrayRef dims, - bool keep_dims) { - - //MPSGraph wants negative axes to be converted to positive - const int input_dims = [inputTensor.shape count]; - NSMutableArray* dimArray = [NSMutableArray array]; - for(int i = 0; i= 3, "pixel_shuffle requires tensor with at least 3 dimensions."); - if (upscale_factor == 1) { - return inputTensor; - } - TORCH_CHECK(inputTensor.shape[ndims - 3].intValue % (upscale_factor * upscale_factor) == 0, - "pixel_shuffle channels must be divisible by upscale factor squared."); - - return [mpsGraph depthToSpace2DTensor:inputTensor - widthAxis:ndims - 1 - heightAxis:ndims - 2 - depthAxis:ndims - 3 - blockSize:upscale_factor - usePixelShuffleOrder:true - name:@"pixel_shuffle"]; - -} - -std::vector -MPSGraphModule::split_size(MPSGraphTensor* inputTensor, IntArrayRef split_sizes, int dim) { - - TORCH_CHECK(dim >=0 && dim < inputTensor.shape.count, - "split_copy: dim ", dim, " out of range for input tensor with ", inputTensor.shape.count, " dimensions"); - - std::vector splitResults; - NSArray* mpsGraphResults; - - mpsGraphResults = [mpsGraph splitTensor:inputTensor - splitSizes:getMPSShape(split_sizes) - axis:dim - name:@"split_size"]; - - for (MPSGraphTensor* splitTensor in mpsGraphResults) { - splitResults.push_back(splitTensor); - } - return splitResults; -} - -std::vector -MPSGraphModule::split(MPSGraphTensor* inputTensor, int split_size, int dim) { - - TORCH_CHECK(dim >=0 && dim < inputTensor.shape.count, - "split_copy: dim ", dim, " out of range for input tensor with ", inputTensor.shape.count, " dimensions"); - TORCH_CHECK(split_size > 0 && split_size <= inputTensor.shape[dim].intValue, - "split_copy: split_size ", split_size, " invalid for inputTensor dimension ", dim, " with length ", inputTensor.shape[dim].intValue); - - NSMutableArray* splits = [NSMutableArray array]; - NSNumber* splitSize = [NSNumber numberWithInt:split_size]; - int i = 1; - - while(split_size * i < inputTensor.shape[dim].intValue) { - [splits addObject:splitSize]; - i++; - } - - int splits_adjust = inputTensor.shape[dim].intValue - (split_size * i); - if (splits_adjust < 0) { - splits[i - 1] = [NSNumber numberWithInt:(split_size + splits_adjust)]; - } - - std::vector splitResults; - NSArray* mpsGraphResults; - - mpsGraphResults = [mpsGraph splitTensor:inputTensor - splitSizes:splits - axis:dim - name:@"split"]; - - for (MPSGraphTensor* splitTensor in mpsGraphResults) { - splitResults.push_back(splitTensor); - } - return splitResults; -} - -std::vector -MPSGraphModule::unbind(MPSGraphTensor* inputTensor, int dim) { - - std::vector unbindResults; - - for (int i = 0; i < inputTensor.shape[dim].intValue; i++) { - unbindResults.push_back( - [mpsGraph sliceTensor:inputTensor - dimension:dim - start:i - length:1 - name:@"unbind"] - ); - } - return unbindResults; -} - -PyMPSGraphTensor* -MPSGraphModule::slice(MPSGraphTensor* inputTensor, - int64_t dim, - c10::optional start, - c10::optional end, - int64_t step) { - - int64_t dim_len = inputTensor.shape[dim].intValue; - // Unwrap optional values - int64_t start_val = start.has_value() ? start.value() : 0; - int64_t end_val = end.has_value() ? end.value() : dim_len; - // Convert python style indices to compatible values - start_val = start_val < 0 ? start_val + dim_len : start_val; - end_val = end_val < 0 ? end_val + dim_len : end_val; - start_val = start_val < 0 ? 0 : start_val; - end_val = end_val < 0 ? 0 : end_val; - start_val = start_val > dim_len ? dim_len : start_val; - end_val = end_val > dim_len ? dim_len : end_val; - - // Define input arrays as required by MPSGraph api - NSMutableArray* start_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; - NSMutableArray* end_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; - NSMutableArray* step_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; - // Step needs to be set to one for all other dims - for (int i = 0; i < inputTensor.shape.count; i++) { - step_arr[i] = @1; - end_arr[i] = inputTensor.shape[i]; - start_arr[i] = @0; - } - - start_arr[dim] = [NSNumber numberWithInteger:start_val]; - end_arr[dim] = [NSNumber numberWithInteger:end_val]; - step_arr[dim] = [NSNumber numberWithInteger:step]; - - return [mpsGraph sliceTensor:inputTensor - starts:start_arr - ends:end_arr - strides:step_arr - name:@"strided_slice"]; -} - -PyMPSGraphTensor* -MPSGraphModule::cat(int dim, py::args catTensors) { - NSMutableArray* inputTensors = [NSMutableArray array]; - for (const auto i: c10::irange(catTensors.size())) { - MPSGraphTensor* catTensor = static_cast(pybind11::cast(catTensors[i])); - if (catTensor != nil) - [inputTensors addObject:static_cast(pybind11::cast(catTensors[i]))]; - } - return [mpsGraph concatTensors:inputTensors - dimension:dim - name:@"cat"]; -} - -PyMPSGraphTensor* -MPSGraphModule::stack(int dim, py::args stackTensors) { - NSMutableArray* inputTensors = [NSMutableArray array]; - for (const auto i: c10::irange(stackTensors.size())) { - [inputTensors addObject:static_cast(pybind11::cast(stackTensors[i]))]; - } - return [mpsGraph stackTensors:inputTensors - axis:dim - name:@"stack"]; -} - -PyMPSGraphTensor* -MPSGraphModule::expand(MPSGraphTensor* inputTensor, - IntArrayRef sizes) { - // In torch, -1 is passed for dimensions which are to stay the same size - NSMutableArray* mpsSizes = [NSMutableArray array]; - [mpsSizes addObjectsFromArray:getMPSShape(sizes)]; - for (int64_t i = 0; i < mpsSizes.count; i++) { - if ([mpsSizes[i] isEqualToNumber:[NSNumber numberWithInt:-1]]) { - mpsSizes[i] = inputTensor.shape[i]; - } - } - return [mpsGraph broadcastTensor:inputTensor - toShape:mpsSizes - name:@"expand_copy"]; -} - -PyMPSGraphTensor* -MPSGraphModule::select(MPSGraphTensor* inputTensor, int dim, int index) { - // Support python-style negative indexing - // MPSGraph already handles negative indexing for start param - if (dim < 0) { - dim += inputTensor.shape.count; - } - MPSGraphTensor* slicedTensor = [mpsGraph sliceTensor:inputTensor - dimension:dim - start:index - length:1 - name:@"slice"]; - slicedTensor = [mpsGraph squeezeTensor:slicedTensor - axis:dim - name:@"slice/squeezed"]; - return slicedTensor; -} - -PyMPSGraphTensor* -MPSGraphModule::view(MPSGraphTensor* inputTensor, - IntArrayRef shape) { - // MPS_TODO: Implement view functionality instead of just copying & reshaping - return [mpsGraph reshapeTensor:inputTensor - withShape:getMPSShape(shape) - name:@"view_copy"]; -} - -PyMPSGraphTensor* -MPSGraphModule::permute(MPSGraphTensor* inputTensor, - IntArrayRef axes) { - NSMutableArray* permutation = [NSMutableArray array]; - for(int64_t i = 0; i* wrappedAxes = [NSMutableArray array]; - for(int64_t i = 0; i(input), py11_export_name, \ - [&](MPSGraphTensor* inputTensor) -> MPSGraphTensor* { \ - return [self.getMPSGraph() graph_op##WithTensor:inputTensor \ - name:nil]; \ - }); \ -}) -// clang-format on diff --git a/backends/apple/mps/operations/UnaryOps.mm b/backends/apple/mps/operations/UnaryOps.mm deleted file mode 100644 index 2ea2322d22f..00000000000 --- a/backends/apple/mps/operations/UnaryOps.mm +++ /dev/null @@ -1,30 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "utils/MPSGraphInterface.h" -#include "UnaryOps.h" - -namespace mps { -using namespace torch; - -PyMPSGraphTensor* -MPSGraphModule::unaryOpTensor( - MPSGraphTensor* inputTensor, - const std::string& op_name, - std::function unaryOpFunction) { - return unaryOpFunction(inputTensor); - } - -PyMPSGraphTensor* -MPSGraphModule::cumsum( - MPSGraphTensor* inputTensor, - int dim -) { - return [mpsGraph cumulativeSumWithTensor:inputTensor - axis:dim - name:@"cumsum"]; -} - -}//namespace diff --git a/backends/apple/mps/operators/__init__.py b/backends/apple/mps/operators/__init__.py new file mode 100644 index 00000000000..4c5c09c00bc --- /dev/null +++ b/backends/apple/mps/operators/__init__.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from . import ( # noqa + # Activation ops + activation_ops, + # binary ops + binary_ops, + # Clamp ops + clamp_ops, + # Constant ops + constant_ops, + # Convolution ops + convolution_ops, + # Indexing ops + indexing_ops, + # Linear algebra ops + linear_algebra_ops, + # Normalization ops + normalization_ops, + op_clone, + op_getitem, + # Pad ops + pad_ops, + # Pooling ops + pooling_ops, + # Range ops + range_ops, + # Reduce ops + reduce_ops, + # Shape ops + shape_ops, + # unary ops + unary_ops, +) + +__all__ = [ + op_getitem, + op_clone, + # Binary ops + binary_ops, + # Unary ops + unary_ops, + # Activation ops + activation_ops, + # Linear algebra ops + linear_algebra_ops, + # Constant ops + constant_ops, + # Clamp ops + clamp_ops, + # Indexing ops + indexing_ops, + # Reduce ops + reduce_ops, + # Shape ops + shape_ops, + # Conv ops + convolution_ops, + # Normalization ops + normalization_ops, + # Pooling ops + pooling_ops, + # Pad ops + pad_ops, + # Range ops + range_ops, +] diff --git a/backends/apple/mps/operators/activation_ops.py b/backends/apple/mps/operators/activation_ops.py new file mode 100644 index 00000000000..0519ca69491 --- /dev/null +++ b/backends/apple/mps/operators/activation_ops.py @@ -0,0 +1,99 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSGELU, + MPSGraph, + MPSHardTanh, + MPSLeakyReLU, + MPSLogSoftmax, + MPSReLU, + MPSSoftmax, +) +from executorch.backends.apple.mps.utils.mps_utils import get_scalar_val +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_node_visitor +class HardTanhVisitor(NodeVisitor): + target = "aten.hardtanh.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSHardTanh) + mps_node.mpsnode_union.min_value = get_scalar_val(node, 1) + mps_node.mpsnode_union.max_value = get_scalar_val(node, 2) + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class ReLU_LeakyReLU_GELU_Visitor(NodeVisitor): + target = ["aten.relu.default", "aten.leaky_relu.default", "aten.gelu.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + self.activation_ops = { + exir_ops.edge.aten.relu.default: MPSReLU, + exir_ops.edge.aten.leaky_relu.default: MPSLeakyReLU, + exir_ops.edge.aten.gelu.default: MPSGELU, + } + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + node_type = self.activation_ops[node.target] + mps_node = self.create_unary_node(node, mps_graph, node_type) + + if node_type is MPSLeakyReLU and len(node.args) == 2: + mps_node.mpsnode_union.negative_slope = cast(float, node.args[1]) + elif ( + node_type is MPSGELU + and node.kwargs + and node.kwargs["approximate"] is not None + ): + mps_node.mpsnode_union.approximate = node.kwargs["approximate"] + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class Softmax_LogSoftmax_Visitor(NodeVisitor): + target = ["aten._softmax.default", "aten._log_softmax.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + node_type = ( + MPSSoftmax + if node.target == exir_ops.edge.aten._softmax.default + else MPSLogSoftmax + ) + mps_node = self.create_unary_node(node, mps_graph, node_type) + + mps_node.mpsnode_union.dim = cast(int, node.args[1]) + mps_node.mpsnode_union.half_to_float = cast(bool, node.args[2]) + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/binary_ops.py b/backends/apple/mps/operators/binary_ops.py new file mode 100644 index 00000000000..a9216aa7654 --- /dev/null +++ b/backends/apple/mps/operators/binary_ops.py @@ -0,0 +1,156 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSAdd, + MPSBitwiseAnd, + MPSBitwiseOr, + MPSBitwiseXor, + MPSDiv, + MPSEq, + MPSFmod, + MPSGe, + MPSGraph, + MPSGt, + MPSLe, + MPSLt, + MPSMinimum, + MPSMul, + MPSNe, + MPSPow, + MPSRemainder, + MPSSub, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_node_visitor +class BinaryOpVisitor(NodeVisitor): + target = [ + # Arithmetic Binary Ops + "aten.add.Tensor", + "aten.add.Scalar", + "aten.sub.Tensor", + "aten.sub.Scalar", + "aten.div.Tensor", + "aten.div.Tensor_mode", + "aten.mul.Tensor", + "aten.mul.Scalar", + "aten.pow.Tensor_Tensor", + "aten.pow.Tensor_Scalar", + "aten.floor_divide.default", + "aten.fmod.Tensor", + "aten.fmod.Scalar", + "aten.remainder.Tensor", + "aten.remainder.Scalar", + "aten.bitwise_and.Tensor", + "aten.bitwise_and.Scalar", + "aten.bitwise_or.Tensor", + "aten.bitwise_or.Scalar", + "aten.bitwise_xor.Tensor", + "aten.bitwise_xor.Scalar", + "aten.minimum.default", + ] + + def __init__(self, *args) -> None: + super().__init__(*args) + self.op_mapping = { + exir_ops.edge.aten.add.Tensor: MPSAdd, + exir_ops.edge.aten.add.Scalar: MPSAdd, + exir_ops.edge.aten.sub.Tensor: MPSSub, + exir_ops.edge.aten.sub.Scalar: MPSSub, + exir_ops.edge.aten.div.Tensor: MPSDiv, + exir_ops.edge.aten.div.Tensor_mode: MPSDiv, + exir_ops.edge.aten.mul.Tensor: MPSMul, + exir_ops.edge.aten.mul.Scalar: MPSMul, + exir_ops.edge.aten.pow.Tensor_Tensor: MPSPow, + exir_ops.edge.aten.pow.Tensor_Scalar: MPSPow, + exir_ops.edge.aten.floor_divide.default: MPSDiv, + exir_ops.edge.aten.fmod.Tensor: MPSFmod, + exir_ops.edge.aten.fmod.Scalar: MPSFmod, + exir_ops.edge.aten.remainder.Tensor: MPSRemainder, + exir_ops.edge.aten.remainder.Scalar: MPSRemainder, + exir_ops.edge.aten.bitwise_and.Tensor: MPSBitwiseAnd, + exir_ops.edge.aten.bitwise_and.Scalar: MPSBitwiseAnd, + exir_ops.edge.aten.bitwise_or.Tensor: MPSBitwiseOr, + exir_ops.edge.aten.bitwise_or.Scalar: MPSBitwiseOr, + exir_ops.edge.aten.bitwise_xor.Tensor: MPSBitwiseXor, + exir_ops.edge.aten.bitwise_xor.Scalar: MPSBitwiseXor, + exir_ops.edge.aten.minimum.default: MPSMinimum, + } + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_binary_node( + node, mps_graph, self.op_mapping[node.target] + ) + + if node.kwargs and "alpha" in node.kwargs and node.kwargs["alpha"] is not None: + mps_node.mpsnode_union.alpha = node.kwargs["alpha"] + + if ( + node.kwargs + and "rounding_mode" in node.kwargs + and node.kwargs["rounding_mode"] is not None + ): + mps_node.mpsnode_union.rounding_mode = node.kwargs["rounding_mode"] + + mps_graph.mps_nodes.append(mps_node) + + +## +## Boolean Binary Ops +## +@register_node_visitor +class ComparasionOpVisitor(NodeVisitor): + target = [ + "aten.eq.Tensor", + "aten.ne.Tensor", + "aten.ge.Tensor", + "aten.gt.Tensor", + "aten.le.Tensor", + "aten.lt.Tensor", + "aten.eq.Scalar", + "aten.ne.Scalar", + "aten.ge.Scalar", + "aten.gt.Scalar", + "aten.le.Scalar", + "aten.lt.Scalar", + ] + + def __init__(self, *args) -> None: + super().__init__(*args) + self.comparison_ops = { + exir_ops.edge.aten.eq.Tensor: MPSEq, + exir_ops.edge.aten.ne.Tensor: MPSNe, + exir_ops.edge.aten.ge.Tensor: MPSGe, + exir_ops.edge.aten.gt.Tensor: MPSGt, + exir_ops.edge.aten.le.Tensor: MPSLe, + exir_ops.edge.aten.lt.Tensor: MPSLt, + exir_ops.edge.aten.eq.Scalar: MPSEq, + exir_ops.edge.aten.ne.Scalar: MPSNe, + exir_ops.edge.aten.ge.Scalar: MPSGe, + exir_ops.edge.aten.gt.Scalar: MPSGt, + exir_ops.edge.aten.le.Scalar: MPSLe, + exir_ops.edge.aten.lt.Scalar: MPSLt, + } + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + + mps_graph.mps_nodes.append( + self.create_binary_node(node, mps_graph, self.comparison_ops[node.target]) + ) diff --git a/backends/apple/mps/operators/clamp_ops.py b/backends/apple/mps/operators/clamp_ops.py new file mode 100644 index 00000000000..6a1b452fe81 --- /dev/null +++ b/backends/apple/mps/operators/clamp_ops.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSClamp, + MPSGraph, + MPSMinMax, + MPSWhere, +) + + +@register_node_visitor +class ClampVisitor(NodeVisitor): + target = "aten.clamp.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSClamp) + + min_value = "-inf" + max_value = "inf" + + if len(node.args) >= 2 and node.args[1] is not None: + min_value = cast(float, node.args[1]) + + if len(node.args) >= 3 and node.args[2] is not None: + max_value = cast(float, node.args[2]) + + mps_node.min_max = MPSMinMax(min_value=min_value, max_value=max_value) + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class WhereVisitor(NodeVisitor): + target = "aten.where.self" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_graph.mps_nodes.append(self.create_tertiary_node(node, mps_graph, MPSWhere)) diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py new file mode 100644 index 00000000000..aa92258da86 --- /dev/null +++ b/backends/apple/mps/operators/constant_ops.py @@ -0,0 +1,104 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSDataType, + MPSFull, + MPSFullLike, + MPSGraph, + MPSNode, +) +from executorch.backends.apple.mps.utils.mps_utils import ( + edge_dtype_to_mps_dtype, + get_input_node, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.sym_util import eval_shape + + +@register_node_visitor +class ConstantOpVisitor(NodeVisitor): + target = [ + "aten.full.default", + "aten.empty.memory_format", + "aten.scalar_tensor.default", + ] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + if len(node.args) >= 3: + raise AssertionError("Unexpected number of input parameters") + + if node.target == exir_ops.edge.aten.scalar_tensor.default: + shape = [1] + else: + shape = eval_shape(node.args[0]) + + if node.target == exir_ops.edge.aten.full.default: + fill_value = cast(float, node.args[1]) + elif node.target == exir_ops.edge.aten.empty.memory_format: + fill_value = 0 + elif node.target == exir_ops.edge.aten.scalar_tensor.default: + fill_value = float(node.args[0]) + + dtype = MPSDataType.mps_data_type_float32 + if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: + dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) + + output_id = self.define_tensor(node, mps_graph) + mps_graph.mps_nodes.append( + MPSNode( + mpsnode_union=MPSFull( + output_id=output_id, + shape=shape, + fill_value=fill_value, + dtype=dtype, + ) + ) + ) + + +@register_node_visitor +class FullLikeVisitor(NodeVisitor): + target = "aten.full_like.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + + if len(node.args) < 2: + raise AssertionError("Full op requires at least size & fill_value args") + + mps_node = self.create_unary_node(node, mps_graph, MPSFullLike) + + mps_node.mpsnode_union.fill_value = cast(float, node.args[1]) + mps_node.mpsnode_union.dtype = self.get_serialized_dtype( + get_input_node(node, 0) + ) + if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: + mps_node.mpsnode_union.dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) + if len(node.args) >= 3: + raise AssertionError("Unexpected number of input parameters") + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/convolution_ops.py b/backends/apple/mps/operators/convolution_ops.py new file mode 100644 index 00000000000..d976d800da3 --- /dev/null +++ b/backends/apple/mps/operators/convolution_ops.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast, List + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSConv2D, + MPSDepthwiseConv2D, + MPSGraph, +) +from executorch.backends.apple.mps.utils.mps_utils import get_input_node +from executorch.backends.transforms import get_shape + + +@register_node_visitor +class Conv2D(NodeVisitor): + target = "aten.convolution.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + input_shape = get_shape(get_input_node(node, 0)) + weight_shape = get_shape(get_input_node(node, 1)) + groups = cast(int, node.args[8]) + + # Convolution is depthwise if groups = input channels and output channel + # is a positive multiple of input channels + + is_depthwise_conv = (groups > 1 and weight_shape[1] == 1) and ( + len(input_shape) >= 4 and len(weight_shape) >= 4 + ) + + mps_node = self.create_tertiary_node( + node, mps_graph, MPSDepthwiseConv2D if is_depthwise_conv else MPSConv2D + ) + + stride = cast(List[int], node.args[3]) + padding = cast(List[int], node.args[4]) + dilation = cast(List[int], node.args[5]) + + if len(stride) == 1: + stride = [1, stride[0]] + if len(padding) == 1: + padding = [0, padding[0]] + if len(dilation) == 1: + dilation = [1, dilation[0]] + + mps_node.mpsnode_union.stride_y = stride[0] + mps_node.mpsnode_union.stride_x = stride[1] + mps_node.mpsnode_union.dilation_y = dilation[0] + mps_node.mpsnode_union.dilation_x = dilation[1] + mps_node.mpsnode_union.groups = groups + mps_node.mpsnode_union.padding_top = padding[0] + mps_node.mpsnode_union.padding_bottom = padding[0] + mps_node.mpsnode_union.padding_right = padding[1] + mps_node.mpsnode_union.padding_left = padding[1] + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/indexing_ops.py b/backends/apple/mps/operators/indexing_ops.py new file mode 100644 index 00000000000..f2c9dc6aeab --- /dev/null +++ b/backends/apple/mps/operators/indexing_ops.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSEmbedding, + MPSGraph, + MPSIndexSelect, +) +from executorch.backends.apple.mps.utils.mps_utils import get_input_node +from executorch.exir.sym_util import eval_expr + + +@register_node_visitor +class IndexSelectVisitor(NodeVisitor): + target = "aten.index_select.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSIndexSelect) + mps_node.mpsnode_union.dim = cast(int, node.args[1]) + mps_node.mpsnode_union.index_id = self.define_tensor( + get_input_node(node, 2), mps_graph + ) + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class EmbeddingVisitor(NodeVisitor): + target = "aten.embedding.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + n_args = len(node.args) + mps_node = self.create_binary_node( + node, + mps_graph, + MPSEmbedding, + ) + + if n_args >= 3: + mps_node.mpsnode_union.padding_idx = eval_expr( + cast(torch.SymInt, node.args[2]) + ) + if n_args >= 4: + mps_node.mpsnode_union.scale_grad_by_freq = cast(bool, node.args[3]) + if n_args >= 5: + mps_node.mpsnode_union.sparse = cast(bool, node.args[4]) + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/linear_algebra_ops.py b/backends/apple/mps/operators/linear_algebra_ops.py new file mode 100644 index 00000000000..ae9d91a8313 --- /dev/null +++ b/backends/apple/mps/operators/linear_algebra_ops.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSAddmm, + MPSGraph, + MPSMatMul, +) + + +@register_node_visitor +class MatMulVisitor(NodeVisitor): + target = ["aten.mm.default", "aten.bmm.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_graph.mps_nodes.append(self.create_binary_node(node, mps_graph, MPSMatMul)) + + +@register_node_visitor +class AddmmVisitor(NodeVisitor): + target = "aten.addmm.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_tertiary_node(node, mps_graph, MPSAddmm) + + if len(node.args) == 4: + mps_node.mpsnode_union.beta = node.args[3] + if len(node.args) == 5: + mps_node.mpsnode_union.alpha = node.args[4] + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py new file mode 100644 index 00000000000..3be321d8551 --- /dev/null +++ b/backends/apple/mps/operators/node_visitor.py @@ -0,0 +1,404 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import ctypes +import logging + +from typing import Dict, List, Tuple, Union + +import torch + +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + Buffer, + MPSCast, + MPSDataType, + MPSGraph, + MPSNode, + MPSNodeUnion, + MPSTensor, +) + +from executorch.backends.apple.mps.utils.mps_utils import ( + edge_dtype_to_mps_dtype, + get_input_node, + get_param_tensor, + get_scalar_val, + is_parameter, +) + +from executorch.backends.transforms import get_shape +from executorch.exir.sym_util import eval_shape + +from torch.export.exported_program import ExportedProgram + + +class NodeVisitor: + """ + Node visitor pattern for visiting nodes in an edge IR graph and + serializing them using the mps serialization schema. + """ + + _tensor_to_id: Dict[torch.fx.Node, int] = {} + _convert_model_to_fp16: bool = True + + def __init__( + self, exported_program: ExportedProgram, convert_model_to_fp16: bool = True + ): + self._exported_program = exported_program + self._convert_model_to_fp16 = convert_model_to_fp16 + + @property + def tensor_to_id(self) -> Dict[torch.fx.Node, int]: + return self._tensor_to_id + + @property + def convert_model_to_fp16(self) -> bool: + return self._convert_model_to_fp16 + + @property + def exported_program(self) -> ExportedProgram: + return self._exported_program + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + raise NotImplementedError("NodeVisitor must be extended!") + + def define_tensor( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> int: + """Defines a tensor value into the MPSGraph serialization schema + + Args: + tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph + mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer + """ + + if node is None: + return -1 + + if node in self.tensor_to_id: + return self.tensor_to_id[node] + + # Get a unique id for the node. + id = self.get_serialized_id(node, mps_graph) + cb_size, constant_buffer, mps_data_type = self.get_serialized_buffer( + node, mps_graph, id + ) + dims = get_shape(node) + + logging.debug( + f"Serializing: {node}, data type: {node.meta['val'].dtype}, dims: {dims}" + ) + mps_tensor = MPSTensor( + datatype=mps_data_type, + num_dims=len(dims), + dims=dims, + constant_buffer_size=cb_size, + constant_buffer=constant_buffer, + ) + + mps_graph.mps_values.append(mps_tensor) + return id + + def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[int]: + """_summary_ + + Args: + node (torch.fx.Node): _description_ + mps_graph (MPSGraph): _description_ + """ + if node is None: + return -1 + + if node in self.tensor_to_id: + return self.tensor_to_id[node] + + self.tensor_to_id[node] = [] + for i in range(len(node.meta["val"])): + id = len(mps_graph.mps_values) + self.tensor_to_id[node].append(id) + + tensor = node.meta["val"][i] + dims = eval_shape(tensor.shape) + mps_data_type = edge_dtype_to_mps_dtype(tensor.dtype) + logging.debug( + f"Serializing: [{i}]: {node}, data type: {tensor.dtype}, dims: {dims}" + ) + + mps_tensor = MPSTensor( + datatype=mps_data_type, + num_dims=len(dims), + dims=dims, + constant_buffer_size=0, + constant_buffer=Buffer(storage=b""), + ) + logging.debug(f" Serialized tensor: {mps_tensor}") + mps_graph.mps_values.append(mps_tensor) + return self.tensor_to_id[node] + + def define_scalar( + self, + val: Union[float, int], + mps_data_type: MPSDataType, + mps_graph: MPSGraph, + ): + """Defines a scalar value into the MPSGraph serialization schema + + Args: + tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph + mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer + """ + assert isinstance(val, int) or isinstance(val, float) + + if val in self.tensor_to_id: + return self.tensor_to_id[val] + + id = self.get_serialized_id(val, mps_graph) + + if ( + self.convert_model_to_fp16 + and mps_data_type == MPSDataType.mps_data_type_float32 + ): + mps_data_type = MPSDataType.mps_data_type_float16 + + if isinstance(val, int): + array = bytes(ctypes.c_int32(val)) + elif isinstance(val, float): + array = bytes(ctypes.c_float(val)) + else: + raise RuntimeError("Unknown data type!") + + constant_buffer = Buffer(storage=array) + constant_buffer_size = len(array) + + mps_tensor = MPSTensor( + datatype=mps_data_type, + num_dims=1, + dims=[1], + constant_buffer_size=constant_buffer_size, + constant_buffer=constant_buffer, + ) + + if id not in mps_graph.constant_ids: + mps_graph.constant_ids.append(id) + + mps_graph.mps_values.append(mps_tensor) + return id + + def get_serialized_buffer( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + node_id: int, + ) -> Tuple[int, Buffer, MPSDataType]: + """ + If tensor holds some constant data, serialize it and return the + index of its placement in the constant buffer + + Args: + tensor (torch.fx.Node): _description_ + mps_graph (MPSGraph): _description_ + + Returns: + _type_: _description_ + """ + mps_data_type = self.get_serialized_dtype(node) + + # Check if this node is a lifted parameter + if not is_parameter(self.exported_program, node): + return 0, Buffer(storage=b""), mps_data_type + + tensor = get_param_tensor(self.exported_program, node) + assert tensor is not None and isinstance(tensor, torch.Tensor) + tensor = tensor.contiguous() + if self.convert_model_to_fp16 and tensor.dtype == torch.float32: + tensor = tensor.half() + mps_data_type = MPSDataType.mps_data_type_float16 + + if node_id not in mps_graph.constant_ids: + mps_graph.constant_ids.append(node_id) + + array_type = ctypes.c_char * tensor.untyped_storage().nbytes() + array = ctypes.cast( + tensor.untyped_storage().data_ptr(), + ctypes.POINTER(array_type), + ).contents + buffer = Buffer(storage=bytes(array)) + + return tensor.untyped_storage().nbytes(), buffer, mps_data_type + + def get_serialized_id( + self, node: Union[torch.fx.Node, float, int], mps_graph: MPSGraph + ) -> int: + + """ + Map a tensor to a unique id. If the tensor was already mapped, return + the existent id. + + Args: + tensor (Union[torch.fx.Node, float]): _description_ + mps_graph (MPSGraph): _description_ + + Returns: + int: _description_ + """ + if node in self.tensor_to_id: + return self.tensor_to_id[node] + + id = len(mps_graph.mps_values) + self.tensor_to_id[node] = id + + return id + + def get_serialized_dtype( + self, + node: torch.fx.Node, + ) -> MPSDataType: + return edge_dtype_to_mps_dtype(node.meta["val"].dtype) + + def create_tertiary_node( + self, node: torch.fx.Node, mps_graph: MPSGraph, tertiary_op: MPSNodeUnion + ): + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + input2_id = self.define_tensor(get_input_node(node, 1), mps_graph) + input3_id = self.define_tensor(get_input_node(node, 2), mps_graph) + output_id = self.define_tensor(node, mps_graph) + return MPSNode( + mpsnode_union=tertiary_op( + input1_id=input1_id, + input2_id=input2_id, + input3_id=input3_id, + output_id=output_id, + ) + ) + + def create_binary_node( + self, node: torch.fx.Node, mps_graph: MPSGraph, binary_op: MPSNodeUnion + ) -> MPSNode: + input1_node = get_input_node(node, 0) + input1_id = self.define_tensor(input1_node, mps_graph) + + # Handle both tensor and scalar variants of the op. + # In case of scalar ops, manually define a constant and serialize it in the FlatBuffer. + if isinstance(node.args[1], torch.fx.Node): + # Second argument is a node. + input2_id = self.define_tensor(get_input_node(node, 1), mps_graph) + else: + # Second argument is a scalar. + scalar_val = get_scalar_val(node, 1) + if input1_node.meta["val"].dtype == torch.float32: + scalar_val = float(scalar_val) + input2_id = self.define_scalar( + scalar_val, self.get_serialized_dtype(input1_node), mps_graph + ) + + output_id = self.define_tensor(node, mps_graph) + return MPSNode( + mpsnode_union=binary_op( + input1_id=input1_id, input2_id=input2_id, output_id=output_id + ) + ) + + def create_unary_node( + self, node: torch.fx.Node, mps_graph: MPSGraph, unary_op: MPSNodeUnion + ) -> MPSNode: + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + output_id = self.define_tensor(node, mps_graph) + return MPSNode(mpsnode_union=unary_op(input1_id=input1_id, output_id=output_id)) + + +# This will hold mapping of all node names to the visitor class. +_node_visitor_dict = {} + + +def register_node_visitor(visitor): + assert ( + isinstance(visitor, type) + and issubclass(visitor, NodeVisitor) + and hasattr(visitor, "target") + ), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}" + if isinstance(visitor.target, list): + for elem in visitor.target: + _node_visitor_dict[elem] = visitor + else: + _node_visitor_dict[visitor.target] = visitor + + +def get_node_visitors(*args) -> Dict[str, NodeVisitor]: + node_visitors = {} + """ + Create a new class instance at runtime, and put them in a dict + """ + for target, visitor in _node_visitor_dict.items(): + assert callable( + visitor + ), f"Expecting a callable class, but got {visitor} of type {type(visitor)}" + node_visitors[target] = visitor(*args) + + placeholder_output_visitor = NodeVisitor(*args) + node_visitors["placeholder"] = placeholder_output_visitor + node_visitors["output"] = placeholder_output_visitor + return node_visitors + + +def process_placeholder_nodes( + exported_program: ExportedProgram, + edge_graph_module: torch.fx.GraphModule, + mps_graph: MPSGraph, + placeholder_visitor: NodeVisitor, +) -> None: + # Visit the placeholder nodes in the same order they are passed to the + # forward function - forward(*args). When lifted graphs are being used, + # parameters/buffers are lifted as placeholders and the order of the args + # is not matching anymore with the original graph. We can retrieve the + # original order by parsing all the placeholder nodes, and check if they are + # constant tensors. + # + # Constant tensors will be bundled directly in the FlatBuffer and they won't be + # provided by ExecuTorch during runtime. + + for node in edge_graph_module.graph.nodes: + if node.op == "placeholder" and not is_parameter( + exp_prog=exported_program, node=node + ): + if node.meta["val"] is None: + continue + + input_id = placeholder_visitor.define_tensor(node, mps_graph) + mps_graph.input_ids.append(input_id) + + if placeholder_visitor.convert_model_to_fp16: + mps_node = MPSNode( + mpsnode_union=MPSCast( + input1_id=input_id, + output_id=input_id, + dtype=MPSDataType.mps_data_type_float16, + ) + ) + mps_graph.mps_nodes.append(mps_node) + + +def process_output_node( + output_node, + mps_graph: MPSGraph, + output_visitor: NodeVisitor, +) -> None: + output_id = output_visitor.define_tensor(output_node, mps_graph) + mps_graph.output_ids.append(output_id) + + if output_visitor.convert_model_to_fp16: + mps_node = MPSNode( + mpsnode_union=MPSCast( + input1_id=output_id, + output_id=output_id, + dtype=MPSDataType.mps_data_type_float32, + ) + ) + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/normalization_ops.py b/backends/apple/mps/operators/normalization_ops.py new file mode 100644 index 00000000000..f613be933f7 --- /dev/null +++ b/backends/apple/mps/operators/normalization_ops.py @@ -0,0 +1,97 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast, List + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSBatchNorm, + MPSGraph, + MPSLayerNorm, + MPSNode, +) +from executorch.backends.apple.mps.utils.mps_utils import get_input_node, get_scalar_val +from executorch.exir.sym_util import eval_shape + + +@register_node_visitor +class BatchNorm(NodeVisitor): + target = "aten._native_batch_norm_legit_no_training.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + + input_id = self.define_tensor(get_input_node(node, 0), mps_graph) + weight_id = self.define_tensor(get_input_node(node, 1), mps_graph) + bias_id = self.define_tensor(get_input_node(node, 2), mps_graph) + mean_id = self.define_tensor(get_input_node(node, 3), mps_graph) + var_id = self.define_tensor(get_input_node(node, 4), mps_graph) + momentum: float = get_scalar_val(node, 5) + epsilon: float = get_scalar_val(node, 6) + + output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph) + + mps_node = MPSNode( + mpsnode_union=MPSBatchNorm( + input_id=input_id, + mean_id=mean_id, + var_id=var_id, + weight_id=weight_id, + bias_id=bias_id, + momentum=momentum, + epsilon=epsilon, + output1_id=output1_id, + output2_id=output2_id, + output3_id=output3_id, + ) + ) + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class LayerNorm(NodeVisitor): + target = "aten.native_layer_norm.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + normalized_shape = eval_shape(cast(List[torch.SymInt], node.args[1])) + weight_id = self.define_tensor(get_input_node(node, 2), mps_graph) + bias_id = self.define_tensor(get_input_node(node, 3), mps_graph) + epsilon: float = get_scalar_val(node, 4) + output1_id, output2_id, output3_id = self.define_tensor_list(node, mps_graph) + + mps_graph.mps_nodes.append( + MPSNode( + mpsnode_union=MPSLayerNorm( + input1_id=input1_id, + normalized_shape=normalized_shape, + weight_id=weight_id, + bias_id=bias_id, + eps=epsilon, + output1_id=output1_id, + output2_id=output2_id, + output3_id=output3_id, + ) + ) + ) diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py new file mode 100644 index 00000000000..2310ae02da7 --- /dev/null +++ b/backends/apple/mps/operators/op_clone.py @@ -0,0 +1,35 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph +from executorch.backends.apple.mps.utils.mps_utils import get_input_node +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_node_visitor +class CloneVisitor(NodeVisitor): + target = ["aten.clone.default", "aten._to_copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + if node.target == exir_ops.edge.aten._to_copy.default: + # TODO + if len(node.args) > 1: + raise RuntimeError( + "aten._to_copy not supported with more than one argument currently" + ) + input_id = self.define_tensor(get_input_node(node, 0), mps_graph) + self.tensor_to_id[node] = input_id diff --git a/backends/apple/mps/operators/op_getitem.py b/backends/apple/mps/operators/op_getitem.py new file mode 100644 index 00000000000..a67a54d4568 --- /dev/null +++ b/backends/apple/mps/operators/op_getitem.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph +from executorch.backends.apple.mps.utils.mps_utils import get_input_node, get_scalar_val + + +@register_node_visitor +class GetItemVisitor(NodeVisitor): + target = "getitem" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + self.tensor_to_id[node] = self.tensor_to_id[get_input_node(node, 0)][ + get_scalar_val(node, 1) + ] diff --git a/backends/apple/mps/operators/pad_ops.py b/backends/apple/mps/operators/pad_ops.py new file mode 100644 index 00000000000..fe3faa838d6 --- /dev/null +++ b/backends/apple/mps/operators/pad_ops.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSConstantPadND, + MPSGraph, +) +from executorch.exir.sym_util import eval_shape + + +@register_node_visitor +class ConstantPadNDVisitor(NodeVisitor): + target = "aten.constant_pad_nd.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSConstantPadND) + + mps_node.mpsnode_union.pad = eval_shape(cast(torch.SymInt, node.args[1])) + mps_node.mpsnode_union.value = float(node.args[2]) + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/pooling_ops.py b/backends/apple/mps/operators/pooling_ops.py new file mode 100644 index 00000000000..711702db9c6 --- /dev/null +++ b/backends/apple/mps/operators/pooling_ops.py @@ -0,0 +1,142 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast, List + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSAvgPool2D, + MPSGraph, + MPSMaxPool2DWithIndices, + MPSNode, +) +from executorch.backends.apple.mps.utils.mps_utils import get_input_node + + +@register_node_visitor +class MaxPool2DWithIndicesVisitor(NodeVisitor): + target = "aten.max_pool2d_with_indices.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + n_args = len(node.args) + if n_args > 6: + raise AssertionError( + f"Unexpected number of input parameters for {self.target}" + ) + + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + + padding = [0, 0] + dilation = [1, 1] + ceil_mode = False + kernel_size = cast(List[int], node.args[1]) + stride = cast(List[int], node.args[2]) + if n_args >= 4: + padding = cast(List[int], node.args[3]) + if n_args >= 5: + dilation = cast(List[int], node.args[4]) + if n_args == 6: + ceil_mode = cast(bool, node.args[5]) + padding_top = padding[0] + padding_left = padding[1] + padding_bottom = padding[0] * stride[0] if ceil_mode else padding[0] + padding_right = padding[1] * stride[1] if ceil_mode else padding[1] + + output1_id, output2_id = self.define_tensor_list(node, mps_graph) + mps_graph.mps_nodes.append( + MPSNode( + mpsnode_union=MPSMaxPool2DWithIndices( + input1_id=input1_id, + kernel_height=kernel_size[0], + kernel_width=kernel_size[1], + stride_height=stride[0], + stride_width=stride[1], + padding_left=padding_left, + padding_right=padding_right, + padding_top=padding_top, + padding_bottom=padding_bottom, + dilation_height=dilation[0], + dilation_width=dilation[1], + ceil_mode=ceil_mode, + output1_id=output1_id, + output2_id=output2_id, + ) + ) + ) + + +@register_node_visitor +class AvgPool2DVisitor(NodeVisitor): + target = "aten.avg_pool2d.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + n_args = len(node.args) + if n_args > 7: + raise AssertionError( + f"Unexpected number of input parameters for {self.target}" + ) + + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + output1_id = self.define_tensor(node, mps_graph) + + padding_top, padding_left = [0, 0] + dilation_height, dilation_width = [1, 1] + + ceil_mode = False + count_include_pad = True + divisor_override = 0 + kernel_height, kernel_width = cast(List[int], node.args[1]) + stride_height, stride_width = cast(List[int], node.args[2]) + if n_args >= 4: + padding_top, padding_left = cast(List[int], node.args[3]) + if n_args >= 5: + ceil_mode = cast(bool, node.args[4]) + if n_args == 6: + count_include_pad = cast(bool, node.args[5]) + if n_args == 7: + divisor_override = cast(int, node.args[6]) + + padding_bottom = padding_top * stride_height if ceil_mode else padding_top + padding_right = padding_left * stride_width if ceil_mode else padding_left + + mps_graph.mps_nodes.append( + MPSNode( + mpsnode_union=MPSAvgPool2D( + input1_id=input1_id, + kernel_height=kernel_height, + kernel_width=kernel_width, + stride_height=stride_height, + stride_width=stride_width, + padding_left=padding_left, + padding_right=padding_right, + padding_top=padding_top, + padding_bottom=padding_bottom, + dilation_height=dilation_height, + dilation_width=dilation_width, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + output1_id=output1_id, + ) + ) + ) diff --git a/backends/apple/mps/operators/range_ops.py b/backends/apple/mps/operators/range_ops.py new file mode 100644 index 00000000000..6a6ee12b835 --- /dev/null +++ b/backends/apple/mps/operators/range_ops.py @@ -0,0 +1,53 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSArange, + MPSGraph, + MPSNode, +) +from executorch.backends.apple.mps.utils.mps_utils import edge_dtype_to_mps_dtype + + +@register_node_visitor +class ArangeVisitor(NodeVisitor): + target = "aten.arange.start_step" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + step = 1.0 + if len(node.args) > 2 and node.args[2] is not None: + step = float(node.args[2]) + + start = float(node.args[0]) + end = float(node.args[1]) + + dtype = edge_dtype_to_mps_dtype(node.meta["val"].dtype) + if node.kwargs and "dtype" in node.kwargs and node.kwargs["dtype"] is not None: + dtype = edge_dtype_to_mps_dtype(node.kwargs["dtype"]) + + output_id = self.define_tensor(node, mps_graph) + + mps_node = MPSNode( + mpsnode_union=MPSArange( + output_id=output_id, + start=start, + end=end, + step=step, + dtype=dtype, + ) + ) + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/reduce_ops.py b/backends/apple/mps/operators/reduce_ops.py new file mode 100644 index 00000000000..b0b4e39e8f5 --- /dev/null +++ b/backends/apple/mps/operators/reduce_ops.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast, List + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSGraph, + MPSMean, +) + + +@register_node_visitor +class MeanVisitor(NodeVisitor): + target = "aten.mean.dim" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSMean) + + dims = cast(List[int], node.args[1]) + mps_node.mpsnode_union.num_dims = len(dims) + mps_node.mpsnode_union.dims = dims + if len(node.args) == 3: + mps_node.mpsnode_union.keep_dims = node.args[2] + + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/shape_ops.py b/backends/apple/mps/operators/shape_ops.py new file mode 100644 index 00000000000..b6c9bad692e --- /dev/null +++ b/backends/apple/mps/operators/shape_ops.py @@ -0,0 +1,264 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +from typing import cast, List + +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSCat, + MPSExpand, + MPSGraph, + MPSNode, + MPSPermute, + MPSPixelShuffle, + MPSSelect, + MPSSlice, + MPSSplitWithSizes, + MPSSqueeze, + MPSUnsqueeze, + MPSView, +) +from executorch.backends.apple.mps.utils.mps_utils import get_input_node +from executorch.backends.transforms import get_shape +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.sym_util import eval_expr, eval_shape + + +@register_node_visitor +class PermuteVisitor(NodeVisitor): + target = "aten.permute_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSPermute) + + permute_order = cast(List[int], node.args[1]) + mps_node.mpsnode_union.num_dims = len(permute_order) + mps_node.mpsnode_union.perm = permute_order + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class ViewExpandVisitor(NodeVisitor): + target = ["aten.view_copy.default", "aten.expand_copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + node_type = ( + MPSView + if node.target is exir_ops.edge.aten.view_copy.default + else MPSExpand + ) + mps_node = self.create_unary_node(node, mps_graph, node_type) + + view_shape = cast(List[int], node.args[1]) + mps_node.mpsnode_union.num_dims = len(view_shape) + mps_node.mpsnode_union.shape = view_shape + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class CatVisitor(NodeVisitor): + target = "aten.cat.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + tensors = cast(List[torch.fx.Node], node.args[0]) + output_id = self.define_tensor(node, mps_graph) + input_ids: List[int] = [] + + for tensor in tensors: + input_ids.append(self.define_tensor(tensor, mps_graph)) + + dim = 0 + if len(node.args) > 1: + dim = cast(int, node.args[1]) + if dim < 0 and len(tensors) > 0: + dim += len(get_shape(tensors[0])) + + mps_graph.mps_nodes.append( + MPSNode( + mpsnode_union=MPSCat(input_ids=input_ids, output_id=output_id, dim=dim), + ), + ) + + +@register_node_visitor +class SqueezeUnsqueezeVisitor(NodeVisitor): + target = ["aten.unsqueeze_copy.default", "aten.squeeze_copy.dims"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + node_type = ( + MPSUnsqueeze + if node.target is exir_ops.edge.aten.unsqueeze_copy.default + else MPSSqueeze + ) + + mps_node = self.create_unary_node(node, mps_graph, node_type) + + if node_type is MPSUnsqueeze: + mps_node.mpsnode_union.dim = cast(int, node.args[1]) + else: + dims = cast(List[int], node.args[1]) + input_shape = get_shape(get_input_node(node, 0)) + new_dims = [] + for dim in dims: + if input_shape[dim] == 1: + new_dims.append(dim) + mps_node.mpsnode_union.dims = new_dims + + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class SelectVisitor(NodeVisitor): + target = "aten.select_copy.int" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSSelect) + mps_node.mpsnode_union.dim = cast(int, node.args[1]) + mps_node.mpsnode_union.index = eval_expr(cast(torch.SymInt, node.args[2])) + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class PixelShuffleVisitor(NodeVisitor): + target = "aten.pixel_shuffle.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSPixelShuffle) + mps_node.mpsnode_union.upscale_factor = cast(int, node.args[1]) + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class SliceVisitor(NodeVisitor): + target = "aten.slice_copy.Tensor" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSSlice) + + def maybe_wrap_dim(dim: int, n: int) -> List[int]: + if dim < 0: + wrapped_dim = dim + n + if wrapped_dim < 0: + wrapped_dim = 0 + return wrapped_dim + elif dim > n: + return n + return dim + + start = None + end = None + if len(node.args) >= 2: + mps_node.mpsnode_union.dim = cast(int, node.args[1]) + if len(node.args) >= 4: + end = cast(int, node.args[3]) + start = cast(int, node.args[2]) + if len(node.args) >= 5: + mps_node.mpsnode_union.step = cast(int, node.args[4]) + + input_shape = get_shape(get_input_node(node, 0)) + dim_len = input_shape[ + maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape)) + ] + + start_val = start if start is not None else 0 + end_val = end if end is not None else dim_len + + mps_node.mpsnode_union.start = maybe_wrap_dim(start_val, dim_len) + mps_node.mpsnode_union.end = maybe_wrap_dim(end_val, dim_len) + mps_graph.mps_nodes.append(mps_node) + + +@register_node_visitor +class SplitWithSizesVisitor(NodeVisitor): + target = "aten.split_with_sizes_copy.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # mps_node = self.create_unary_node( + # node, mps_graph, MPSSlice + # ) + + input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) + output_ids = self.define_tensor_list(node, mps_graph) + split_sizes = eval_shape(cast(torch.SymInt, node.args[1])) + dim = cast(int, node.args[2]) + input_shape = get_shape(get_input_node(node, 0)) + + if dim < 0 or dim >= len(input_shape): + raise RuntimeError( + f"split_copy: dim {dim} out of range for input tensor with {len(input_shape)} dimensions" + ) + + mps_node = MPSNode( + mpsnode_union=MPSSplitWithSizes( + input1_id=input1_id, + output_ids=output_ids, + split_sizes=split_sizes, + dim=dim, + ) + ) + mps_graph.mps_nodes.append(mps_node) diff --git a/backends/apple/mps/operators/unary_ops.py b/backends/apple/mps/operators/unary_ops.py new file mode 100644 index 00000000000..a8b957dc44f --- /dev/null +++ b/backends/apple/mps/operators/unary_ops.py @@ -0,0 +1,122 @@ +import torch +from executorch.backends.apple.mps.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.apple.mps.serialization.mps_graph_schema import ( + MPSAbs, + MPSAcos, + MPSAcosh, + MPSAsin, + MPSAsinh, + MPSAtan, + MPSAtanh, + MPSBitwiseNot, + MPSCeil, + MPSCos, + MPSCosh, + MPSErf, + MPSExp, + MPSExp2, + MPSFloor, + MPSGraph, + MPSIsinf, + MPSIsnan, + MPSLog, + MPSLog10, + MPSLog2, + MPSNeg, + MPSReciprocal, + MPSRound, + MPSRsqrt, + MPSSigmoid, + MPSSign, + MPSSin, + MPSSinh, + MPSSqrt, + MPSTan, + MPSTanh, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_node_visitor +class UnaryOpVisitor(NodeVisitor): + target = [ + "aten.exp.default", + "aten.exp2.default", + "aten.reciprocal.default", + "aten.sqrt.default", + "aten.neg.default", + "aten.log.default", + "aten.log10.default", + "aten.log2.default", + "aten.erf.default", + "aten.floor.default", + "aten.ceil.default", + "aten.rsqrt.default", + "aten.sigmoid.default", + "aten.sin.default", + "aten.sign.default", + "aten.cos.default", + "aten.tan.default", + "aten.abs.default", + "aten.asin.default", + "aten.acos.default", + "aten.atan.default", + "aten.sinh.default", + "aten.cosh.default", + "aten.tanh.default", + "aten.asinh.default", + "aten.acosh.default", + "aten.atanh.default", + "aten.bitwise_not.default", + "aten.isnan.default", + "aten.isinf.default", + "aten.round.default", + ] + + def __init__(self, *args) -> None: + super().__init__(*args) + self.unary_op = { + exir_ops.edge.aten.exp.default: MPSExp, + exir_ops.edge.aten.exp2.default: MPSExp2, + exir_ops.edge.aten.reciprocal.default: MPSReciprocal, + exir_ops.edge.aten.sqrt.default: MPSSqrt, + exir_ops.edge.aten.neg.default: MPSNeg, + exir_ops.edge.aten.log.default: MPSLog, + exir_ops.edge.aten.log10.default: MPSLog10, + exir_ops.edge.aten.log2.default: MPSLog2, + exir_ops.edge.aten.erf.default: MPSErf, + exir_ops.edge.aten.floor.default: MPSFloor, + exir_ops.edge.aten.ceil.default: MPSCeil, + exir_ops.edge.aten.rsqrt.default: MPSRsqrt, + exir_ops.edge.aten.sigmoid.default: MPSSigmoid, + exir_ops.edge.aten.sin.default: MPSSin, + exir_ops.edge.aten.sign.default: MPSSign, + exir_ops.edge.aten.cos.default: MPSCos, + exir_ops.edge.aten.tan.default: MPSTan, + exir_ops.edge.aten.abs.default: MPSAbs, + exir_ops.edge.aten.asin.default: MPSAsin, + exir_ops.edge.aten.acos.default: MPSAcos, + exir_ops.edge.aten.atan.default: MPSAtan, + exir_ops.edge.aten.sinh.default: MPSSinh, + exir_ops.edge.aten.cosh.default: MPSCosh, + exir_ops.edge.aten.tanh.default: MPSTanh, + exir_ops.edge.aten.asinh.default: MPSAsinh, + exir_ops.edge.aten.acosh.default: MPSAcosh, + exir_ops.edge.aten.atanh.default: MPSAtanh, + exir_ops.edge.aten.bitwise_not.default: MPSBitwiseNot, + exir_ops.edge.aten.isnan.default: MPSIsnan, + exir_ops.edge.aten.isinf.default: MPSIsinf, + exir_ops.edge.aten.round.default: MPSRound, + } + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_graph.mps_nodes.append( + self.create_unary_node(node, mps_graph, self.unary_op[node.target]) + ) diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py index a9246219982..3fdc09028f9 100644 --- a/backends/apple/mps/partition/mps_partitioner.py +++ b/backends/apple/mps/partition/mps_partitioner.py @@ -4,63 +4,78 @@ # import logging +from typing import Any, Dict, List, Union import torch +from executorch.backends.apple.mps.mps_preprocess import MPSBackend +from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_partitions_from_list_of_nodes, +) from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) - from torch._export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase -logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) +class MPSOperatorSupport(OperatorSupportBase): + def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): + self.node_visitors = get_node_visitors(edge_program) -class OperatorsSupportedForMpsBackend(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - supported_mps_ops = [ - torch.ops.aten.add.Tensor, - torch.ops.aten.mm.default, - torch.ops.aten.div.default, - ] - ret_val = ( - (node.op == "call_function" and node.target in supported_mps_ops) - or node.op == "get_attr" - or node.op == "output" - ) - return ret_val + if node.op != "call_function": + return False + + if node.target.__name__ not in self.node_visitors: + return False + + return True -# TODO MPSPartitioner is work in progress currently. -# Use whole graph delegation instead when lowering to MPS. class MPSPartitioner(Partitioner): - compile_spec = [] + compile_spec: List[CompileSpec] = [] def __init__(self) -> None: - self.delegation_spec = DelegationSpec("MPSBackend", self.compile_spec) + self.delegation_spec = DelegationSpec(MPSBackend.__name__, self.compile_spec) + self.partition_tags: Dict[str, DelegationSpec] = {} - def partition(self, exported_program: ExportedProgram) -> PartitionResult: - # Run the CapabilityBasedPartitioner to return the largest possible - # subgraphs containing the nodes with the tags - logger.info("MpsPartitioner::partition") - partition_tags = {} - - capability_partitioner = CapabilityBasedPartitioner( - exported_program.graph_module, - OperatorsSupportedForMpsBackend(), - allows_single_node_partition=True, + def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: + self.supported_ops = MPSOperatorSupport( + edge_program=edge_program, compiler_specs=self.delegation_spec.compile_specs + ) + return generate_partitions_from_list_of_nodes( + edge_program.graph_module, + op_support=self.supported_ops, ) - partition_list = capability_partitioner.propose_partitions() - for partition in partition_list: + + def tag_nodes(self, partitions: List[Partition]) -> None: + for partition in partitions: for node in partition.nodes: - tag = f"tag{partition.id}" - node.meta["delegation_tag"] = tag - partition_tags[tag] = self.delegation_spec + delegation_tag = f"mps_{partition.id}" + node.meta["delegation_tag"] = delegation_tag + self.partition_tags[delegation_tag] = self.delegation_spec - return PartitionResult( - tagged_exported_program=exported_program, partition_tags=partition_tags + @staticmethod + def check_partitions(partitions: Union[dict, list]) -> bool: + pl = len(partitions) + if pl == 0: + logging.warning("Nothing can be partitioned!") + else: + logging.info(f"Found {pl} subgraphs to be partitioned.") + return pl != 0 + + # override + def partition(self, edge_program: ExportedProgram) -> PartitionResult: + partitions = self.generate_partitions(edge_program=edge_program) + if self.check_partitions(partitions): + self.tag_nodes(partitions) + x = PartitionResult( + tagged_exported_program=edge_program, partition_tags=self.partition_tags ) + + return x diff --git a/backends/apple/mps/runtime/MPSBackend.mm b/backends/apple/mps/runtime/MPSBackend.mm index 4afa37b30d1..e2070bf2c62 100644 --- a/backends/apple/mps/runtime/MPSBackend.mm +++ b/backends/apple/mps/runtime/MPSBackend.mm @@ -37,7 +37,11 @@ bool is_available() const override { // destructible, we must call the destructor manually in destroy(). new (executor) mps::delegate::MPSExecutor; Error err = mps::delegate::MPSCompiler::compileModel( - processed->data(), processed->size(), executor, context.get_runtime_allocator(), compile_specs); + processed->data(), + processed->size(), + executor, + context.get_runtime_allocator(), + compile_specs); ET_CHECK_OR_RETURN_ERROR( err == Error::Ok, Internal, diff --git a/backends/apple/mps/runtime/MPSCompiler.h b/backends/apple/mps/runtime/MPSCompiler.h index 4d8e39e6b86..4bc2f60ee42 100644 --- a/backends/apple/mps/runtime/MPSCompiler.h +++ b/backends/apple/mps/runtime/MPSCompiler.h @@ -5,12 +5,14 @@ #pragma once +#include + #include #include #include + #include #include -#include "MPSExecutor.h" namespace torch { namespace executor { diff --git a/backends/apple/mps/runtime/MPSCompiler.mm b/backends/apple/mps/runtime/MPSCompiler.mm index b9c07430170..65a261d0f5d 100644 --- a/backends/apple/mps/runtime/MPSCompiler.mm +++ b/backends/apple/mps/runtime/MPSCompiler.mm @@ -3,14 +3,23 @@ // Provided subject to the LICENSE file in the top level directory. // +// Obj-C headers #import #import #import -#include "MPSCompiler.h" -#include + +// MPS headers +#include +#include +#include +#include + +// Runtime headers #include + #include #include +#include #define MPS_UNUSED(x) ( (void)(x) ) @@ -24,50 +33,6 @@ @interface MPSGraphExecutable() namespace mps { namespace delegate { -void printLoadedGraph(MPSGraphExecutable* executable) { - NSLog(@"Loaded graph: %@", [executable debugDescription]); -} - -MPSGraphExecutable* loadExecutable( - const void* buffer_pointer, - size_t num_bytes) { - ExirMPSGraphPackage* exirMPSGraphPackage = (ExirMPSGraphPackage*)buffer_pointer; - NSData *new_manifest_plist_data = [NSData dataWithBytes:exirMPSGraphPackage->data length:exirMPSGraphPackage->model_0_offset]; - NSData *new_model_0_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset length:exirMPSGraphPackage->model_1_offset - exirMPSGraphPackage->model_0_offset]; - NSData *new_model_1_data = [NSData dataWithBytes:exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset length:exirMPSGraphPackage->total_bytes - sizeof(ExirMPSGraphPackage) - exirMPSGraphPackage->model_1_offset]; - - NSError* error = nil; - NSString* packageName = [NSString stringWithUTF8String:( - std::string("%@/mpsgraphmodule_") + std::to_string(arc4random_uniform(INT_MAX)) + ".mpsgraphpackage").c_str()]; -#if TARGET_OS_IPHONE - NSArray *paths = NSSearchPathForDirectoriesInDomains - (NSDocumentDirectory, NSUserDomainMask, YES); - NSString *documentsDirectory = [paths objectAtIndex:0]; -#else - NSString *documentsDirectory = @"/tmp"; -#endif - - NSString *dataFileNSStr = [NSString stringWithFormat:packageName, - documentsDirectory]; - - NSString* manifestFileStr = [NSString stringWithFormat:@"%@/manifest.plist", dataFileNSStr]; - NSString* model0FileStr = [NSString stringWithFormat:@"%@/model_0.mpsgraph", dataFileNSStr]; - NSString* model1FileStr = [NSString stringWithFormat:@"%@/model_1.mpsgraph", dataFileNSStr]; - - NSFileManager *fileManager= [NSFileManager defaultManager]; - [fileManager createDirectoryAtPath:dataFileNSStr withIntermediateDirectories:NO attributes:nil error:&error]; - - [new_manifest_plist_data writeToFile:manifestFileStr options:NSDataWritingAtomic error:&error]; - [new_model_0_data writeToFile:model0FileStr options:NSDataWritingAtomic error:&error]; - [new_model_1_data writeToFile:model1FileStr options:NSDataWritingAtomic error:&error]; - - NSURL *bundleURL = [NSURL fileURLWithPath:dataFileNSStr]; - MPSGraphCompilationDescriptor *compilationDescriptor = [MPSGraphCompilationDescriptor new]; - MPSGraphExecutable *newExec = [[MPSGraphExecutable new] initWithMPSGraphPackageAtURL:bundleURL compilationDescriptor:compilationDescriptor]; - - return newExec; -} - /* Builds the mps runtime object using the buffer pointer. The buffer pointer must be a valid pointer to the serialized mps object. @@ -82,26 +47,23 @@ void printLoadedGraph(MPSGraphExecutable* executable) { Error err = Error::Ok; - id mpsCD = NSClassFromString(@"MPSGraph"); - static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(imToColWithSourceTensor:descriptor:name:)] == YES; + std::unique_ptr mpsGraphBuilder(new MPSGraphBuilder(buffer_pointer)); + err = mpsGraphBuilder->compileModel(); ET_CHECK_OR_RETURN_ERROR( - _macos_14_0_plus, - NotSupported, - "MPS Executorch runtime is supported only from macOS 14.0 and above."); + err == Error::Ok, Internal, "Failed to construct the MPS graph object"); - MPSGraphExecutable* executable = loadExecutable(buffer_pointer, num_bytes); + executor->executable_ = mpsGraphBuilder->getMPSGraphExecutable(); ET_CHECK_OR_RETURN_ERROR( - executable != nil, + executor->executable_ != nil, InvalidProgram, - "Invalid flatbuffer contents - could not deserialize MPSGraphExecutable"); + "Invalid FlatBuffer contents - could not create MPSGraphExecutable"); - executor->inputShapes_ = [[executable getInputShapes] retain]; - executor->outputShapes_ = [[executable getOutputShapes] retain]; + executor->inputShapes_ = [[executor->executable_ getInputShapes] retain]; + executor->outputShapes_ = [[executor->executable_ getOutputShapes] retain]; - ET_LOG(Info, "Num inputs: %lu", [executor->inputShapes_ count]); - ET_LOG(Info, "Num outputs: %lu", [executor->outputShapes_ count]); + ET_LOG(Debug, "MPSGraphExecutable num inputs: %lu", [executor->inputShapes_ count]); + ET_LOG(Debug, "MPSGraphExecutable num outputs: %lu", [executor->outputShapes_ count]); - executor->executable_ = executable; return err; } diff --git a/backends/apple/mps/runtime/MPSExecutor.h b/backends/apple/mps/runtime/MPSExecutor.h index 5aefa53b22f..172ae37e401 100644 --- a/backends/apple/mps/runtime/MPSExecutor.h +++ b/backends/apple/mps/runtime/MPSExecutor.h @@ -5,13 +5,13 @@ // clang-format off #pragma once -#import -#include -#include #include #include +#include +#include + #include #include #include diff --git a/backends/apple/mps/runtime/MPSExecutor.mm b/backends/apple/mps/runtime/MPSExecutor.mm index f83a61f2efe..0311214bf58 100644 --- a/backends/apple/mps/runtime/MPSExecutor.mm +++ b/backends/apple/mps/runtime/MPSExecutor.mm @@ -3,12 +3,10 @@ // Provided subject to the LICENSE file in the top level directory. // -#define EXIR_MPS_DELEGATE 1 - -#include -#include - -#include "MPSExecutor.h" +#include +#include +#include +#include @interface MPSNDArray () -(nonnull instancetype) initWithBuffer:(id _Nonnull) buffer @@ -47,7 +45,7 @@ @interface MPSNDArrayDescriptor () MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:[inputShapes_[i] dataType] shape:[inputShapes_[i] shape]]; tensorDesc.preferPackedRows = YES; - id inputBuffer = ::mps::getMTLBufferStorage(*inputs[i]); + id inputBuffer = getMTLBufferStorage(*inputs[i]); MPSNDArray *ndArrayData = [[MPSNDArray alloc] initWithBuffer:inputBuffer descriptor:tensorDesc]; MPSGraphTensorData* tensorData = [[MPSGraphTensorData alloc] initWithMPSNDArray:ndArrayData]; [inputsArray_ addObject:tensorData]; @@ -57,7 +55,7 @@ @interface MPSNDArrayDescriptor () MPSNDArrayDescriptor *tensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:[outputShapes_[i] dataType] shape:[outputShapes_[i] shape]]; tensorDesc.preferPackedRows = YES; - id outputBuffer = ::mps::getMTLBufferStorage(*outputs[i]); + id outputBuffer = getMTLBufferStorage(*outputs[i]); MPSNDArray *ndArrayData = [[MPSNDArray alloc] initWithBuffer:outputBuffer descriptor:tensorDesc]; MPSGraphTensorData* tensorData = [[MPSGraphTensorData alloc] initWithMPSNDArray:ndArrayData]; [outputsArray_ addObject:tensorData]; @@ -72,12 +70,20 @@ @interface MPSNDArrayDescriptor () __ET_NODISCARD Error MPSExecutor::forward(std::vector& outputs) { Error err = Error::Ok; MPSStream* mpsStream = getDefaultMPSStream(); - id commandBuffer = mpsStream->commandBuffer(); - [executable_ encodeToCommandBuffer:commandBuffer - inputsArray:inputsArray_ - resultsArray:outputsArray_ - executionDescriptor:nil]; - if (mps::delegate::getDefaultMPSStream()->commitAndContinueEnabled()) { + if (mpsStream->commitAndContinueEnabled() || mpsStream->hasLiveCommandBuffer() || true) { + id commandBuffer = mpsStream->commandBuffer(); + [executable_ encodeToCommandBuffer:commandBuffer + inputsArray:inputsArray_ + resultsArray:outputsArray_ + executionDescriptor:nil]; + } else { + [executable_ runWithMTLCommandQueue:mpsStream->commandQueue() + inputsArray:inputsArray_ + resultsArray:outputsArray_ + executionDescriptor:nil]; + } + + if (mpsStream->commitAndContinueEnabled()) { err = mpsStream->synchronize(SyncType::COMMIT_AND_CONTINUE); } else { err = mpsStream->synchronize(SyncType::COMMIT_AND_WAIT); diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.h b/backends/apple/mps/runtime/MPSGraphBuilder.h new file mode 100644 index 00000000000..7a710e3f560 --- /dev/null +++ b/backends/apple/mps/runtime/MPSGraphBuilder.h @@ -0,0 +1,183 @@ +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#pragma once + +// Obj-C headers +#include +#include +#include +#include + +// Runtime headers +#include +#include + +// MPS headers +#include +#include + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +using DataType = mpsgraph::MPSDataType; +using TensorPtr = const mpsgraph::MPSTensor *; +using NodePtr = const mpsgraph::MPSNode *; + +#define _DEFINE_MPS_OP(name) Error mps##name##Op(NodePtr nodePtr); + +/** + * Helper class to construct a MPSGraph object from a serialized MPS FlatBuffer model. + * It records all the input placeholders, lifted weights/biases and output feeds. + */ +class MPSGraphBuilder { +public: + MPSGraphBuilder(const void *buffer_pointer); + ~MPSGraphBuilder() = default; + + Error compileModel(); + MPSGraph *getMPSGraph(); + MPSGraphExecutable *getMPSGraphExecutable(); + +private: + // Input feeds & constant ops + Error mpsGraphRankedPlaceholder(int32_t id); + Error mpsConstantOp(int32_t id); + // Activation ops + _DEFINE_MPS_OP(HardTanh); + _DEFINE_MPS_OP(ReLU); + _DEFINE_MPS_OP(GELU); + _DEFINE_MPS_OP(LeakyReLU); + _DEFINE_MPS_OP(Softmax); + _DEFINE_MPS_OP(LogSoftmax); + // Arithmetic Binary Ops + _DEFINE_MPS_OP(Add); + _DEFINE_MPS_OP(Sub); + _DEFINE_MPS_OP(Mul); + _DEFINE_MPS_OP(Div); + _DEFINE_MPS_OP(Pow); + _DEFINE_MPS_OP(Fmod); + _DEFINE_MPS_OP(Remainder); + _DEFINE_MPS_OP(BitwiseAnd); + _DEFINE_MPS_OP(BitwiseOr); + _DEFINE_MPS_OP(BitwiseXor); + _DEFINE_MPS_OP(Minimum); + // Comparison ops + _DEFINE_MPS_OP(Eq); + _DEFINE_MPS_OP(Ne); + _DEFINE_MPS_OP(Ge); + _DEFINE_MPS_OP(Gt); + _DEFINE_MPS_OP(Le); + _DEFINE_MPS_OP(Lt); + // Unary ops + _DEFINE_MPS_OP(Exp); + _DEFINE_MPS_OP(Exp2); + _DEFINE_MPS_OP(Reciprocal); + _DEFINE_MPS_OP(Sqrt); + _DEFINE_MPS_OP(Neg); + _DEFINE_MPS_OP(Log); + _DEFINE_MPS_OP(Log10); + _DEFINE_MPS_OP(Log2); + _DEFINE_MPS_OP(Erf); + _DEFINE_MPS_OP(Floor); + _DEFINE_MPS_OP(Ceil); + _DEFINE_MPS_OP(Rsqrt); + _DEFINE_MPS_OP(Sigmoid); + _DEFINE_MPS_OP(Sin); + _DEFINE_MPS_OP(Sign); + _DEFINE_MPS_OP(Cos); + _DEFINE_MPS_OP(Tan); + _DEFINE_MPS_OP(Abs); + _DEFINE_MPS_OP(Asin); + _DEFINE_MPS_OP(Acos); + _DEFINE_MPS_OP(Atan); + _DEFINE_MPS_OP(Sinh); + _DEFINE_MPS_OP(Cosh); + _DEFINE_MPS_OP(Tanh); + _DEFINE_MPS_OP(Asinh); + _DEFINE_MPS_OP(Acosh); + _DEFINE_MPS_OP(Atanh); + _DEFINE_MPS_OP(BitwiseNot); + _DEFINE_MPS_OP(Isnan); + _DEFINE_MPS_OP(Isinf); + _DEFINE_MPS_OP(Round); + _DEFINE_MPS_OP(NormCdf); + // Clamp ops + _DEFINE_MPS_OP(Clamp); + _DEFINE_MPS_OP(Where); + // BitWise ops + // Convolution ops + _DEFINE_MPS_OP(Conv2D); + _DEFINE_MPS_OP(DepthwiseConv2D); + // Indexing ops + _DEFINE_MPS_OP(IndexSelect); + _DEFINE_MPS_OP(Embedding); + // Linear algebra ops + _DEFINE_MPS_OP(MatMul); + _DEFINE_MPS_OP(Addmm); + // Constant ops + _DEFINE_MPS_OP(Full); + _DEFINE_MPS_OP(FullLike); + // Normalization ops + _DEFINE_MPS_OP(BatchNorm); + _DEFINE_MPS_OP(LayerNorm); + // Reduce ops + _DEFINE_MPS_OP(Mean); + // Shape ops + _DEFINE_MPS_OP(Permute); + _DEFINE_MPS_OP(View); + _DEFINE_MPS_OP(Expand); + _DEFINE_MPS_OP(Cat); + _DEFINE_MPS_OP(Squeeze); + _DEFINE_MPS_OP(Unsqueeze); + _DEFINE_MPS_OP(Select); + _DEFINE_MPS_OP(Slice); + _DEFINE_MPS_OP(PixelShuffle); + _DEFINE_MPS_OP(SplitWithSizes); + _DEFINE_MPS_OP(Cast); + // Pooling ops + _DEFINE_MPS_OP(MaxPool2DWithIndices); + _DEFINE_MPS_OP(AvgPool2D); + // Pad ops + _DEFINE_MPS_OP(ConstantPadND); + // Range ops + _DEFINE_MPS_OP(Arange); + + // Helper functions + Error addNodeToMPSGraph(NodePtr nodePtr); + MPSShape *getMPSShape(int32_t id); + MPSShape *getMPSShape(const flatbuffers::Vector *shape); + int64_t numel(const flatbuffers::Vector *shape); + MPSDataType getMPSDataType(int32_t id); + MPSDataType getMPSDataType(DataType serializedDataType); + MPSGraphTensor *getMPSGraphTensor(int32_t id); + NSData *getConstantData(int32_t id); + std::pair getMinMaxValues(NodePtr nodePtr); + + // Each MPSGraph op result in at least MPSGraphTensor being + // produced, which will be stored in this structure. Other ops + // can reference the saved tensor by the AOT id (1:1 mapping). + std::vector _idToMPSGraphTensor; + // FlatBuffer serialized graph containing the nodes from the original model. + const mpsgraph::MPSGraph *_flatBufferGraph; + // FlatBuffer raw bytes of the serialized MPS model. + const void *_buffer_pointer; + + MPSGraph *_mpsGraph; + MPSGraphExecutable *_mpsGraphExecutable; + NSMutableDictionary *_feeds; + NSMutableArray *_targetTensors; +}; + +#undef _DEFINE_MPS_OP + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.mm b/backends/apple/mps/runtime/MPSGraphBuilder.mm new file mode 100644 index 00000000000..df9aa15f654 --- /dev/null +++ b/backends/apple/mps/runtime/MPSGraphBuilder.mm @@ -0,0 +1,115 @@ +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +MPSGraphBuilder::MPSGraphBuilder(const void* buffer_pointer) : _buffer_pointer(buffer_pointer) { + _mpsGraph = [MPSGraph new]; + _feeds = [NSMutableDictionary dictionary]; + _targetTensors = [NSMutableArray new]; + + _mpsGraphExecutable = nil; +} + +Error +MPSGraphBuilder::compileModel() { + Error err = Error::Ok; + + ET_CHECK(_buffer_pointer != nullptr); + ET_CHECK_OR_RETURN_ERROR( + mpsgraph::MPSGraphBufferHasIdentifier(_buffer_pointer), + DelegateInvalidCompatibility, + "MPS Delegate Serialization Format version identifier '%.4s' != expected '%.4s'", + flatbuffers::GetBufferIdentifier(_buffer_pointer), + mpsgraph::MPSGraphIdentifier()); + + _flatBufferGraph = mpsgraph::GetMPSGraph(_buffer_pointer); + _idToMPSGraphTensor.resize(_flatBufferGraph->mps_values()->size(), nullptr); + + // Add the placeholder nodes to the graph. + for (auto in_id : *_flatBufferGraph->input_ids()) { + err = mpsGraphRankedPlaceholder(in_id); + if (err != Error::Ok) { + return err; + } + } + + // Parse all the serialized constant values and add them to MPSGraph. + for (auto constant_id : *_flatBufferGraph->constant_ids()) { + err = mpsConstantOp(constant_id); + if (err != Error::Ok) { + return err; + } + } + + // Create the corresponding MPSGraph ops of the serialized nodes from the FlatBuffer. + for (auto node : *_flatBufferGraph->mps_nodes()) { + err = addNodeToMPSGraph(node); + if (err != Error::Ok) { + return err; + } + } + + // Add the output nodes to the MPSGraphExecutable. + for (auto out_id : *_flatBufferGraph->output_ids()) { + ET_CHECK_OR_RETURN_ERROR( + _idToMPSGraphTensor[out_id] != nil, + InvalidState, + "Failed to deserialize the model"); + + [_targetTensors addObject: _idToMPSGraphTensor[out_id]]; + } + + return err; +} + +Error +MPSGraphBuilder::mpsGraphRankedPlaceholder(int32_t id) { + ET_LOG(Debug, "%s: %d", __FUNCTION__, id); + MPSShape* mpsShape = getMPSShape(id); + MPSDataType mpsDataType = getMPSDataType(id); + _idToMPSGraphTensor[id] = [_mpsGraph placeholderWithShape:mpsShape + dataType:mpsDataType + name:nil]; + _feeds[_idToMPSGraphTensor[id]] = [[MPSGraphShapedType alloc] initWithShape:mpsShape + dataType:mpsDataType]; + return Error::Ok; +} + +MPSGraph* +MPSGraphBuilder::getMPSGraph() { + return _mpsGraph; +} + +MPSGraphExecutable* +MPSGraphBuilder::getMPSGraphExecutable() { + if (_mpsGraphExecutable) { + return _mpsGraphExecutable; + } + + _mpsGraphExecutable = [_mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:MPSDevice::getInstance()->device()] + feeds:_feeds + targetTensors:_targetTensors + targetOperations:nil + compilationDescriptor:nil]; + + + // [_mpsGraphExecutable specializeWithDevice:[MPSGraphDevice deviceWithMTLDevice:MPSDevice::getInstance()->device()] + // inputTypes:[_feeds allValues] + // compilationDescriptor:nil]; + + return _mpsGraphExecutable; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/MPSStream.h b/backends/apple/mps/runtime/MPSStream.h index 18c386a1b39..84d9748b57f 100644 --- a/backends/apple/mps/runtime/MPSStream.h +++ b/backends/apple/mps/runtime/MPSStream.h @@ -37,6 +37,7 @@ class MPSStream { return _serialQueue; } + bool hasLiveCommandBuffer(); MPSCommandBuffer* commandBuffer(); id commandEncoder(); void endKernelCoalescing(); diff --git a/backends/apple/mps/runtime/MPSStream.mm b/backends/apple/mps/runtime/MPSStream.mm index d9445c3eb62..75ff6eac2b3 100644 --- a/backends/apple/mps/runtime/MPSStream.mm +++ b/backends/apple/mps/runtime/MPSStream.mm @@ -53,6 +53,10 @@ @interface MPSGraphExecutionDescriptor () assert(_commandBuffer == nil); } +bool MPSStream::hasLiveCommandBuffer() { + return _commandBuffer; +} + MPSCommandBuffer* MPSStream::commandBuffer() { if (!_commandBuffer) { _commandBuffer = [MPSCommandBuffer commandBufferFromCommandQueue:_commandQueue].retain; diff --git a/backends/apple/mps/runtime/operations/ActivationOps.mm b/backends/apple/mps/runtime/operations/ActivationOps.mm new file mode 100644 index 00000000000..002c1cd83cb --- /dev/null +++ b/backends/apple/mps/runtime/operations/ActivationOps.mm @@ -0,0 +1,155 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsHardTanhOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSHardTanh(); + + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + float minValue = graphNode->min_value(); + float maxValue = graphNode->max_value(); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + + MPSDataType inputType = [inputTensor dataType]; + MPSShape* inputShape = [inputTensor shape]; + MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minValue shape:inputShape dataType:inputType]; + MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:maxValue shape:inputShape dataType:inputType]; + MPSGraphTensor* lessThanMinPredicateTensor = [_mpsGraph lessThanWithPrimaryTensor:inputTensor + secondaryTensor:minTensor + name:@"LessThanPredicate"]; + MPSGraphTensor* greaterThanMaxPredicateTensor = [_mpsGraph greaterThanWithPrimaryTensor:inputTensor + secondaryTensor:maxTensor + name:@"MoreThanPredicate"]; + + MPSGraphTensor* temp = [_mpsGraph selectWithPredicateTensor:lessThanMinPredicateTensor + truePredicateTensor:minTensor + falsePredicateTensor:inputTensor + name:@"minOutput"]; + + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph selectWithPredicateTensor:greaterThanMaxPredicateTensor + truePredicateTensor:maxTensor + falsePredicateTensor:temp + name:@"hardTanh"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsReLUOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSReLU(); + + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph reLUWithTensor:getMPSGraphTensor(graphNode->input1_id()) + name:@"relu"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsGELUOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSGELU(); + std::string approximation = graphNode->approximate()->str(); + Error status = Error::Ok; + + ET_LOG( + Debug, "%s: %d (%s) -> %d", + __FUNCTION__, graphNode->input1_id(), approximation.c_str(), graphNode->output_id() + ); + + if (approximation == "tanh") { + status = mpsTanhOp(nodePtr); + } else { + status = mpsNormCdfOp(nodePtr); + } + + ET_CHECK_OR_RETURN_ERROR( + status == Error::Ok, + Internal, + "[ERROR] Couldn't add GELU node to MPSGraph"); + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph multiplicationWithPrimaryTensor:_idToMPSGraphTensor[graphNode->output_id()] + secondaryTensor:getMPSGraphTensor(graphNode->input1_id()) + name:nil]; + + return status; +} + +Error +MPSGraphBuilder::mpsLeakyReLUOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSLeakyReLU(); + + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph leakyReLUWithTensor:getMPSGraphTensor(graphNode->input1_id()) + alpha:graphNode->negative_slope() + name:@"leaky_relu"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsSoftmaxOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSSoftmax(); + + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS"); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id()) + axis:graphNode->dim() + name:@"softmax"]; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsLogSoftmaxOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSLogSoftmax(); + + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + ET_CHECK_MSG(!graphNode->half_to_float(), "softmax with half to float conversion is not supported on MPS"); + + MPSGraphTensor* softmaxTensor = [_mpsGraph softMaxWithTensor:getMPSGraphTensor(graphNode->input1_id()) + axis:graphNode->dim() + name:@"softmax"]; + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph logarithmWithTensor:softmaxTensor + name:@"log_softmax"]; + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/BinaryOps.mm b/backends/apple/mps/runtime/operations/BinaryOps.mm new file mode 100644 index 00000000000..a0771ff16b6 --- /dev/null +++ b/backends/apple/mps/runtime/operations/BinaryOps.mm @@ -0,0 +1,285 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +MPSGraphTensor* +binaryOpTensor( + MPSGraphTensor* primaryTensor, + MPSGraphTensor* secondaryTensor, + MPSGraph* mpsGraph, + std::function binaryOpFunction) { + MPSDataType mpsInputDataType = [primaryTensor dataType]; + MPSDataType mpsOtherDataType = [secondaryTensor dataType]; + + exec_aten::ScalarType inputDataType = getScalarType(mpsInputDataType); + exec_aten::ScalarType otherDataType = getScalarType(mpsOtherDataType); + + MPSGraphTensor* primaryCastTensor = primaryTensor; + MPSGraphTensor* secondaryCastTensor = secondaryTensor; + exec_aten::ScalarType commonDataType = promoteTypes(inputDataType, otherDataType); + if (inputDataType != commonDataType) { + primaryCastTensor = castMPSTensor(mpsGraph, primaryTensor, commonDataType); + } + if (otherDataType != commonDataType) { + secondaryCastTensor = castMPSTensor(mpsGraph, secondaryTensor, commonDataType); + } + + return binaryOpFunction(primaryCastTensor, secondaryCastTensor); +} + +/* +Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer. +It takes 2 inputs, an alpha parameter and returns one output. Couple operators from PyTorch, +such as torch.sub, torch.add take an additional alpha param. +More info at https://pytorch.org/docs/stable/generated/torch.sub.html. +*/ +#define REGISTER_BINARY_WITH_ALPHA_OP(aot_name, graph_op) \ +Error \ +MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ +auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ + ET_LOG( \ + Debug, "%s: (%d, %d) -> %d", \ + __FUNCTION__, \ + graphNode->input1_id(), \ + graphNode->input2_id(), \ + graphNode->output_id() \ + ); \ + \ + _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ + getMPSGraphTensor(graphNode->input1_id()), \ + getMPSGraphTensor(graphNode->input2_id()), \ + _mpsGraph, \ + [&](MPSGraphTensor* primaryCastTensor, \ + MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ + if (graphNode->alpha() != 1.0) { \ + MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:graphNode->alpha() \ + shape:@[@1] \ + dataType:primaryCastTensor.dataType]; \ + secondaryCastTensor = [_mpsGraph multiplicationWithPrimaryTensor:secondaryCastTensor \ + secondaryTensor:alphaTensor \ + name:nil]; \ + } \ + return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + } \ + ); \ + return Error::Ok; \ +} + +/* +Helper macro to create an MPSGraph node based on the serialized data from the FlatBuffer. +It takes 2 inputs and returns one output. +*/ +#define REGISTER_BINARY_OP(aot_name, graph_op) \ +Error \ +MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ +auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ + ET_LOG( \ + Debug, "%s: (%d, %d) -> %d", \ + __FUNCTION__, \ + graphNode->input1_id(), \ + graphNode->input2_id(), \ + graphNode->output_id() \ + ); \ + \ + _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ + getMPSGraphTensor(graphNode->input1_id()), \ + getMPSGraphTensor(graphNode->input2_id()), \ + _mpsGraph, \ + [&](MPSGraphTensor* primaryCastTensor, \ + MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ + return [_mpsGraph graph_op##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + } \ + ); \ + \ + return Error::Ok; \ +} + +#define REGISTER_BITWISE_BINARY_OP(aot_name, graph_op) \ +Error \ +MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ +auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ + ET_LOG( \ + Debug, "%s: (%d, %d) -> %d", \ + __FUNCTION__, \ + graphNode->input1_id(), \ + graphNode->input2_id(), \ + graphNode->output_id() \ + ); \ + \ + _idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \ + getMPSGraphTensor(graphNode->input1_id()), \ + getMPSGraphTensor(graphNode->input2_id()), \ + _mpsGraph, \ + [&](MPSGraphTensor* primaryCastTensor, \ + MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { \ + MPSDataType mpsInputDataType = [primaryCastTensor dataType]; \ + if (getScalarType(mpsInputDataType) == ScalarType::Bool) { \ + return [_mpsGraph logical##graph_op##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + } \ + return [_mpsGraph bitwise##graph_op##WithPrimaryTensor:primaryCastTensor \ + secondaryTensor:secondaryCastTensor \ + name:nil]; \ + } \ + ); \ + \ + return Error::Ok; \ +} + +// Arithmetic Binary Ops +REGISTER_BINARY_WITH_ALPHA_OP(Add, addition) +REGISTER_BINARY_WITH_ALPHA_OP(Sub, subtraction) +REGISTER_BINARY_OP(Mul, multiplication) +REGISTER_BINARY_OP(Pow, power) +REGISTER_BINARY_OP(Minimum, minimum) + +// Boolean Binary ops +REGISTER_BINARY_OP(Eq, equal) +REGISTER_BINARY_OP(Ne, notEqual) +REGISTER_BINARY_OP(Ge, greaterThanOrEqualTo) +REGISTER_BINARY_OP(Gt, greaterThan) +REGISTER_BINARY_OP(Le, lessThanOrEqualTo) +REGISTER_BINARY_OP(Lt, lessThan) + +// Bitwise Binary ops +REGISTER_BITWISE_BINARY_OP(BitwiseAnd, AND) +REGISTER_BITWISE_BINARY_OP(BitwiseOr, OR) +REGISTER_BITWISE_BINARY_OP(BitwiseXor, XOR) + +#undef REGISTER_BINARY_WITH_ALPHA_OP +#undef REGISTER_BINARY_OP + +static +MPSGraphTensor* mpsTruncTensor(MPSGraphTensor* inputTensor, MPSGraph* mpsGraph) { + // Rounding is a no-op for integral types, and also a reasonable workaround + // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` + // See https://github.com/pytorch/pytorch/issues/84995 + bool isFloatInput = ([inputTensor dataType] & MPSDataTypeFloatBit) != 0; + if (!isFloatInput) { + return inputTensor; + } + + return [mpsGraph truncateWithTensor:inputTensor + name:nil]; +}; + +static +MPSGraphTensor* divModeTemplate( + MPSGraphTensor* primaryTensor, + MPSGraphTensor* secondaryTensor, + std::optional rounding_mode, + MPSGraph* mpsGraph, + const std::string& op_name) { + MPSDataType mpsInputDataType = [primaryTensor dataType]; + MPSDataType mpsOtherDataType = [secondaryTensor dataType]; + + ScalarType inputDataType = getScalarType(mpsInputDataType); + ScalarType otherDataType = getScalarType(mpsOtherDataType); + + if(rounding_mode.has_value() && *rounding_mode == "trunc"){ + ET_CHECK_MSG(inputDataType != ScalarType::Half, + "MPS: does not support trunc_divide op with float16 input"); + } + + auto divOpFunc = [&](MPSGraphTensor* primaryCastTensor, + MPSGraphTensor* secondaryCastTensor) -> MPSGraphTensor* { + bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0; + if (!isFloatInput && rounding_mode.has_value() && (*rounding_mode == "floor" || *rounding_mode == "trunc")) { + primaryCastTensor = [mpsGraph castTensor:primaryCastTensor + toType:MPSDataTypeFloat32 + name:@"primaryCastTensor"]; + secondaryCastTensor = [mpsGraph castTensor:secondaryCastTensor + toType:MPSDataTypeFloat32 + name:@"secondaryCastTensor"]; + } + MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryCastTensor + secondaryTensor:secondaryCastTensor + name:nil]; + + // Rounding is a no-op for integral types, and also a reasonable workaround + // For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library` + // See https://github.com/pytorch/pytorch/issues/84995 + bool isFloatOutput = ([divTensor dataType] & MPSDataTypeFloatBit) != 0; + if (!rounding_mode.has_value() || !isFloatOutput) { + return divTensor; + } else if (*rounding_mode == "trunc") { + auto truncTensor = mpsTruncTensor(divTensor, mpsGraph); + if (op_name == "Fmod") { + auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:truncTensor + secondaryTensor:secondaryCastTensor + name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor + secondaryTensor:mulTensor + name:nil]; + } + return truncTensor; + } else if (*rounding_mode == "floor") { + MPSGraphTensor* floorTensor = [mpsGraph floorWithTensor:divTensor name:nil]; + if (op_name == "Remainder") { + auto mulTensor = [mpsGraph multiplicationWithPrimaryTensor:floorTensor + secondaryTensor:secondaryCastTensor + name:nil]; + return [mpsGraph subtractionWithPrimaryTensor:primaryCastTensor + secondaryTensor:mulTensor + name:nil]; + } + return floorTensor; + } else { + assert(0 && "Invalid rounding mode\n"); + } + return nullptr; + }; + return binaryOpTensor(primaryTensor, secondaryTensor, mpsGraph, divOpFunc); +} + +#define REGISTER_DIV_OP(aot_name, round_mode) \ +Error \ +MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ + auto graphNode = nodePtr->mpsnode_union_as_MPS##aot_name(); \ + ET_LOG( \ + Debug, "%s: (%d, %d) -> %d", \ + __FUNCTION__, \ + graphNode->input1_id(), \ + graphNode->input2_id(), \ + graphNode->output_id() \ + ); \ + \ + auto strView = graphNode->rounding_mode() != nullptr ? \ + std::make_optional(graphNode->rounding_mode()->string_view()) : round_mode; \ + \ + _idToMPSGraphTensor[graphNode->output_id()] = divModeTemplate( \ + getMPSGraphTensor(graphNode->input1_id()), \ + getMPSGraphTensor(graphNode->input2_id()), \ + strView, \ + _mpsGraph, \ + #aot_name \ + ); \ + \ + return Error::Ok; \ +} + +REGISTER_DIV_OP(Div, std::nullopt) +REGISTER_DIV_OP(Fmod, "trunc") +REGISTER_DIV_OP(Remainder, "floor") + +#undef REGISTER_DIV_OP + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/ClampOps.mm b/backends/apple/mps/runtime/operations/ClampOps.mm new file mode 100644 index 00000000000..c993232a292 --- /dev/null +++ b/backends/apple/mps/runtime/operations/ClampOps.mm @@ -0,0 +1,94 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsClampOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSClamp(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + + std::pair minMaxValues = getMinMaxValues(nodePtr); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + bool useMin = minMaxValues.first != -INF; + bool useMax = minMaxValues.second != INF; + + if (useMin && useMax) { + // Both min and max values are set + MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first + shape:inputTensor.shape + dataType:inputTensor.dataType]; + MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second + shape:inputTensor.shape + dataType:inputTensor.dataType]; + + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph clampWithTensor:inputTensor + minValueTensor:minTensor + maxValueTensor:maxTensor + name:@"clamp"]; + } else if (useMin && !useMax) { + // Only min is set + MPSGraphTensor* minTensor = [_mpsGraph constantWithScalar:minMaxValues.first + shape:inputTensor.shape + dataType:inputTensor.dataType]; + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph maximumWithPrimaryTensor:inputTensor + secondaryTensor:minTensor + name:nil]; + } else if (!useMin && useMax) { + // Only max is set + MPSGraphTensor* maxTensor = [_mpsGraph constantWithScalar:minMaxValues.second + shape:inputTensor.shape + dataType:inputTensor.dataType]; + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph minimumWithPrimaryTensor:inputTensor + secondaryTensor:maxTensor + name:nil]; + } + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsWhereOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSWhere(); + ET_LOG( + Debug, "%s: (%d, %d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->input3_id(), + graphNode->output_id() + ); + + MPSGraphTensor* condition = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* input = getMPSGraphTensor(graphNode->input2_id()); + MPSGraphTensor* other = getMPSGraphTensor(graphNode->input3_id()); + + if ([condition dataType] != MPSDataTypeBool) { + condition = [_mpsGraph castTensor:condition + toType:MPSDataTypeBool + name:@"condition"]; + } + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph selectWithPredicateTensor:condition + truePredicateTensor:input + falsePredicateTensor:other + name:nil]; + return Error::Ok; +} + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/ConstantOps.mm b/backends/apple/mps/runtime/operations/ConstantOps.mm new file mode 100644 index 00000000000..158a98c0383 --- /dev/null +++ b/backends/apple/mps/runtime/operations/ConstantOps.mm @@ -0,0 +1,63 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsConstantOp(int32_t id) { + _idToMPSGraphTensor[id] = [_mpsGraph constantWithData:getConstantData(id) + shape:getMPSShape(id) + dataType:getMPSDataType(id)]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsFullOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSFull(); + ET_LOG( + Debug, "%s: - -> %d", + __FUNCTION__, graphNode->output_id() + ); + + if (numel(graphNode->shape()) == 0) { + _idToMPSGraphTensor[graphNode->output_id()] = nil; + } else { + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph constantWithScalar:graphNode->fill_value() + shape:getMPSShape(graphNode->shape()) + dataType:getMPSDataType(graphNode->dtype())]; + } + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsFullLikeOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSFullLike(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph constantWithScalar:graphNode->fill_value() + shape:getMPSGraphTensor(graphNode->input1_id()).shape + dataType:getMPSDataType(graphNode->dtype())]; + + return Error::Ok; +} + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/ConvolutionOps.mm b/backends/apple/mps/runtime/operations/ConvolutionOps.mm new file mode 100644 index 00000000000..991a050e665 --- /dev/null +++ b/backends/apple/mps/runtime/operations/ConvolutionOps.mm @@ -0,0 +1,146 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsDepthwiseConv2DOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSDepthwiseConv2D(); + ET_LOG( + Debug, "%s: (%d, %d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->input3_id(), + graphNode->output_id() + ); + + bool isConv1D = ([getMPSShape(graphNode->input2_id()) count] == 3); + ET_CHECK(!isConv1D); + + MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor = + [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; + + depthWiseConv3dDescriptor.strides = + @[ @1, [[NSNumber alloc] initWithInteger:graphNode->stride_y()], [[NSNumber alloc] initWithInteger:graphNode->stride_x()] ]; + + depthWiseConv3dDescriptor.dilationRates = + @[ @1, [[NSNumber alloc] initWithInteger:graphNode->dilation_y()], [[NSNumber alloc] initWithInteger:graphNode->dilation_x()] ]; + + depthWiseConv3dDescriptor.paddingStyle = MPSGraphPaddingStyleExplicit; + depthWiseConv3dDescriptor.paddingValues = @[ + @0, + @0, + [[NSNumber alloc] initWithInteger:graphNode->padding_top()], + [[NSNumber alloc] initWithInteger:graphNode->padding_bottom()], + [[NSNumber alloc] initWithInteger:graphNode->padding_left()], + [[NSNumber alloc] initWithInteger:graphNode->padding_right()] + ]; + depthWiseConv3dDescriptor.channelDimensionIndex = -3LL; + MPSGraphTensor* weightTransposeTensor = [_mpsGraph transposeTensor:getMPSGraphTensor(graphNode->input2_id()) + dimension:-3 + withDimension:-4 + name:nil]; + MPSGraphTensor* depthwiseConvTensor = [_mpsGraph depthwiseConvolution3DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) + weightsTensor:weightTransposeTensor + descriptor:depthWiseConv3dDescriptor + name:nil]; + // Bias is optional + if (graphNode->input3_id() != -1) { + //Need to add correct dimension to bias to avoid broadcasting issues + MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id()); + biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor + axes:@[@0, @2, @3] + name:nil]; + depthwiseConvTensor = [_mpsGraph additionWithPrimaryTensor:depthwiseConvTensor + secondaryTensor:biasTensor + name:@"depthwiseConv2DWithBiasAdd"]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = depthwiseConvTensor; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsConv2DOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSConv2D(); + ET_LOG( + Debug, "%s: (%d, %d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->input3_id(), + graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input2_id()); + + bool isConv1D = ([weightTensor.shape count] == 3); + if (isConv1D) { + inputTensor = [_mpsGraph expandDimsOfTensor:inputTensor + axis:2 + name:@"unsqueezeInput"]; + weightTensor = [_mpsGraph expandDimsOfTensor:weightTensor + axis:2 + name:@"unsqueezeWeight"]; + } + + MPSGraphConvolution2DOpDescriptor* desc = + [MPSGraphConvolution2DOpDescriptor descriptorWithStrideInX:graphNode->stride_x() + strideInY:graphNode->stride_y() + dilationRateInX:graphNode->dilation_x() + dilationRateInY:graphNode->dilation_y() + groups:graphNode->groups() + paddingLeft:graphNode->padding_left() + paddingRight:graphNode->padding_right() + paddingTop:graphNode->padding_top() + paddingBottom:graphNode->padding_bottom() + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW + weightsLayout:MPSGraphTensorNamedDataLayoutHWIO]; + // Convert weights from OIHW to HWIO. + MPSGraphTensor* weightTransposeTensor = [_mpsGraph transposeTensor:weightTensor + permutation:@[@2, @3, @1, @0] + name:nil]; + + MPSGraphTensor* conv2DTensor = [_mpsGraph convolution2DWithSourceTensor:inputTensor + weightsTensor:weightTransposeTensor + descriptor:desc + name:@"conv2D"]; + + // Bias is optional + if (graphNode->input3_id() != -1) { + // Need to add correct dimension to bias to avoid broadcasting issues + MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input3_id()); + biasTensor = [_mpsGraph expandDimsOfTensor:biasTensor + axes:@[@0,@2,@3] + name:nil]; + conv2DTensor = [_mpsGraph additionWithPrimaryTensor:conv2DTensor + secondaryTensor:biasTensor + name:@"conv2DWithBiasAdd"]; + } + + if (isConv1D) { + conv2DTensor = [_mpsGraph squeezeTensor:conv2DTensor + axis:2 + name:@"squeeze"]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = conv2DTensor; + return Error::Ok; +} + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/IndexingOps.mm b/backends/apple/mps/runtime/operations/IndexingOps.mm new file mode 100644 index 00000000000..224c16b06da --- /dev/null +++ b/backends/apple/mps/runtime/operations/IndexingOps.mm @@ -0,0 +1,114 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + + +MPSGraphTensor* indexSelect( + MPSGraphTensor* inputTensor, + int64_t dim, + MPSGraphTensor* indexTensor, + MPSGraph* mpsGraph) { + + MPSGraphTensor* castIndexTensor = indexTensor; + if(castIndexTensor.dataType != MPSDataTypeInt32) { + castIndexTensor = [mpsGraph castTensor:indexTensor + toType:MPSDataTypeInt32 + name:nil]; + } + + return [mpsGraph gatherWithUpdatesTensor:inputTensor + indicesTensor:castIndexTensor + axis:dim + batchDimensions:0 + name:nil]; +} + +Error +MPSGraphBuilder::mpsIndexSelectOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSIndexSelect(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indexTensor = getMPSGraphTensor(graphNode->index_id()); + MPSGraphTensor* castIndexTensor = indexTensor; + if(castIndexTensor.dataType != MPSDataTypeInt32) { + castIndexTensor = [_mpsGraph castTensor:indexTensor + toType:MPSDataTypeInt32 + name:nil]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph gatherWithUpdatesTensor:inputTensor + indicesTensor:castIndexTensor + axis:graphNode->dim() + batchDimensions:0 + name:nil]; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsEmbeddingOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSEmbedding(); + ET_LOG( + Debug, "%s: (%d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->output_id() + ); + + + MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->input2_id()); + int padding_idx = graphNode->padding_idx(); + + if (padding_idx != -1) { + MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:padding_idx + shape:@[@1] + dataType:indicesTensor.dataType]; + + MPSGraphTensor* notEqualTensor = [_mpsGraph notEqualWithPrimaryTensor:indicesTensor + secondaryTensor:constantTensor + name:nil]; + MPSGraphTensor* condition = [_mpsGraph expandDimsOfTensor:notEqualTensor + axis:-1 + name:@"unsqueeze"]; + MPSGraphTensor* valTensor = indexSelect(weightTensor, 0, indicesTensor, _mpsGraph); + MPSGraphTensor* zeroTensor = [_mpsGraph constantWithScalar:0 + shape:valTensor.shape + dataType:valTensor.dataType]; + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph selectWithPredicateTensor:condition + truePredicateTensor:valTensor + falsePredicateTensor:zeroTensor + name:nil]; + } else { + _idToMPSGraphTensor[graphNode->output_id()] = indexSelect( + getMPSGraphTensor(graphNode->input1_id()), + 0, + getMPSGraphTensor(graphNode->input2_id()), + _mpsGraph + ); + } + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/LinearAlgebra.mm b/backends/apple/mps/runtime/operations/LinearAlgebra.mm new file mode 100644 index 00000000000..21ecc27b297 --- /dev/null +++ b/backends/apple/mps/runtime/operations/LinearAlgebra.mm @@ -0,0 +1,86 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsMatMulOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSMatMul(); + ET_LOG( + Debug, "%s: (%d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph matrixMultiplicationWithPrimaryTensor:getMPSGraphTensor(graphNode->input1_id()) + secondaryTensor:getMPSGraphTensor(graphNode->input2_id()) + name:nil]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsAddmmOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSAddmm(); + ET_LOG( + Debug, "%s: (%d, %d, %d) -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->input2_id(), + graphNode->input3_id(), + graphNode->output_id() + ); + + MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input2_id()); + MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->input3_id()); + float beta = graphNode->beta(); + float alpha = graphNode->alpha(); + + MPSGraphTensor* multiplyTensor = [_mpsGraph matrixMultiplicationWithPrimaryTensor:inputTensor + secondaryTensor:weightTensor + name:@"addmm/matmul"]; + MPSGraphTensor* alphaTimesMultiply = multiplyTensor; + if (alpha != 1.0) { + // assert + MPSGraphTensor* alphaTensor = [_mpsGraph constantWithScalar:alpha + dataType:inputTensor.dataType]; + + alphaTimesMultiply = [_mpsGraph multiplicationWithPrimaryTensor:multiplyTensor + secondaryTensor:alphaTensor + name:@"addmm/alpha*matmul"]; + } + + MPSGraphTensor* betaBiasTensor = biasTensor; + if (beta != 1.0) { + MPSGraphTensor* betaTensor = [_mpsGraph constantWithScalar:beta + dataType:inputTensor.dataType]; + + betaBiasTensor = [_mpsGraph multiplicationWithPrimaryTensor:biasTensor + secondaryTensor:betaTensor + name:@"addmm/beta*bias"]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:alphaTimesMultiply + secondaryTensor:betaBiasTensor + name:@"addmm/beta*bias*alpha*matmul"]; + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/NormalizationOps.mm b/backends/apple/mps/runtime/operations/NormalizationOps.mm new file mode 100644 index 00000000000..e868482eec8 --- /dev/null +++ b/backends/apple/mps/runtime/operations/NormalizationOps.mm @@ -0,0 +1,129 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +Error +MPSGraphBuilder::mpsBatchNormOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSBatchNorm(); + ET_LOG( + Debug, "%s: (%d, %d, %d, %d, %d) -> (%d, %d, %d)", + __FUNCTION__, + graphNode->input_id(), + graphNode->mean_id(), + graphNode->var_id(), + graphNode->weight_id(), + graphNode->bias_id(), + graphNode->output1_id(), + graphNode->output2_id(), + graphNode->output3_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input_id()); + MPSGraphTensor* meanTensor = getMPSGraphTensor(graphNode->mean_id()); + MPSGraphTensor* varTensor = getMPSGraphTensor(graphNode->var_id()); + MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->weight_id()); + MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->bias_id()); + float epsilon = graphNode->epsilon(); + + // Shapes are NCHW so the input parameters to normalization are 1xCx1x1 + NSMutableArray* newShape = [NSMutableArray array]; + [newShape addObject:[NSNumber numberWithInt:1]]; + [newShape addObject:inputTensor.shape[1]]; + for(int i = 2; i<[inputTensor.shape count]; ++i) { + [newShape addObject:[NSNumber numberWithInt:1]]; + } + // No need for momentum since we are not training for now + // TODO: Check if momentum is needed + MPSGraphTensor* reshapedMeanTensor = [_mpsGraph reshapeTensor:meanTensor + withShape:newShape + name:nil]; + MPSGraphTensor* reshapedVarTensor = [_mpsGraph reshapeTensor:varTensor + withShape:newShape + name:nil]; + MPSGraphTensor* reshapedWeightTensor = [_mpsGraph reshapeTensor:weightTensor + withShape:newShape + name:nil]; + MPSGraphTensor* reshapedBiasTensor = [_mpsGraph reshapeTensor:biasTensor + withShape:newShape + name:nil]; + + _idToMPSGraphTensor[graphNode->output1_id()] = [_mpsGraph normalizationWithTensor:inputTensor + meanTensor:reshapedMeanTensor + varianceTensor:reshapedVarTensor + gammaTensor:reshapedWeightTensor + betaTensor:reshapedBiasTensor + epsilon:epsilon + name:@"batch_norm"]; + + //For now just return meanTensor and varTensor assuming this isn't training + + // saveVarTensor + _idToMPSGraphTensor[graphNode->output2_id()] = [_mpsGraph identityWithTensor:varTensor name:nil]; + // saveMeanTensor + _idToMPSGraphTensor[graphNode->output2_id()] = [_mpsGraph identityWithTensor:meanTensor name:nil]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsLayerNormOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSLayerNorm(); + ET_LOG( + Debug, "%s: (%d, %d, %d) -> (%d, %d, %d)", + __FUNCTION__, + graphNode->input1_id(), + graphNode->weight_id(), + graphNode->bias_id(), + graphNode->output1_id(), + graphNode->output2_id(), + graphNode->output3_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* weightTensor = getMPSGraphTensor(graphNode->weight_id()); + MPSGraphTensor* biasTensor = getMPSGraphTensor(graphNode->bias_id()); + const int input_ndim = [inputTensor.shape count]; + const int normalized_shape_ndim = graphNode->normalized_shape()->size(); + const int ndim_to_normalize = input_ndim - normalized_shape_ndim; + + NSMutableArray* axesArray = [NSMutableArray arrayWithCapacity:normalized_shape_ndim]; + for (int32_t idx = ndim_to_normalize; idx < input_ndim; idx++) { + [axesArray addObject:[NSNumber numberWithInt:idx]]; + } + + MPSGraphTensor* meanTensor = [_mpsGraph meanOfTensor:inputTensor + axes:axesArray + name:@"LayerNorm/MeanTensor"]; + + MPSGraphTensor* varianceTensor = [_mpsGraph varianceOfTensor:inputTensor + meanTensor:meanTensor + axes:axesArray + name:@"LayerNorm/varianceTensor"]; + MPSGraphTensor* normalizedTensor = [_mpsGraph normalizationWithTensor:inputTensor + meanTensor:meanTensor + varianceTensor:varianceTensor + gammaTensor:weightTensor + betaTensor:biasTensor + epsilon:graphNode->eps() + name:@"LayerNorm/resultTensor"]; + + _idToMPSGraphTensor[graphNode->output1_id()] = normalizedTensor; + _idToMPSGraphTensor[graphNode->output2_id()] = meanTensor; + _idToMPSGraphTensor[graphNode->output3_id()] = varianceTensor; + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/OperationUtils.h b/backends/apple/mps/runtime/operations/OperationUtils.h new file mode 100644 index 00000000000..34c59b8b7f2 --- /dev/null +++ b/backends/apple/mps/runtime/operations/OperationUtils.h @@ -0,0 +1,54 @@ +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#pragma once + +#import +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +#define INF std::numeric_limits::infinity() + +MPSDataType getMPSScalarType(exec_aten::ScalarType scalar_type); +exec_aten::ScalarType getScalarType(MPSDataType mpsDataType); +MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, exec_aten::ScalarType toType); +MPSGraphTensor *castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor *tensor, MPSDataType toType); +std::vector getMPSShapeVec(const MPSShape *shape); + +template std::vector flatbufferDimsToVector(const flatbuffers::Vector *dims) { + std::vector dimsData; + dimsData.reserve(dims->size()); + for (auto dim : *dims) { + dimsData.push_back(static_cast(dim)); + } + return dimsData; +} + +static inline id getMTLBufferStorage(const Tensor &tensor) { +#if TARGET_OS_SIMULATOR + // Simulator crashes in newBufferWithBytesNoCopy, so we're making a copy of + // the data. + uint8_t *data = tensor.mutable_data_ptr(); + return [MPSDevice::getInstance()->device() newBufferWithBytes:data length:tensor.nbytes() options:0]; +#else + uint8_t *data = tensor.mutable_data_ptr(); + return [MPSDevice::getInstance()->device() newBufferWithBytesNoCopy:data + length:tensor.nbytes() + options:0 + deallocator:nil]; +#endif // TARGET_OS_SIMULATOR +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/OperationUtils.mm b/backends/apple/mps/runtime/operations/OperationUtils.mm new file mode 100644 index 00000000000..524ad1a5b53 --- /dev/null +++ b/backends/apple/mps/runtime/operations/OperationUtils.mm @@ -0,0 +1,283 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +MPSDataType +MPSGraphBuilder::getMPSDataType(int32_t id) { + return getMPSDataType(_flatBufferGraph->mps_values()->Get(id)->datatype()); +} + +MPSDataType +MPSGraphBuilder::getMPSDataType(DataType serializedDataType) { + switch (serializedDataType) { + case DataType::mps_data_type_float16: + return MPSDataTypeFloat16; + case DataType::mps_data_type_float32: + return MPSDataTypeFloat32; + case DataType::mps_data_type_int8: + return MPSDataTypeInt8; + case DataType::mps_data_type_int16: + return MPSDataTypeInt16; + case DataType::mps_data_type_int32: + return MPSDataTypeInt32; + case DataType::mps_data_type_int64: + return MPSDataTypeInt64; + case DataType::mps_data_type_bool: + return MPSDataTypeBool; + default: + ET_CHECK_MSG(false, "[ERROR] Invalid MPS data type: %d!", (int32_t)serializedDataType); + return MPSDataTypeInvalid; + } +} + +MPSShape* +MPSGraphBuilder::getMPSShape(int32_t id) { + TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id); + auto sizes = mpsTensor->dims(); + const int sz = mpsTensor->num_dims(); + const int sz_ = (sz > 0) ? sz : 1; + + std::vector numbers(sz_); + + for (int i = 0; i < sz_; i++) { + NSInteger sz_i = (i < sz) ? sizes->Get(i) : 1; + NSNumber* number = [NSNumber numberWithInteger:sz_i]; + numbers[i] = number; + } + return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; +} + +MPSShape* +MPSGraphBuilder::getMPSShape(const flatbuffers::Vector* shape) { + const int sz = shape->size(); + const int sz_ = (sz > 0) ? sz : 1; + + std::vector numbers(sz_); + + for (int i = 0; i < sz_; i++) { + NSInteger sz_i = (i < sz) ? shape->Get(i) : 1; + NSNumber* number = [NSNumber numberWithInteger:sz_i]; + numbers[i] = number; + } + return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; +} + +int64_t +MPSGraphBuilder::numel(const flatbuffers::Vector* shape) { + int64_t numel = 1; + for (auto dim : *shape) { + numel = numel * dim; + } + return numel; +} + +NSData* +MPSGraphBuilder::getConstantData(int32_t id) { + TensorPtr mpsTensor = _flatBufferGraph->mps_values()->Get(id); + int32_t constantBufferSize = mpsTensor->constant_buffer_size(); + const unsigned char* constantBuffer = mpsTensor->constant_buffer()->storage()->data(); + ET_CHECK_MSG(constantBufferSize > 0 && constantBuffer != nullptr, "[ERROR] Invalid constant buffer"); + return [[NSData alloc] initWithBytes:constantBuffer + length:constantBufferSize]; +} + +std::pair +MPSGraphBuilder::getMinMaxValues(NodePtr nodePtr) { + float minValue = -INF; + float maxValue = INF; + auto minMaxValues = nodePtr->min_max(); + if (minMaxValues != nullptr) { + minValue = minMaxValues->min_value(); + maxValue = minMaxValues->max_value(); + } + + return {minValue, maxValue}; +} + +#define _DEFINE_MPS_NODE(node) \ + case mpsgraph::MPSNodeUnion::MPS##node: \ + return mps##node##Op(nodePtr); + +Error +MPSGraphBuilder::addNodeToMPSGraph(NodePtr nodePtr) { + switch (nodePtr->mpsnode_union_type()) { + // Activation ops + _DEFINE_MPS_NODE(HardTanh); + _DEFINE_MPS_NODE(ReLU); + _DEFINE_MPS_NODE(GELU); + _DEFINE_MPS_NODE(LeakyReLU); + _DEFINE_MPS_NODE(Softmax); + _DEFINE_MPS_NODE(LogSoftmax); + // Binary ops + _DEFINE_MPS_NODE(Add); + _DEFINE_MPS_NODE(Sub); + _DEFINE_MPS_NODE(Mul); + _DEFINE_MPS_NODE(Div); + _DEFINE_MPS_NODE(Pow); + _DEFINE_MPS_NODE(Fmod); + _DEFINE_MPS_NODE(Remainder); + _DEFINE_MPS_NODE(BitwiseAnd); + _DEFINE_MPS_NODE(BitwiseOr); + _DEFINE_MPS_NODE(BitwiseXor); + _DEFINE_MPS_NODE(Minimum); + // Unary ops + _DEFINE_MPS_NODE(Exp); + _DEFINE_MPS_NODE(Exp2); + _DEFINE_MPS_NODE(Reciprocal); + _DEFINE_MPS_NODE(Sqrt); + _DEFINE_MPS_NODE(Neg); + _DEFINE_MPS_NODE(Log); + _DEFINE_MPS_NODE(Log10); + _DEFINE_MPS_NODE(Log2); + _DEFINE_MPS_NODE(Erf); + _DEFINE_MPS_NODE(Floor); + _DEFINE_MPS_NODE(Ceil); + _DEFINE_MPS_NODE(Rsqrt); + _DEFINE_MPS_NODE(Sigmoid); + _DEFINE_MPS_NODE(Sin); + _DEFINE_MPS_NODE(Sign); + _DEFINE_MPS_NODE(Cos); + _DEFINE_MPS_NODE(Tan); + _DEFINE_MPS_NODE(Abs); + _DEFINE_MPS_NODE(Asin); + _DEFINE_MPS_NODE(Acos); + _DEFINE_MPS_NODE(Atan); + _DEFINE_MPS_NODE(Sinh); + _DEFINE_MPS_NODE(Cosh); + _DEFINE_MPS_NODE(Tanh); + _DEFINE_MPS_NODE(Asinh); + _DEFINE_MPS_NODE(Acosh); + _DEFINE_MPS_NODE(Atanh); + _DEFINE_MPS_NODE(BitwiseNot); + _DEFINE_MPS_NODE(Isnan); + _DEFINE_MPS_NODE(Isinf); + _DEFINE_MPS_NODE(Round); + // Clamp ops + _DEFINE_MPS_NODE(Clamp); + _DEFINE_MPS_NODE(Where); + // Linear algebra ops + _DEFINE_MPS_NODE(MatMul); + _DEFINE_MPS_NODE(Addmm); + // Constant ops + _DEFINE_MPS_NODE(Full); + _DEFINE_MPS_NODE(FullLike); + //Indexing ops + _DEFINE_MPS_NODE(IndexSelect); + _DEFINE_MPS_NODE(Embedding); + // Reduce ops + _DEFINE_MPS_NODE(Mean); + // Shape ops + _DEFINE_MPS_NODE(Permute); + _DEFINE_MPS_NODE(View); + _DEFINE_MPS_NODE(Expand); + _DEFINE_MPS_NODE(Cat); + _DEFINE_MPS_NODE(Squeeze); + _DEFINE_MPS_NODE(Unsqueeze); + _DEFINE_MPS_NODE(Select); + _DEFINE_MPS_NODE(Slice); + _DEFINE_MPS_NODE(PixelShuffle); + _DEFINE_MPS_NODE(SplitWithSizes); + _DEFINE_MPS_NODE(Cast); + // Convolution ops + _DEFINE_MPS_NODE(Conv2D); + _DEFINE_MPS_NODE(DepthwiseConv2D); + // Comparison ops + _DEFINE_MPS_NODE(Eq); + _DEFINE_MPS_NODE(Ne); + _DEFINE_MPS_NODE(Ge); + _DEFINE_MPS_NODE(Gt); + _DEFINE_MPS_NODE(Le); + _DEFINE_MPS_NODE(Lt); + // Normalization ops + _DEFINE_MPS_NODE(BatchNorm); + _DEFINE_MPS_NODE(LayerNorm); + // Pooling ops + _DEFINE_MPS_NODE(MaxPool2DWithIndices); + _DEFINE_MPS_NODE(AvgPool2D); + // Pad ops + _DEFINE_MPS_NODE(ConstantPadND); + // Range ops + _DEFINE_MPS_NODE(Arange); + + case mpsgraph::MPSNodeUnion::NONE: + default: + ET_CHECK_OR_RETURN_ERROR( + false, + NotImplemented, + "[ERROR] Unhandled node type: %s!", + mpsgraph::EnumNameMPSNodeUnion(nodePtr->mpsnode_union_type())); + } +} + +#undef _DEFINE_MPS_NODE + +MPSGraphTensor* +MPSGraphBuilder::getMPSGraphTensor(int32_t id) { + static int32_t cacheEntries = _idToMPSGraphTensor.size(); + return _idToMPSGraphTensor[id]; +} + +MPSDataType getMPSScalarType(exec_aten::ScalarType scalar_type) { + switch (scalar_type) { + // This is an intentional fallthrough supporting Double for Scalar + // types as they are casted to Float32 currently. + case exec_aten::ScalarType::Float: + return MPSDataTypeFloat32; + case exec_aten::ScalarType::Half: + return MPSDataTypeFloat16; + default: + ET_CHECK_MSG(false, "Unhandled ExecuTorch scalar type!"); + } +} + +exec_aten::ScalarType getScalarType(MPSDataType mpsDataType) { + switch (mpsDataType) { + case MPSDataTypeFloat16: + return exec_aten::ScalarType::Half; + case MPSDataTypeFloat32: + return exec_aten::ScalarType::Float; + case MPSDataTypeInt8: + return exec_aten::ScalarType::Char; + case MPSDataTypeInt16: + return exec_aten::ScalarType::Short; + case MPSDataTypeInt32: + return exec_aten::ScalarType::Int; + case MPSDataTypeInt64: + return exec_aten::ScalarType::Long; + case MPSDataTypeBool: + return exec_aten::ScalarType::Bool; + default: + ET_CHECK_MSG(false, "Unhandled MPS data type!"); + } +} + +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, exec_aten::ScalarType toType) { + return castMPSTensor(mpsGraph, tensor, getMPSScalarType(toType)); +} + +MPSGraphTensor* castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { + return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; +} + +std::vector getMPSShapeVec(const MPSShape* shape) { + __block std::vector shapeVec; + shapeVec.reserve([shape count]); + [shape enumerateObjectsUsingBlock:^(NSNumber * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { + shapeVec.push_back(obj.intValue); + }]; + return shapeVec; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/operations/PadOps.mm b/backends/apple/mps/runtime/operations/PadOps.mm similarity index 59% rename from backends/apple/mps/operations/PadOps.mm rename to backends/apple/mps/runtime/operations/PadOps.mm index e68fb5d55f4..7728d7edf48 100644 --- a/backends/apple/mps/operations/PadOps.mm +++ b/backends/apple/mps/runtime/operations/PadOps.mm @@ -1,32 +1,39 @@ + // // Copyright (c) 2023 Apple Inc. All rights reserved. // Provided subject to the LICENSE file in the top level directory. // -#include "utils/MPSGraphInterface.h" -#include +#include +namespace torch { +namespace executor { namespace mps { -using namespace torch; +namespace delegate { // Pad operations (1D/2D/3D forward) -static PyMPSGraphTensor* -pad_out_template(MPSGraph* mpsGraph, - MPSGraphTensor* input, IntArrayRef padding, - MPSGraphPaddingMode mode, double constantValue) -{ +static +MPSGraphTensor* padOutTemplate( + MPSGraph* mpsGraph, + MPSGraphTensor* input, + std::vector padding, + MPSGraphPaddingMode mode, + float constantValue) { + const int padding_size = (int) padding.size(); int padding_dim = padding_size / 2; // either 1D, 2D, or 3D - TORCH_CHECK(padding_size == 2 || padding_size == 4 || padding_size == 6, - "invalid padding argument of size ", padding_size); + ET_CHECK_MSG(padding_size == 2 || padding_size == 4 || padding_size == 6, + "invalid padding argument of size %d", padding_size); auto input_sizes = getMPSShapeVec(input.shape); int64_t nbatch = 1; int64_t ndims = input_sizes.size(); - TORCH_CHECK(ndims >= (int64_t)padding_dim, "Length of pad should be no more than twice the number of " - "dimensions of the input. Pad length is ", padding_size, "while the input has ", ndims, "dimensions."); + ET_CHECK_MSG( + ndims >= (int64_t)padding_dim, + "Length of pad should be no more than twice the number of " + "dimensions of the input. Pad length is %d while the input has %lld dimensions.", padding_size, ndims); // number of input dims with ConstantPad could be less than 2 int dim_w = padding_dim; @@ -36,9 +43,9 @@ if (mode != MPSGraphPaddingModeConstant && ndims > padding_dim) { bool valid_dims = input_sizes[1] != 0 && input_sizes[padding_dim] != 0; - TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) || + ET_CHECK_MSG((ndims == 1 + padding_dim && valid_dims) || (ndims == 2 + padding_dim && valid_dims && input_sizes[1 + padding_dim] != 0), - "3D or 4D (batch mode) tensor expected for input, but got: ", input); + "3D or 4D (batch mode) tensor expected for input, but got: %zu", input_sizes.size()); } if (ndims == padding_dim) { @@ -71,28 +78,29 @@ int64_t input_d = padding_dim > 2 ? input_sizes[dim_d] : 0; int64_t output_d = padding_dim > 2 ? input_d + pad_front + pad_back : 0; - TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1, - "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated " - "output H: ", output_h, " W: ", output_w); + ET_CHECK_MSG( + output_w >= 1 || output_h >= padding_dim - 1, + "input (H: %lld, W: %lld) is too small. Calculated " + "output H: %lld, W: %lld", input_h, input_w, output_h, output_w); // these checks are only relevant for reflection padding (code taken from ReflectionPad.cpp) if (mode == MPSGraphPaddingModeReflect) { - TORCH_CHECK(pad_l < input_w && pad_r < input_w, + ET_CHECK_MSG(pad_l < input_w && pad_r < input_w, "Argument #4: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_l, ", ", pad_r, - ") at dimension ", dim_w, " of input ", ndims); + "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld", + pad_l, pad_r, dim_w, ndims); if (padding_dim > 1) { - TORCH_CHECK(pad_t < input_h && pad_b < input_h, + ET_CHECK_MSG(pad_t < input_h && pad_b < input_h, "Argument #6: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_t, ", ", pad_b, - ") at dimension ", dim_h, " of input ", ndims); + "input dimension, but got: padding (%lld, %lld) at dimension %d of input %lld", + pad_t, pad_b, dim_h, ndims); } if (padding_dim > 2) { - TORCH_CHECK(pad_front < input_d && pad_back < input_d, + ET_CHECK_MSG(pad_front < input_d && pad_back < input_d, "Argument #8: Padding size should be less than the corresponding " - "input dimension, but got: padding (", pad_front, ", ", pad_back, - ") at dimension ", dim_d, " of input ", ndims); + "input dimension, but got: padding (%lld, %lld) at dimension %lld of input %lld", + pad_front, input_d, pad_back, input_d); } } @@ -121,12 +129,39 @@ return padTensor; } -PyMPSGraphTensor* -MPSGraphModule::constant_pad_nd( - MPSGraphTensor* input, - IntArrayRef pad, - const double value) { - return pad_out_template(mpsGraph, input, pad, MPSGraphPaddingModeConstant, value); +Error +MPSGraphBuilder::mpsConstantPadNDOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSConstantPadND(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + + + _idToMPSGraphTensor[graphNode->output_id()] = + padOutTemplate( + _mpsGraph, + getMPSGraphTensor(graphNode->input1_id()), + flatbufferDimsToVector(graphNode->pad()), + MPSGraphPaddingModeConstant, + graphNode->value() + ); + + return Error::Ok; } -} // namespace at::native +// PyMPSGraphTensor* +// MPSGraphModule::constant_pad_nd( +// MPSGraphTensor* input, +// IntArrayRef pad, +// const double value) { +// return pad_out_template(mpsGraph, input, pad, MPSGraphPaddingModeConstant, value); +// } + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/PoolingOps.mm b/backends/apple/mps/runtime/operations/PoolingOps.mm new file mode 100644 index 00000000000..06d21ae1116 --- /dev/null +++ b/backends/apple/mps/runtime/operations/PoolingOps.mm @@ -0,0 +1,107 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + + +Error +MPSGraphBuilder::mpsMaxPool2DWithIndicesOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSMaxPool2DWithIndices(); + ET_LOG( + Debug, "%s: %d -> (%d, %d)", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output1_id(), + graphNode->output2_id() + ); + + MPSGraphPooling2DOpDescriptor* desc = + [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width() + kernelHeight:graphNode->kernel_height() + strideInX:graphNode->stride_width() + strideInY:graphNode->stride_height() + dilationRateInX:graphNode->dilation_width() + dilationRateInY:graphNode->dilation_height() + paddingLeft:graphNode->padding_left() + paddingRight:graphNode->padding_right() + paddingTop:graphNode->padding_top() + paddingBottom:graphNode->padding_bottom() + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; + desc.ceilMode = graphNode->ceil_mode(); + desc.returnIndicesMode = MPSGraphPoolingReturnIndicesGlobalFlatten2D; + desc.returnIndicesDataType = MPSDataTypeInt32; + + NSArray* outputs = + [_mpsGraph maxPooling2DReturnIndicesWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) + descriptor:desc + name:@"MaxPool2DWithIndices"]; + + + _idToMPSGraphTensor[graphNode->output1_id()] = outputs[0]; + _idToMPSGraphTensor[graphNode->output2_id()] = outputs[1]; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsAvgPool2DOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSAvgPool2D(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output1_id() + ); + + MPSGraphPooling2DOpDescriptor* desc = + [MPSGraphPooling2DOpDescriptor descriptorWithKernelWidth:graphNode->kernel_width() + kernelHeight:graphNode->kernel_height() + strideInX:graphNode->stride_width() + strideInY:graphNode->stride_height() + dilationRateInX:graphNode->dilation_width() + dilationRateInY:graphNode->dilation_height() + paddingLeft:graphNode->padding_left() + paddingRight:graphNode->padding_right() + paddingTop:graphNode->padding_top() + paddingBottom:graphNode->padding_bottom() + paddingStyle:MPSGraphPaddingStyleExplicit + dataLayout:MPSGraphTensorNamedDataLayoutNCHW]; + const bool useDivisor = graphNode->divisor_override() != 0; + + // If overriding divisor, zeroPads must be included to the average for correct behavior + desc.includeZeroPadToAverage = useDivisor ? true : graphNode->count_include_pad(); + + MPSGraphTensor* avgPoolTensor = [_mpsGraph avgPooling2DWithSourceTensor:getMPSGraphTensor(graphNode->input1_id()) + descriptor:desc + name:@"AvgPool2DTensor"]; + if (useDivisor) { + // Here we rescale the average due to MPSGraph not supporting custom divisor directly + const float divisor = float(graphNode->kernel_height() * graphNode->kernel_width()) / (float)graphNode->divisor_override(); + MPSGraphTensor* constantTensor = [_mpsGraph constantWithScalar:divisor + shape:@[@1] + dataType:MPSDataTypeFloat32]; + avgPoolTensor = [_mpsGraph multiplicationWithPrimaryTensor:avgPoolTensor + secondaryTensor:constantTensor + name:@"AvgPool2DTensor/divisor_override"]; + + } + + _idToMPSGraphTensor[graphNode->output1_id()] = avgPoolTensor; + + return Error::Ok; +} + + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/RangeOps.mm b/backends/apple/mps/runtime/operations/RangeOps.mm new file mode 100644 index 00000000000..ef8ec9ae56d --- /dev/null +++ b/backends/apple/mps/runtime/operations/RangeOps.mm @@ -0,0 +1,53 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + + +Error +MPSGraphBuilder::mpsArangeOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSArange(); + ET_LOG( + Debug, "%s: () -> %d", + __FUNCTION__, + graphNode->output_id() + ); + + auto start = graphNode->start(); + auto end = graphNode->end(); + auto step = graphNode->step(); + MPSDataType dataType = getMPSDataType(graphNode->dtype()); + + int32_t size_d = std::ceil(static_cast(end - start) / step); + auto shapeTensor = [_mpsGraph constantWithData:[NSData dataWithBytes:&size_d length:sizeof(int32_t)] + shape:@[ @1 ] + dataType:MPSDataTypeInt32]; + auto startScalar = start; + auto stepScalar = step; + auto coordsTensor = [_mpsGraph coordinateAlongAxis:0 withShapeTensor:shapeTensor name:nil]; + coordsTensor = [_mpsGraph castTensor:coordsTensor toType:dataType name:@"coords"]; + + auto startTensor = [_mpsGraph constantWithScalar:startScalar + dataType:dataType]; + auto multiplyTensor = [_mpsGraph constantWithScalar:stepScalar + dataType:dataType]; + auto scaledCoords = [_mpsGraph multiplicationWithPrimaryTensor:coordsTensor + secondaryTensor:multiplyTensor + name:nil]; + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph additionWithPrimaryTensor:scaledCoords secondaryTensor:startTensor name:nil]; + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/ReduceOps.mm b/backends/apple/mps/runtime/operations/ReduceOps.mm new file mode 100644 index 00000000000..1335adf85bb --- /dev/null +++ b/backends/apple/mps/runtime/operations/ReduceOps.mm @@ -0,0 +1,59 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + + +Error +MPSGraphBuilder::mpsMeanOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSMean(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + + //MPSGraph wants negative axes to be converted to positive + const int inputDims = [inputTensor.shape count]; + + NSMutableArray* dimArray = [NSMutableArray array]; + for(int64_t i = 0; i < graphNode->num_dims(); i++) { + int32_t dim = graphNode->dims()->Get(i); + if (dim < 0) { + dim = inputDims + dim; + } + [dimArray addObject:[NSNumber numberWithInt:dim]]; + } + + // Reverting back to get the ordering back to slowest axis first as MPSGraph expects + dimArray = [[[dimArray reverseObjectEnumerator] allObjects] mutableCopy]; + + MPSGraphTensor* meanTensor = [_mpsGraph meanOfTensor:inputTensor + axes:dimArray + name:@"Mean"]; + if (!graphNode->keep_dims()) { + meanTensor = [_mpsGraph squeezeTensor:meanTensor + axes:dimArray + name:@"Mean/squeezed"]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = meanTensor; + return Error::Ok; +} + + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/ShapeOps.mm b/backends/apple/mps/runtime/operations/ShapeOps.mm new file mode 100644 index 00000000000..e6ca79934a5 --- /dev/null +++ b/backends/apple/mps/runtime/operations/ShapeOps.mm @@ -0,0 +1,275 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + + +Error +MPSGraphBuilder::mpsPermuteOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSPermute(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + + NSMutableArray* permutation = [NSMutableArray array]; + for(int64_t i = 0; i < graphNode->num_dims(); i++) { + [permutation addObject:[NSNumber numberWithInteger:graphNode->perm()->Get(i)]]; + } + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph transposeTensor:getMPSGraphTensor(graphNode->input1_id()) + permutation:permutation + name:@"permutation"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsViewOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSView(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + NSMutableArray* shape = [NSMutableArray array]; + for (int32_t i = 0; i < graphNode->num_dims(); i++) { + [shape addObject:[NSNumber numberWithInteger:graphNode->shape()->Get(i)]]; + } + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph reshapeTensor:getMPSGraphTensor(graphNode->input1_id()) + withShape:shape + name:@"view_copy"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsExpandOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSExpand(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + NSMutableArray* shape = [NSMutableArray array]; + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + + // In torch, -1 is passed for dimensions which are to stay the same size + for (int32_t i = 0; i < inputTensor.shape.count; i++) { + int expandDimVal = graphNode->shape()->Get(i); + if (expandDimVal == -1) { + [shape addObject:inputTensor.shape[i]]; + } else { + [shape addObject:[NSNumber numberWithInteger:expandDimVal]]; + } + } + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph broadcastTensor:inputTensor + toShape:shape + name:@"expand_copy"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsCatOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSCat(); + ET_LOG( + Debug, "%s: %d", + __FUNCTION__, graphNode->output_id() + ); + + NSMutableArray* inputTensors = [NSMutableArray array]; + for (auto id : *graphNode->input_ids()) { + MPSGraphTensor* catTensor = getMPSGraphTensor(id); + if (catTensor != nil) + [inputTensors addObject:catTensor]; + } + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph concatTensors:inputTensors + dimension:graphNode->dim() + name:@"cat"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsSqueezeOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSSqueeze(); + ET_LOG( + Debug, "%s: %d", + __FUNCTION__, graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph squeezeTensor:getMPSGraphTensor(graphNode->input1_id()) + axes:getMPSShape(graphNode->dims()) + name:@"squeeze"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsUnsqueezeOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSUnsqueeze(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph expandDimsOfTensor:getMPSGraphTensor(graphNode->input1_id()) + axis:graphNode->dim() + name:@"unsqueeze"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsSelectOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSSelect(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + MPSGraphTensor* slicedTensor = [_mpsGraph sliceTensor:getMPSGraphTensor(graphNode->input1_id()) + dimension:graphNode->dim() + start:graphNode->index() + length:1 + name:@"slice"]; + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph squeezeTensor:slicedTensor + axis:graphNode->dim() + name:@"slice/squeezed"]; + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsPixelShuffleOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSPixelShuffle(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + const int ndims = inputTensor.shape.count; + MPSGraphTensor* outputTensor = nil; + int32_t upscaleFactor = graphNode->upscale_factor(); + + ET_CHECK_OR_RETURN_ERROR( + ndims >= 3, Internal, "pixel_shuffle requires tensor with at least 3 dimensions."); + if (upscaleFactor == 1) { + // TODO: move this to AOT + outputTensor = inputTensor; + } else { + ET_CHECK_OR_RETURN_ERROR( + inputTensor.shape[ndims - 3].intValue % (upscaleFactor * upscaleFactor) == 0, + Internal, + "pixel_shuffle channels must be divisible by upscale factor squared."); + + outputTensor = [_mpsGraph depthToSpace2DTensor:inputTensor + widthAxis:ndims - 1 + heightAxis:ndims - 2 + depthAxis:ndims - 3 + blockSize:upscaleFactor + usePixelShuffleOrder:true + name:@"pixel_shuffle"]; + } + + _idToMPSGraphTensor[graphNode->output_id()] = outputTensor; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsSliceOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSSlice(); + ET_LOG( + Debug, "%s %d: %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + int64_t dim = graphNode->dim(); + int64_t dimLen = inputTensor.shape[dim].intValue; + + // Define input arrays as required by MPSGraph API + NSMutableArray* start_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; + NSMutableArray* end_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; + NSMutableArray* step_arr = [NSMutableArray arrayWithCapacity: inputTensor.shape.count]; + // Step needs to be set to one for all other dims + for (int i = 0; i < inputTensor.shape.count; i++) { + step_arr[i] = @1; + end_arr[i] = inputTensor.shape[i]; + start_arr[i] = @0; + } + + start_arr[dim] = [NSNumber numberWithInteger:graphNode->start()]; + end_arr[dim] = [NSNumber numberWithInteger:graphNode->end()]; + step_arr[dim] = [NSNumber numberWithInteger:graphNode->step()]; + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph sliceTensor:inputTensor + starts:start_arr + ends:end_arr + strides:step_arr + name:@"strided_slice"]; + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsSplitWithSizesOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSSplitWithSizes(); + ET_LOG( + Debug, "%s: %d -> len(output)=%d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_ids()->size() + ); + + std::vector splitResults; + NSArray* mpsGraphResults; + + mpsGraphResults = [_mpsGraph splitTensor:getMPSGraphTensor(graphNode->input1_id()) + splitSizes:getMPSShape(graphNode->split_sizes()) + axis:graphNode->dim() + name:@"split_size"]; + + int crtIdx = 0; + for (auto outId : *graphNode->output_ids()) { + _idToMPSGraphTensor[outId] = mpsGraphResults[crtIdx++]; + } + + return Error::Ok; +} + +Error +MPSGraphBuilder::mpsCastOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSCast(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + + _idToMPSGraphTensor[graphNode->output_id()] = castMPSTensor( + _mpsGraph, getMPSGraphTensor(graphNode->input1_id()), getMPSDataType(graphNode->dtype())); + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/runtime/operations/UnaryOps.mm b/backends/apple/mps/runtime/operations/UnaryOps.mm new file mode 100644 index 00000000000..099afb63cc9 --- /dev/null +++ b/backends/apple/mps/runtime/operations/UnaryOps.mm @@ -0,0 +1,134 @@ + +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +#include + +namespace torch { +namespace executor { +namespace mps { +namespace delegate { + +MPSGraphTensor* +unaryOpTensor( + MPSGraphTensor* inputTensor, + MPSGraph* mpsGraph, + std::function unaryOpFunction) { + return unaryOpFunction(inputTensor); +} + +Error +MPSGraphBuilder::mpsBitwiseNotOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSBitwiseNot(); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSDataType mpsInputDataType = [inputTensor dataType]; + if (getScalarType(mpsInputDataType) == ScalarType::Bool) { + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph notWithTensor:inputTensor name:nil]; + } else { + _idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph bitwiseNOTWithTensor:inputTensor name:nil]; + } + + return Error::Ok; +} + +#define REGISTER_UNARY_OP(aot_name, graph_op) \ +Error \ +MPSGraphBuilder::mps##aot_name##Op(NodePtr nodePtr) { \ + auto graphNode = static_cast(nodePtr->mpsnode_union()); \ + ET_LOG( \ + Debug, "%s: %d -> %d", \ + __FUNCTION__, \ + graphNode->input1_id(), \ + graphNode->output_id() \ + ); \ + _idToMPSGraphTensor[graphNode->output_id()] = unaryOpTensor( \ + getMPSGraphTensor(graphNode->input1_id()), \ + _mpsGraph, \ + [&](MPSGraphTensor* inputTensor) -> MPSGraphTensor* { \ + return [_mpsGraph graph_op##WithTensor:inputTensor \ + name:nil]; \ + } \ + ); \ + return Error::Ok; \ +} + +REGISTER_UNARY_OP(Exp, exponent) +REGISTER_UNARY_OP(Exp2, exponentBase2) +REGISTER_UNARY_OP(Reciprocal, reciprocal) +REGISTER_UNARY_OP(Sqrt, squareRoot) +REGISTER_UNARY_OP(Neg, negative) +REGISTER_UNARY_OP(Log, logarithm) +REGISTER_UNARY_OP(Log10, logarithmBase10) +REGISTER_UNARY_OP(Log2, logarithmBase2) +REGISTER_UNARY_OP(Erf, erf) +REGISTER_UNARY_OP(Floor, floor) +REGISTER_UNARY_OP(Ceil, ceil) +REGISTER_UNARY_OP(Rsqrt, reverseSquareRoot) +REGISTER_UNARY_OP(Sigmoid, sigmoid) +REGISTER_UNARY_OP(Sin, sin) +REGISTER_UNARY_OP(Sign, sign) +REGISTER_UNARY_OP(Cos, cos) +REGISTER_UNARY_OP(Tan, tan) +REGISTER_UNARY_OP(Abs, absolute) +REGISTER_UNARY_OP(Asin, asin) +REGISTER_UNARY_OP(Acos, acos) +REGISTER_UNARY_OP(Atan, atan) +REGISTER_UNARY_OP(Sinh, sinh) +REGISTER_UNARY_OP(Cosh, cosh) +REGISTER_UNARY_OP(Tanh, tanh) +REGISTER_UNARY_OP(Asinh, asinh) +REGISTER_UNARY_OP(Acosh, acosh) +REGISTER_UNARY_OP(Atanh, atanh) +REGISTER_UNARY_OP(Isnan, isNaN) +REGISTER_UNARY_OP(Isinf, isInfinite) +REGISTER_UNARY_OP(Round, round) + + +Error +MPSGraphBuilder::mpsNormCdfOp(NodePtr nodePtr) { + auto graphNode = static_cast(nodePtr->mpsnode_union()); + ET_LOG( + Debug, "%s: %d -> %d", + __FUNCTION__, + graphNode->input1_id(), + graphNode->output_id() + ); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + auto dataType = [inputTensor dataType]; + const float SQRT1_2 = 0.707106781186547524400844362104849039f; + MPSGraphTensor *sqrt1_2 = [_mpsGraph constantWithScalar:SQRT1_2 + shape:@[@1] + dataType:dataType]; + MPSGraphTensor *onef = [_mpsGraph constantWithScalar:1.0f + shape:@[@1] + dataType:dataType]; + MPSGraphTensor *halff = [_mpsGraph constantWithScalar:0.5f + shape:@[@1] + dataType:dataType]; + + MPSGraphTensor *erfTensor = [_mpsGraph multiplicationWithPrimaryTensor:inputTensor + secondaryTensor:sqrt1_2 + name:nil]; + erfTensor = [_mpsGraph erfWithTensor:erfTensor name:nil]; + erfTensor = [_mpsGraph additionWithPrimaryTensor:erfTensor + secondaryTensor:onef + name:nil]; + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph multiplicationWithPrimaryTensor:erfTensor + secondaryTensor:halff + name:nil]; + + return Error::Ok; +} + +} // namespace delegate +} // namespace mps +} // namespace executor +} // namespace torch diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py new file mode 100644 index 00000000000..04a41abaa1c --- /dev/null +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -0,0 +1,743 @@ +# +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +""" +Please refer to executorch/backends/apple/mps/serialization/schema.fbs for the schema definitions +""" + +from dataclasses import dataclass, field +from enum import IntEnum +from typing import List, Optional, Union + + +class MPSDataType(IntEnum): + mps_data_type_invalid = 0 + mps_data_type_float16 = 1 + mps_data_type_float32 = 2 + mps_data_type_bfloat16 = 3 + mps_data_type_int8 = 4 + mps_data_type_int16 = 5 + mps_data_type_int32 = 6 + mps_data_type_int64 = 7 + mps_data_type_uint8 = 8 + mps_data_type_bool = 9 + mps_data_type_complex_float16 = 10 + mps_data_type_complex_float32 = 11 + + +@dataclass +class MPSNode1x1: + input1_id: int + output_id: int + + +@dataclass +class MPSNode2x1: + input1_id: int + input2_id: int + output_id: int + + +@dataclass +class MPSDivNode2x1(MPSNode2x1): + rounding_mode: str = None + + +@dataclass +class MPSNode3x1: + input1_id: int + input2_id: int + input3_id: int + output_id: int + + +@dataclass +class MPSConv(MPSNode3x1): + stride_x: int = 0 + stride_y: int = 0 + dilation_x: int = 0 + dilation_y: int = 0 + groups: int = 0 + padding_left: int = 0 + padding_right: int = 0 + padding_top: int = 0 + padding_bottom: int = 0 + + +@dataclass +class MPSPooling2D: + input1_id: int + kernel_height: int + kernel_width: int + stride_height: int + stride_width: int + padding_left: int + padding_right: int + padding_top: int + padding_bottom: int + dilation_height: int + dilation_width: int + ceil_mode: bool + output1_id: int + output2_id: int = -1 + count_include_pad: bool = True + divisor_override: int = 0 + + +@dataclass +class MPSMinMax: + min_value: Union[float, str] = "-inf" + max_value: Union[float, str] = "inf" + + +## +## Activation ops +## +@dataclass +class MPSHardTanh(MPSNode1x1): + min_value: float = 0.0 + max_value: float = 0.0 + + +@dataclass +class MPSReLU(MPSNode1x1): + pass + + +@dataclass +class MPSGELU(MPSNode1x1): + approximate: str = "none" + + +@dataclass +class MPSLeakyReLU(MPSNode1x1): + negative_slope: float = 0.01 + + +@dataclass +class MPSSoftmax(MPSNode1x1): + dim: int = 0 + half_to_float: bool = False + + +@dataclass +class MPSLogSoftmax(MPSNode1x1): + dim: int = 0 + half_to_float: bool = False + + +## +## Binary ops +## +@dataclass +class MPSAdd(MPSNode2x1): + alpha: float = 1.0 + + +@dataclass +class MPSSub(MPSNode2x1): + alpha: float = 1.0 + + +@dataclass +class MPSMul(MPSNode2x1): + pass + + +@dataclass +class MPSDiv(MPSDivNode2x1): + pass + + +@dataclass +class MPSFmod(MPSDivNode2x1): + pass + + +@dataclass +class MPSRemainder(MPSNode2x1): + pass + + +@dataclass +class MPSMin(MPSNode2x1): + pass + + +@dataclass +class MPSMax(MPSNode2x1): + pass + + +@dataclass +class MPSPow(MPSNode2x1): + pass + + +@dataclass +class MPSAtan2(MPSNode2x1): + pass + + +@dataclass +class MPSBitwiseAnd(MPSNode2x1): + pass + + +@dataclass +class MPSBitwiseOr(MPSNode2x1): + pass + + +@dataclass +class MPSBitwiseXor(MPSNode2x1): + pass + + +@dataclass +class MPSMinimum(MPSNode2x1): + pass + + +## +## Unary ops +## +@dataclass +class MPSExp(MPSNode1x1): + pass + + +@dataclass +class MPSExp2(MPSNode1x1): + pass + + +@dataclass +class MPSReciprocal(MPSNode1x1): + pass + + +@dataclass +class MPSSqrt(MPSNode1x1): + pass + + +@dataclass +class MPSNeg(MPSNode1x1): + pass + + +@dataclass +class MPSLog(MPSNode1x1): + pass + + +@dataclass +class MPSLog10(MPSNode1x1): + pass + + +@dataclass +class MPSLog2(MPSNode1x1): + pass + + +@dataclass +class MPSErf(MPSNode1x1): + pass + + +@dataclass +class MPSFloor(MPSNode1x1): + pass + + +@dataclass +class MPSCeil(MPSNode1x1): + pass + + +@dataclass +class MPSRsqrt(MPSNode1x1): + pass + + +@dataclass +class MPSSigmoid(MPSNode1x1): + pass + + +@dataclass +class MPSSin(MPSNode1x1): + pass + + +@dataclass +class MPSSign(MPSNode1x1): + pass + + +@dataclass +class MPSCos(MPSNode1x1): + pass + + +@dataclass +class MPSTan(MPSNode1x1): + pass + + +@dataclass +class MPSAbs(MPSNode1x1): + pass + + +@dataclass +class MPSAsin(MPSNode1x1): + pass + + +@dataclass +class MPSAcos(MPSNode1x1): + pass + + +@dataclass +class MPSAtan(MPSNode1x1): + pass + + +@dataclass +class MPSSinh(MPSNode1x1): + pass + + +@dataclass +class MPSCosh(MPSNode1x1): + pass + + +@dataclass +class MPSTanh(MPSNode1x1): + pass + + +@dataclass +class MPSAsinh(MPSNode1x1): + pass + + +@dataclass +class MPSAcosh(MPSNode1x1): + pass + + +@dataclass +class MPSAtanh(MPSNode1x1): + pass + + +@dataclass +class MPSBitwiseNot(MPSNode1x1): + pass + + +@dataclass +class MPSIsnan(MPSNode1x1): + pass + + +@dataclass +class MPSIsinf(MPSNode1x1): + pass + + +@dataclass +class MPSRound(MPSNode1x1): + pass + + +@dataclass +class MPSBitwise(MPSNode1x1): + pass + + +## +## Linear algebra ops +## +@dataclass +class MPSMatMul(MPSNode2x1): + pass + + +@dataclass +class MPSAddmm(MPSNode3x1): + beta: float = 1.0 + alpha: float = 1.0 + + +## +## Constant ops +## +@dataclass +class MPSFull: + output_id: int + shape: List[int] + fill_value: float + dtype: MPSDataType + + +@dataclass +class MPSFullLike(MPSNode1x1): + fill_value: float = 0.0 + dtype: MPSDataType = MPSDataType.mps_data_type_float32 + + +## +## Clamp ops +## +@dataclass +class MPSClamp(MPSNode1x1): + pass + + +@dataclass +class MPSWhere(MPSNode3x1): + pass + + +## +## Reduce ops +## +@dataclass +class MPSMean(MPSNode1x1): + num_dims: int = 0 + dims: List[int] = field(default_factory=list) + keep_dims: bool = False + + +## +## Indexing ops +## +@dataclass +class MPSIndexSelect(MPSNode1x1): + dim: int = 0 + index_id: int = -1 + + +@dataclass +class MPSEmbedding(MPSNode2x1): + padding_idx: int = -1 + scale_grad_by_freq: bool = False + sparse: bool = False + + +## +## Shape ops +## +@dataclass +class MPSPermute(MPSNode1x1): + num_dims: int = 0 + perm: List[int] = field(default_factory=list) + + +@dataclass +class MPSView(MPSNode1x1): + num_dims: int = 0 + shape: List[int] = field(default_factory=list) + + +@dataclass +class MPSExpand(MPSNode1x1): + num_dims: int = 0 + shape: List[int] = field(default_factory=list) + + +@dataclass +class MPSCat: + input_ids: List[int] + output_id: int + dim: int + + +@dataclass +class MPSSqueeze(MPSNode1x1): + dims: List[int] = field(default_factory=list) + + +@dataclass +class MPSUnsqueeze(MPSNode1x1): + dim: int = 0 + + +@dataclass +class MPSSelect(MPSNode1x1): + dim: int = 0 + index: int = 0 + + +@dataclass +class MPSSlice(MPSNode1x1): + dim: int = 0 + start: int = -1 + end: int = -1 + step: int = 1 + + +@dataclass +class MPSPixelShuffle(MPSNode1x1): + upscale_factor: int = 1 + + +@dataclass +class MPSSplitWithSizes: + input1_id: int + output_ids: List[int] + split_sizes: List[int] + dim: int + + +@dataclass +class MPSCast(MPSNode1x1): + dtype: MPSDataType + + +## +## Convolution ops +## + + +@dataclass +class MPSConv2D(MPSConv): + pass + + +@dataclass +class MPSDepthwiseConv2D(MPSConv): + pass + + +## +## Comparison Ops +## +class MPSEq(MPSNode2x1): + pass + + +class MPSNe(MPSNode2x1): + pass + + +class MPSGe(MPSNode2x1): + pass + + +class MPSGt(MPSNode2x1): + pass + + +class MPSLe(MPSNode2x1): + pass + + +class MPSLt(MPSNode2x1): + pass + + +## +## Normalization op +## +@dataclass +class MPSBatchNorm: + input_id: int + mean_id: int + var_id: int + weight_id: int + bias_id: int + momentum: float + epsilon: float + output1_id: int + output2_id: int + output3_id: int + + +@dataclass +class MPSLayerNorm: + input1_id: int + normalized_shape: List[int] + weight_id: int + bias_id: int + eps: float + output1_id: int + output2_id: int + output3_id: int + + +## +## Pooling ops +## + + +@dataclass +class MPSMaxPool2DWithIndices(MPSPooling2D): + pass + + +@dataclass +class MPSAvgPool2D(MPSPooling2D): + pass + + +## +## Pad ops +## +@dataclass +class MPSConstantPadND(MPSNode1x1): + pad: List[int] = field(default_factory=list) + value: float = 0.0 + + +## +## Range ops +## +@dataclass +class MPSArange: + output_id: int + start: float + end: float + step: float + dtype: MPSDataType + + +MPSNodeUnion = Union[ + # Activation ops + MPSHardTanh, + MPSReLU, + MPSGELU, + MPSLeakyReLU, + MPSSoftmax, + # Binary ops + MPSAdd, + MPSSub, + MPSMul, + MPSDiv, + MPSMin, + MPSMax, + MPSPow, + MPSRemainder, + MPSAtan2, + MPSBitwiseAnd, + MPSBitwiseOr, + MPSBitwiseXor, + MPSMinimum, + # Unary ops + MPSExp, + MPSExp2, + MPSReciprocal, + MPSSqrt, + MPSNeg, + MPSLog, + MPSLog10, + MPSLog2, + MPSErf, + MPSFloor, + MPSCeil, + MPSRsqrt, + MPSSigmoid, + MPSSin, + MPSSign, + MPSCos, + MPSTan, + MPSAbs, + MPSAsin, + MPSAcos, + MPSAtan, + MPSSinh, + MPSCosh, + MPSTanh, + MPSAsinh, + MPSAcosh, + MPSAtanh, + MPSBitwiseNot, + MPSIsnan, + MPSIsinf, + MPSRound, + # Linear algebra ops + MPSMatMul, + MPSAddmm, + # Constant ops + MPSFull, + MPSFullLike, + # Clamp ops + MPSClamp, + MPSWhere, + # Reduce ops + MPSMean, + # Indexing ops + MPSIndexSelect, + MPSEmbedding, + # Shape ops + MPSPermute, + MPSView, + MPSExpand, + MPSCat, + MPSSqueeze, + MPSUnsqueeze, + MPSSelect, + MPSSlice, + MPSPixelShuffle, + MPSSplitWithSizes, + MPSCast, + # Convolution ops + MPSConv2D, + MPSDepthwiseConv2D, + # Comparison ops + MPSEq, + MPSNe, + MPSGe, + MPSGt, + MPSLe, + MPSLt, + # Normalization ops + MPSBatchNorm, + MPSLayerNorm, + # Pooling ops + MPSMaxPool2DWithIndices, + MPSAvgPool2D, + # Pad ops + MPSConstantPadND, + # Range ops + MPSArange, +] + + +@dataclass +class MPSNode: + mpsnode_union: "MPSNodeUnion" + min_max: Optional[MPSMinMax] = None + + +@dataclass +class Buffer: + storage: bytes + + +@dataclass +class MPSTensor: + datatype: MPSDataType + num_dims: int + dims: List[int] + constant_buffer_size: int + constant_buffer: Buffer + + +@dataclass +class MPSGraph: + version: str + mps_nodes: List[MPSNode] + mps_values: List[MPSTensor] + input_ids: List[int] + output_ids: List[int] + constant_ids: List[int] diff --git a/backends/apple/mps/serialization/mps_graph_serialize.py b/backends/apple/mps/serialization/mps_graph_serialize.py new file mode 100644 index 00000000000..6fa46a2f5e5 --- /dev/null +++ b/backends/apple/mps/serialization/mps_graph_serialize.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import tempfile + +import pkg_resources +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSGraph +from executorch.exir._serialize._dataclass import _DataclassEncoder +from executorch.exir._serialize._flatbuffer import _flatc_compile + + +def convert_to_flatbuffer(mps_graph: MPSGraph) -> bytes: + mps_graph_json = json.dumps(mps_graph, cls=_DataclassEncoder) + with tempfile.TemporaryDirectory() as d: + schema_path = os.path.join(d, "schema.fbs") + with open(schema_path, "wb") as schema_file: + schema_file.write(pkg_resources.resource_string(__name__, "schema.fbs")) + json_path = os.path.join(d, "schema.json") + with open(json_path, "wb") as json_file: + json_file.write(mps_graph_json.encode("ascii")) + + _flatc_compile(d, schema_path, json_path) + output_path = os.path.join(d, "schema.bin") + with open(output_path, "rb") as output_file: + return output_file.read() diff --git a/backends/apple/mps/serialization/schema.fbs b/backends/apple/mps/serialization/schema.fbs new file mode 100644 index 00000000000..c3e3eaa4faf --- /dev/null +++ b/backends/apple/mps/serialization/schema.fbs @@ -0,0 +1,443 @@ +// +// Copyright (c) 2023 Apple Inc. All rights reserved. +// Provided subject to the LICENSE file in the top level directory. +// + +namespace mpsgraph; + +// Update after any BC breaking changes +file_identifier "MP00"; + +// datatype for mps-values +enum MPSDataType : short { + mps_data_type_invalid = 0, + mps_data_type_float16 = 1, + mps_data_type_float32 = 2, + mps_data_type_bfloat16 = 3, + mps_data_type_int8 = 4, + mps_data_type_int16 = 5, + mps_data_type_int32 = 6, + mps_data_type_int64 = 7, + mps_data_type_uint8 = 8, + mps_data_type_bool = 9, + mps_data_type_complex_float16 = 10, + mps_data_type_complex_float32 = 11, +} + +// Helper classes to define the number of input and output tensors for a node. +// Not meant to be used directly. + +// A node with one input and one output. +table _MPSNode1x1 { + input1_id:int; + output_id:int; +} + +// A node with two inputs and one output. +table _MPSNode2x1 { + input1_id:int; + input2_id:int; + output_id:int; +} + +table _MPSDivNode2x1 { + input1_id:int; + input2_id:int; + output_id:int; + rounding_mode:string; +} + +table _MPSNodeWithAlpha2x1 { + input1_id:int; + input2_id:int; + output_id:int; + alpha:float; +} + +// A node with three inputs and one output. +table _MPSNode3x1 { + input1_id:int; + input2_id:int; + input3_id:int; + output_id:int; +} + +table MPSMinMax { + min_value:float; + max_value:float; +} + +table MPSPooling2D { + input1_id:int; + kernel_height:int; + kernel_width:int; + stride_height:int; + stride_width:int; + padding_left:int; + padding_right:int; + padding_top:int; + padding_bottom:int; + dilation_height:int; + dilation_width:int; + ceil_mode:bool; + count_include_pad:bool; + divisor_override:int; + output1_id:int; + output2_id:int; +} + +// Activation ops. +table MPSHardTanh { + input1_id:int; + output_id:int; + min_value:float; + max_value:float; +} + +table MPSGELU { + input1_id:int; + output_id:int; + approximate:string; +} + +table MPSLeakyReLU { + input1_id:int; + output_id:int; + negative_slope:float; +} + +table MPSSoftmax { + input1_id:int; + output_id:int; + dim:int; + half_to_float:bool; +} + +// Clamp ops +table MPSClamp { + input1_id:int; + output_id:int; +} + +// Reduce ops +table MPSMean { + input1_id:int; + output_id:int; + num_dims:int; + dims:[int]; + keep_dims:bool; +} + +// Indexing ops +table MPSIndexSelect { + input1_id:int; + output_id:int; + dim:int; + index_id:int; +} + +table MPSEmbedding { + input1_id:int; + input2_id:int; + output_id:int; + padding_idx:int; + scale_grad_by_freq:bool; + sparse:bool; +} + +// Shape ops. +table MPSPermute { + input1_id:int; + output_id:int; + num_dims:int; + perm:[int]; +} + +table MPSView { + input1_id:int; + output_id:int; + num_dims:int; + shape:[int]; +} + +table MPSCat { + input_ids:[int]; + output_id:int; + dim:int; +} + +table MPSSqueeze { + input1_id:int; + output_id:int; + dims:[int]; +} + +table MPSUnsqueeze { + input1_id:int; + output_id:int; + dim:int; +} + +table MPSSelect { + input1_id:int; + output_id:int; + dim:int; + index:int; +} + +table MPSSlice { + input1_id:int; + output_id:int; + dim:long; + start:long; + end:long; + step:long; +} + +table MPSPixelShuffle { + input1_id:int; + output_id:int; + upscale_factor:int; +} + +table MPSSplitWithSizes { + input1_id:int; + output_ids:[int]; + split_sizes:[int]; + dim:int; +} + +table MPSCast { + input1_id:int; + output_id:int; + dtype:MPSDataType; +} + +// Linear algebra ops. +table MPSAddmm { + input1_id:int; + input2_id:int; + input3_id:int; + output_id:int; + beta:float; + alpha:float; +} + +// Constant ops +table _MPSFull { + input1_id:int; + output_id:int; + shape:[int]; + fill_value: float; + dtype:MPSDataType; +} + +// Convolution ops. +table MPSConv { + input1_id:int; + input2_id:int; + input3_id:int; + output_id:int; + stride_x:int; + stride_y:int; + dilation_x:int; + dilation_y:int; + groups:int; + padding_left:int; + padding_right:int; + padding_top:int; + padding_bottom:int; +} + +// Normalization ops. +table MPSBatchNorm { + input_id:int; + mean_id:int; + var_id:int; + weight_id:int; + bias_id:int; + momentum:float; + epsilon:float; + output2_id:int; + output1_id:int; + output3_id:int; +} + +table MPSLayerNorm { + input1_id:int; + normalized_shape:[int]; + weight_id:int; + bias_id:int; + eps:float; + output2_id:int; + output1_id:int; + output3_id:int; +} + +// Pooling ops + +// Pad ops +table MPSConstantPadND { + input1_id:int; + output_id:int; + pad:[int]; + value:float; +} + +// Range ops +table MPSArange { + output_id:int; + start:float; + end:float; + step:float; + dtype:MPSDataType; +} + +union MPSNodeUnion { + // Activation ops + MPSHardTanh, + MPSReLU: _MPSNode2x1, + MPSGELU, + MPSLeakyReLU, + MPSSoftmax, + MPSLogSoftmax: MPSSoftmax, + + // Binary ops + MPSAdd: _MPSNodeWithAlpha2x1, + MPSSub: _MPSNodeWithAlpha2x1, + MPSMul: _MPSNode2x1, + MPSDiv: _MPSDivNode2x1, + MPSFmod: _MPSDivNode2x1, + MPSRemainder: _MPSDivNode2x1, + MPSMin: _MPSNode2x1, + MPSMax: _MPSNode2x1, + MPSPow: _MPSNode2x1, + MPSAtan2: _MPSNode2x1, + MPSBitwiseAnd: _MPSNode2x1, + MPSBitwiseOr: _MPSNode2x1, + MPSBitwiseXor: _MPSNode2x1, + MPSMinimum: _MPSNode2x1, + + // Unary ops + MPSExp: _MPSNode1x1, + MPSExp2: _MPSNode1x1, + MPSReciprocal: _MPSNode1x1, + MPSSqrt: _MPSNode1x1, + MPSNeg: _MPSNode1x1, + MPSLog: _MPSNode1x1, + MPSLog10: _MPSNode1x1, + MPSLog2: _MPSNode1x1, + MPSErf: _MPSNode1x1, + MPSFloor: _MPSNode1x1, + MPSCeil: _MPSNode1x1, + MPSRsqrt: _MPSNode1x1, + MPSSigmoid: _MPSNode1x1, + MPSSin: _MPSNode1x1, + MPSSign: _MPSNode1x1, + MPSCos: _MPSNode1x1, + MPSTan: _MPSNode1x1, + MPSAbs: _MPSNode1x1, + MPSAsin: _MPSNode1x1, + MPSAcos: _MPSNode1x1, + MPSAtan: _MPSNode1x1, + MPSSinh: _MPSNode1x1, + MPSCosh: _MPSNode1x1, + MPSTanh: _MPSNode1x1, + MPSAsinh: _MPSNode1x1, + MPSAcosh: _MPSNode1x1, + MPSAtanh: _MPSNode1x1, + MPSBitwiseNot: _MPSNode1x1, + MPSIsnan: _MPSNode1x1, + MPSIsinf: _MPSNode1x1, + MPSRound: _MPSNode1x1, + + // Linear algebra ops + MPSMatMul: _MPSNode2x1, + MPSAddmm, + + // Constant ops + MPSFull: _MPSFull, + MPSFullLike: _MPSFull, + + // Clamp ops, + MPSClamp, + MPSWhere: _MPSNode3x1, + + // Indexing ops + MPSIndexSelect, + MPSEmbedding, + + // Reduce ops + MPSMean, + + // Shape ops + MPSPermute, + MPSView, + MPSExpand: MPSView, + MPSCat, + MPSSqueeze, + MPSUnsqueeze, + MPSSelect, + MPSSlice, + MPSPixelShuffle, + MPSSplitWithSizes, + MPSCast, + + // Convolution ops + MPSConv2D: MPSConv, + MPSDepthwiseConv2D: MPSConv, + + // Comparasion ops + MPSEq: _MPSNode2x1, + MPSNe: _MPSNode2x1, + MPSGe: _MPSNode2x1, + MPSGt: _MPSNode2x1, + MPSLe: _MPSNode2x1, + MPSLt: _MPSNode2x1, + + // Normalization ops + MPSBatchNorm, + MPSLayerNorm, + + // Pooling ops + MPSMaxPool2DWithIndices: MPSPooling2D, + MPSAvgPool2D: MPSPooling2D, + + // Pad ops + MPSConstantPadND, + + // Range ops + MPSArange, +} + +table MPSNode { + mpsnode_union:MPSNodeUnion; + min_max:MPSMinMax; +} + +// taken from executorch +// Data buffer abstraction. +table Buffer { + storage:[ubyte] (force_align: 16); +} + +table MPSTensor { + datatype:MPSDataType; + num_dims:int; + dims:[int]; + constant_buffer_size:int; + constant_buffer:Buffer; +} + +table MPSGraph { + // Schema version. + version:string; + mps_nodes:[MPSNode]; + mps_values:[MPSTensor]; + + input_ids:[int]; + output_ids:[int]; + constant_ids:[int]; +} + +root_type MPSGraph; diff --git a/backends/apple/mps/targets.bzl b/backends/apple/mps/targets.bzl index 6a5fe4f9de2..64bfdc7c187 100644 --- a/backends/apple/mps/targets.bzl +++ b/backends/apple/mps/targets.bzl @@ -12,6 +12,7 @@ def define_common_targets(is_xplat = False, platforms = []): TARGETS and BUCK files that call this function. """ kwargs = { + "name": "mps", "compiler_flags": [ "-DEXIR_MPS_DELEGATE=1", "-Wno-global-constructors", @@ -23,26 +24,20 @@ def define_common_targets(is_xplat = False, platforms = []): "deps": [ "//executorch/runtime/core:core", "//executorch/runtime/core/exec_aten/util:tensor_util", + ":mps_schema", ], "exported_deps": [ "//executorch/runtime/backend:interface", + ":mps_schema", ], - "headers": [ - "runtime/MPSCompiler.h", - "runtime/MPSDevice.h", - "runtime/MPSExecutor.h", - "runtime/MPSStream.h", - "utils/MPSGraphPackageExport.h", - "utils/OperationUtils.h", - ], - "name": "mps", - "srcs": [ - "runtime/MPSBackend.mm", - "runtime/MPSCompiler.mm", - "runtime/MPSDevice.mm", - "runtime/MPSExecutor.mm", - "runtime/MPSStream.mm", - ], + "headers": native.glob([ + "runtime/*.h", + "runtime/operations/*.h", + ]), + "srcs": native.glob([ + "runtime/*.mm", + "runtime/operations/*.mm", + ]), "visibility": [ "//executorch/backends/apple/...", "//executorch/examples/...", @@ -65,4 +60,36 @@ def define_common_targets(is_xplat = False, platforms = []): kwargs["platforms"] = platforms if runtime.is_oss or is_xplat: + runtime.genrule( + name = "gen_mps_schema", + srcs = [ + "serialization/schema.fbs", + ], + outs = { + "schema_generated.h": ["schema_generated.h"], + }, + cmd = " ".join([ + "$(exe {})".format(runtime.external_dep_location("flatc")), + "--cpp", + "--cpp-std c++11", + "--scoped-enums", + "-o ${OUT}", + "${SRCS}", + ]), + default_outs = ["."], + ) + + runtime.cxx_library( + name = "mps_schema", + srcs = [], + exported_headers = { + "schema_generated.h": ":gen_mps_schema[schema_generated.h]", + }, + exported_external_deps = ["flatbuffers-api"], + visibility = [ + "//executorch/backends/apple/...", + "//executorch/examples/...", + ], + ) + runtime.cxx_library(**kwargs) diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index edf4beb5c8c..2d6d5a1ac25 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -19,14 +19,13 @@ from executorch.backends.apple.mps.test.test_mps_utils import ( _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG, - dump_executorch_program_info, OpSequencesAddConv2d, randomize_bn, TestMPS, ) -from executorch.exir import ExirExportedProgram from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.tests.models import ( BasicSinMax, CompositeDelegateModule, @@ -75,12 +74,11 @@ class MODEL_TYPE(Enum): def run_model( model: str, model_type: MODEL_TYPE = MODEL_TYPE.EXIR_DEFAULT_MODEL, - dump_non_lowered_module: bool = False, - dump_lowered_module: bool = False, + use_fp16: bool = False, ): logging.info(f"Step 1: Retrieving model: {model}...") if model_type == MODEL_TYPE.EXIR_DEFAULT_MODEL: - m, m_inputs, _ = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model]) + m, m_inputs = EagerModelFactory.create_model(*MODEL_NAME_TO_MODEL[model]) elif model_type == MODEL_TYPE.EXIR_TEST_MODEL: m, m_inputs = EXIR_MODEL_NAME_TO_MODEL.get(model)() elif model_type == MODEL_TYPE.MPS_TEST_MODEL: @@ -95,12 +93,12 @@ def run_model( _EDGE_COMPILE_CONFIG ) - if dump_non_lowered_module: - dump_executorch_program_info(edge=edge, module_info="Non-lowered") - # Step 3: Lower to MPSGraph logging.info("Step 3: Lowering to MPSGraph...") - lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, []) + compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))] + lowered_module = to_backend( + MPSBackend.__name__, edge.exported_program, compile_specs + ) logging.info("Step 4: Capturing executorch program with lowered module...") @@ -120,12 +118,6 @@ def forward(self, *args): .to_executorch() ) - if dump_lowered_module: - tmp_exported_program: ExirExportedProgram = exir.capture( - lowered_module, m_inputs, _CAPTURE_CONFIG - ).to_edge(_EDGE_COMPILE_CONFIG) - dump_executorch_program_info(edge=tmp_exported_program, module_info="Lowered") - logging.info("Step 5: Generating bundled program... ") logging.info( @@ -155,53 +147,6 @@ def forward(self, *args): file.write(bundled_program_buffer) -class TestMPSBackend_ExampleModels(unittest.TestCase): - def test_mul(self): - run_model(inspect.stack()[0].function[5:]) - - def test_linear(self): - run_model(inspect.stack()[0].function[5:]) - - def test_add(self): - run_model(inspect.stack()[0].function[5:]) - - def test_add_mul(self): - run_model(inspect.stack()[0].function[5:]) - - def test_emformer_transcribe(self): - run_model(inspect.stack()[0].function[5:]) - - def test_emformer_join(self): - run_model(inspect.stack()[0].function[5:]) - - def test_mobilebert(self): - run_model(inspect.stack()[0].function[5:]) - - def test_mv2(self): - run_model(inspect.stack()[0].function[5:]) - - def test_mv3(self): - run_model(inspect.stack()[0].function[5:]) - - def test_vit(self): - run_model(inspect.stack()[0].function[5:]) - - def test_ic3(self): - run_model(inspect.stack()[0].function[5:]) - - def test_ic4(self): - run_model(inspect.stack()[0].function[5:]) - - def test_resnet18(self): - run_model(inspect.stack()[0].function[5:]) - - def test_resnet50(self): - run_model(inspect.stack()[0].function[5:]) - - def test_edsr(self): - run_model(inspect.stack()[0].function[5:]) - - class TestMPSBackendExirModels(unittest.TestCase): def test_model_with_unused_arg(self): run_model(inspect.stack()[0].function[5:], MODEL_TYPE.EXIR_TEST_MODEL) @@ -390,7 +335,21 @@ def __init__(self): super().__init__() def forward(self, x): - return torch.squeeze(x, 2) + y = torch.squeeze(x, 2) + return torch.squeeze(y, 0) + + example_inputs = (torch.randn(1, 5, 1, 1, 4),) + self.lower_and_test_with_partitioner( + Squeeze(), example_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_backend_unsqueeze_dim_1(self): + class Squeeze(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.unsqueeze(x, 1) example_inputs = (torch.randn(1, 5, 1, 4),) self.lower_and_test_with_partitioner( @@ -464,7 +423,7 @@ def forward(self, x): TwoConv(), example_inputs, func_name=inspect.stack()[0].function[5:] ) - def test_mps_backend_conv2d_bn(self): + def test_mps_backend_conv2d_bn_1(self): class ModelConvBN(torch.nn.Module): def __init__(self, in_features: int, out_features: int, kernel_size): super().__init__() @@ -508,6 +467,76 @@ def test_mps_backend_conv2d(self): conv, example_inputs, func_name=inspect.stack()[0].function[5:] ) + def test_conv1d(self): + example_inputs = (torch.randn(1, 57, 40),) + stride = random.randint(1, 4) + padding = random.randint(1, 4) + conv = torch.nn.Conv1d( + 57, + 20, + stride=stride, + padding=padding, + kernel_size=3, + bias=random.choice([True, False]), + ) + conv.eval() + self.lower_and_test_with_partitioner( + conv, example_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_conv2d_simple(self): + N = 10 + C = 10 + H = 4 + W = 6 + groups = 2 + input_memory_format = torch.contiguous_format + weight_memory_format = torch.contiguous_format + strideX = random.randint(1, 4) + strideY = random.randint(1, 4) + example_inputs = ( + torch.randn(N, C, H, W).to(memory_format=input_memory_format), + ) + conv = torch.nn.Conv2d( + in_channels=N, + out_channels=C, + kernel_size=H, + groups=groups, + stride=(strideX, strideY), + bias=False, + ) + conv.weight.data = conv.weight.to(memory_format=weight_memory_format) + conv.eval() + self.lower_and_test_with_partitioner( + conv, example_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_conv2d_to_depthwise_conv_3d(self): + N = 10 + C = 10 + H = 4 + W = 6 + groups = 10 + input_memory_format = torch.contiguous_format + weight_memory_format = torch.contiguous_format + strideX = random.randint(1, 4) + strideY = random.randint(1, 4) + example_inputs = ( + torch.randn(N, C, H, W).to(memory_format=input_memory_format), + ) + conv = torch.nn.Conv2d( + in_channels=N, + out_channels=C, + kernel_size=H, + groups=groups, + stride=(strideX, strideY), + ) + conv.weight.data = conv.weight.to(memory_format=weight_memory_format) + conv.eval() + self.lower_and_test_with_partitioner( + conv, example_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_conv2d_single_int_params(self): groups = 1 stride = 2 @@ -578,6 +607,27 @@ def test_mps_backend_mm(self): linear, example_input, func_name=inspect.stack()[0].function[5:] ) + def test_mps_backend_bmm(self): + class BmmModule(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.bmm = torch.bmm + + def forward(self, x, y): + return self.bmm(x, y) + + mul_module = BmmModule() + model_inputs = ( + torch.randn((3, 1, 8)), + torch.randn((3, 8, 1)), + ) + + self.lower_and_test_with_partitioner( + mul_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_addmm(self): in_sizes = [1, 4, 4] input_sizes = [4, 37, 17] @@ -1496,6 +1546,23 @@ def forward(self, x): ReluModule(), (example_input,), func_name=inspect.stack()[0].function[5:] ) + def test_mps_backend_GELU(self): + class GELUModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + self.gelu_tanh = torch.nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(x) + # MPS TODO: MPS Gelu tanh fails + # return self.gelu_tanh(y) + + example_input = torch.randn(2, 3, 4) + self.lower_and_test_with_partitioner( + GELUModule(), (example_input,), func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_leaky_Relu(self): class LeakyReluModule(torch.nn.Module): def __init__(self): @@ -1641,7 +1708,29 @@ def forward(self, x, y, z): func_name=inspect.stack()[0].function[5:], ) - def test_mps_clamp(self): + def test_mps_clamp_min_max(self): + class Clamp(torch.nn.Module): + def __init__(self, min_val, max_val): + super().__init__() + self.clamp = torch.clamp + self.min_val = min_val + self.max_val = max_val + + def forward(self, *x): + out1 = self.clamp(x[0], min=-0.5, max=0.5) + out2 = self.clamp(x[0], min=-5, max=5) + return out1, out2 + + model_inputs = ( + torch.randn(1, 4, 122, 122) * 2, + torch.randint(-100, 100, (1, 4, 15, 20)), + ) + module = Clamp(-0.5, 0.5) + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_clamp_min(self): class Clamp(torch.nn.Module): def __init__(self, min_val, max_val): super().__init__() @@ -1653,7 +1742,24 @@ def forward(self, x): return self.clamp(x, min=self.min_val, max=self.max_val) model_inputs = (torch.randn(1, 4, 122, 122) * 2,) - module = Clamp(-0.5, 0.5) + module = Clamp(-0.5, None) + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_clamp_max(self): + class Clamp(torch.nn.Module): + def __init__(self, min_val, max_val): + super().__init__() + self.clamp = torch.clamp + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return self.clamp(x, min=self.min_val, max=self.max_val) + + model_inputs = (torch.randn(1, 4, 122, 122) * 2,) + module = Clamp(None, 0.5) self.lower_and_test_with_partitioner( module, model_inputs, func_name=inspect.stack()[0].function[5:] ) @@ -2383,6 +2489,7 @@ def __init__(self): def forward(self): out1 = torch.ops.aten.scalar_tensor(self._scalar) out2 = torch.ops.aten.scalar_tensor(self._scalar, dtype=torch.int32) + # issue 121117206 out3 = torch.ops.aten.scalar_tensor(self._bool, dtype=torch.bool) return out1 + out2 + out3 diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index c971c4f07e9..f59811fdf65 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -6,27 +6,76 @@ import logging import unittest -from typing import Any, Tuple +from typing import Any, Tuple, Union import executorch.exir as exir - import torch from executorch.backends.apple.mps.mps_preprocess import MPSBackend -from executorch.exir import ExecutorchProgram, ExirExportedProgram +from executorch.exir import ( + EdgeCompileConfig, + EdgeProgramManager, + ExecutorchProgram, + ExirExportedProgram, + to_edge, +) from executorch.exir.backend.backend_api import to_backend, validation_disabled -from executorch.exir.print_program import print_program +from executorch.exir.backend.backend_details import CompileSpec +from executorch.exir.tracer import Value from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.sdk import BundledProgram from executorch.sdk.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) +from torch._export import capture_pre_autograd_graph +from torch.export import export, ExportedProgram + # Config for Capturing the weights, will be moved in the future _CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True, _unlift=True) _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False) +def _to_core_aten( + model: Union[torch.fx.GraphModule, torch.nn.Module], + example_inputs: Tuple[Value, ...], +) -> ExportedProgram: + # post autograd export. eventually this will become .to_core_aten + if not isinstance(model, torch.fx.GraphModule): + raise ValueError( + f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}" + ) + core_aten_ep = export(model, example_inputs) + logging.info(f"Core ATen graph:\n{core_aten_ep.graph}") + return core_aten_ep + + +def _core_aten_to_edge( + core_aten_exir_ep: ExportedProgram, + edge_compile_config=None, +) -> EdgeProgramManager: + if not edge_compile_config: + edge_compile_config = exir.EdgeCompileConfig( + _check_ir_validity=False, # quant ops currently break ir verification + ) + edge_manager: EdgeProgramManager = to_edge( + core_aten_exir_ep, compile_config=edge_compile_config + ) + + edge_manager.exported_program().graph.print_tabular() + logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}") + return edge_manager + + +def export_to_edge( + model: Union[torch.fx.GraphModule, torch.nn.Module], + example_inputs: Tuple[Value, ...], + edge_compile_config=_EDGE_COMPILE_CONFIG, +) -> EdgeProgramManager: + core_aten_ep = _to_core_aten(model, example_inputs) + return _core_aten_to_edge(core_aten_ep, edge_compile_config) + + class ansi_colors: HEADER = "\033[95m" OKBLUE = "\033[94m" @@ -39,27 +88,6 @@ class ansi_colors: UNDERLINE = "\033[4m" -def dump_executorch_program_info( - edge: ExirExportedProgram, module_info: str = "Lowered" -): - module_info = f"\033[92m{module_info}\033[0m" - - logging.info("-----------------------------------") - logging.info(f"{module_info} exported edge graph:\n", edge.exported_program.graph) - executorch_program = edge.to_executorch() - program = executorch_program.program - logging.info("-----------------------------------") - logging.info(f"{module_info} flatbuffer representation:") - exir.print_program.pretty_print(program) - logging.info("-----------------------------------") - logging.info(f"{module_info} instruction list:") - print_program(program=program, show_meminfo=True, mark_dynamic_shape_tensor=True) - logging.info("-----------------------------------") - logging.info(f"{module_info} executorch program:") - logging.info(executorch_program.dump_exported_program()) - logging.info("-----------------------------------") - - class OpSequencesAddConv2d(torch.nn.Module): """ Module which include sequences of Memory Format sensitive ops. forward runs @@ -120,11 +148,10 @@ def lower_module_and_test_output( sample_inputs: Tuple[torch.Tensor], func_name: str, use_partitioner: bool = False, - dump_non_lowered_module: bool = False, - dump_lowered_module: bool = False, + use_fp16: bool = False, ) -> ExirExportedProgram: """ - Helper testing function that takes a torch.nn.Module and lowers it to XNNPACK with + Helper testing function that takes a torch.nn.Module and lowers it to MPS with the given sample inputs. It then runs the lowered module and compares its outputs with the outputs of the eager module. """ @@ -139,20 +166,24 @@ def __init__(self): def forward(self, *args): return self.one_module(*args) - edge_program = exir.capture( - WrappedModule(), sample_inputs, _CAPTURE_CONFIG - ).to_edge(_EDGE_COMPILE_CONFIG) + model = WrappedModule() + model = model.eval() + model = capture_pre_autograd_graph(model, sample_inputs) - if dump_non_lowered_module: - dump_executorch_program_info(edge=edge_program, module_info="Non-lowered") + edge_program = export_to_edge( + model, + sample_inputs, + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) logging.info("Step 2: Lowering to MPSGraph...") if use_partitioner: with validation_disabled(): None else: + compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))] delegated_program = to_backend( - "MPSBackend", edge_program.exported_program, [] + MPSBackend.__name__, edge_program.exported_program(), compile_specs ) logging.info("Step 3: Capturing executorch program with lowered module...") @@ -169,38 +200,27 @@ def forward(self, *args): WrappedModule(), sample_inputs, _CAPTURE_CONFIG ).to_edge(_EDGE_COMPILE_CONFIG) - if dump_lowered_module: - tmp_exported_program: ExirExportedProgram = exir.capture( - delegated_program, sample_inputs, _CAPTURE_CONFIG - ).to_edge(_EDGE_COMPILE_CONFIG) - dump_executorch_program_info( - edge=tmp_exported_program, module_info="Lowered" - ) - executorch_program: ExecutorchProgram = exported_program.to_executorch() - # Assert the backend name is mps - self.assertEqual( - executorch_program.program.execution_plan[0].delegates[0].id, - MPSBackend.__name__, - ) - logging.info("Step 4: Generating bundled program...") logging.info( " -> Number of execution plans: {len(executorch_program.program.execution_plan)}" ) + expected_output = module(*sample_inputs) + method_test_suites = [ MethodTestSuite( method_name="forward", test_cases=[ MethodTestCase( - input=sample_inputs, expected_outputs=module(*sample_inputs) + inputs=sample_inputs, expected_outputs=module(*sample_inputs) ) ], ) ] + logging.info(f"Expected output: {expected_output}") logging.info(" -> Test suites generated successfully") bundled_program = BundledProgram(executorch_program, method_test_suites) @@ -208,8 +228,8 @@ def forward(self, *args): bundled_program ) - filename = f"{func_name}.bpte" - logging.info(f"Step 5: Saving bundled program to {filename}...") + filename = f"{func_name}.pte" + logging.info(f"Step 5: Saving bundled program to {filename}") with open(filename, "wb") as file: file.write(bundled_program_buffer) @@ -218,6 +238,7 @@ def lower_and_test_with_partitioner( graph_module, example_inputs, func_name: str, + use_fp16: bool = False, ): logging.info(func_name) # MPS TODO: partitioner support @@ -226,4 +247,5 @@ def lower_and_test_with_partitioner( example_inputs, use_partitioner=False, func_name=func_name, + use_fp16=use_fp16, ) diff --git a/backends/apple/mps/utils/Bindings.mm b/backends/apple/mps/utils/Bindings.mm deleted file mode 100644 index a65b4a1c84f..00000000000 --- a/backends/apple/mps/utils/Bindings.mm +++ /dev/null @@ -1,345 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "MPSGraphInterface.h" - -namespace mps { -namespace { -using namespace torch; -// Create Python bindings for the Objective-C++ code. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // Main class to interface with MPSGraph. - py::class_(m, "MPSGraphModule") - // MPSGraphModule constructor. - .def(py::init<>()) - - // - // Graph placeholders. - // - .def("mpsGraphUnrankedPlaceHolder", &MPSGraphModule::mpsGraphUnrankedPlaceHolder) - .def("mpsGraphRankedPlaceHolder", &MPSGraphModule::mpsGraphRankedPlaceHolder) - .def("mpsGraphScalarPlaceHolder", &MPSGraphModule::mpsGraphScalarPlaceHolder) - .def("set_outputs", &MPSGraphModule::set_outputs) - - // - // Graph operators - // - .def("constant", &MPSGraphModule::constant) - .def("constantTensor", &MPSGraphModule::constantTensor) - .def("full", &MPSGraphModule::full) - .def("full_like", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, double scalar) { - return self.full_like(static_cast(inputTensor), scalar); - }) - .def("mm", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor) { - return self.mm(static_cast(primaryTensor), static_cast(secondaryTensor)); - }) - .def("conv2D", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor, - std::optional biasTensor, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef outputPadding, int64_t groups, bool is_depthwise) { - MPSGraphTensor *optionalBias = nullptr; - MPSGraphTensor *inputTensor = static_cast(primaryTensor); - if(biasTensor.has_value()){ - optionalBias = static_cast(*biasTensor); - } - return self.conv2D(static_cast(primaryTensor), static_cast(secondaryTensor), - optionalBias, stride, padding, dilation, transposed, outputPadding, groups, is_depthwise); - }) - .def("maxPool2DWithIndices", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef kernel_size, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) { - return self.maxPool2DWithIndices(static_cast(inputTensor), kernel_size, stride, padding, - dilation, ceil_mode); - }) - .def("avgPool2D", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef kernel_size, - IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, - c10::optional divisor_override) { - return self.avgPool2D(static_cast(inputTensor), kernel_size, stride, padding, - ceil_mode, count_include_pad, divisor_override); - - }) - .def("batchNorm", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, PyMPSGraphTensor* weightTensor, - PyMPSGraphTensor* biasTensor, PyMPSGraphTensor* meanTensor, PyMPSGraphTensor* varTensor, - float momentum, float epsilon) { - std::tuple result = self.batchNorm( - static_cast(inputTensor), - static_cast(meanTensor), - static_cast(varTensor), - static_cast(weightTensor), - static_cast(biasTensor), - momentum, epsilon); - return result; - - }) - .def("layerNorm", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef normalized_shape, - c10::optional weightTensor_opt, c10::optional biasTensor_opt, - float epsilon) { - - MPSGraphTensor* weightTensor = nil; - MPSGraphTensor* biasTensor = nil; - if(weightTensor_opt.has_value()){ - weightTensor = static_cast(*weightTensor_opt); - } - if(biasTensor_opt.has_value()){ - biasTensor = static_cast(*biasTensor_opt); - } - - std::tuple result = self.layerNorm( - static_cast(inputTensor), normalized_shape, - weightTensor, biasTensor, epsilon); - - return result; - }) - .def("hardTanh", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, float min_value, float max_value) { - return self.hardTanh(static_cast(inputTensor), min_value, max_value); - }) - .def("mean", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims) { - return self.mean(static_cast(inputTensor), dims, keep_dims); - }) - .def("minDim", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim, bool keep_dims) { - return self.minDim(static_cast(inputTensor), dim, keep_dims); - }) - .def("maxDim", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim, bool keep_dims) { - return self.maxDim(static_cast(inputTensor), dim, keep_dims); - }) - .def("amax", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims) { - return self.amax(static_cast(inputTensor), dims, keep_dims); - }) - .def("amin", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims) { - return self.amin(static_cast(inputTensor), dims, keep_dims); - }) - .def("argmax", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int64_t dim, bool keep_dims, bool flatten) { - return self.argmax(static_cast(inputTensor), dim, keep_dims, flatten); - }) - .def("argmin", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int64_t dim, bool keep_dims, bool flatten) { - return self.argmin(static_cast(inputTensor), dim, keep_dims, flatten); - }) - .def("identity", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor) { - return self.identity(static_cast(inputTensor)); - }) - .def("clamp", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, float min, float max, bool use_min, bool use_max) { - return self.clamp(static_cast(inputTensor), min, max, use_min, use_max); - }) - .def("relu", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor) { - return self.relu(static_cast(inputTensor)); - }) - .def("leaky_relu", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, float negative_slope) { - return self.leaky_relu(static_cast(inputTensor), negative_slope); - }) - .def("softmax", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim, bool half_to_float) { - return self.softmax(static_cast(inputTensor), dim, half_to_float); - }) - .def("log_softmax", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim, bool half_to_float) { - return self.log_softmax(static_cast(inputTensor), dim, half_to_float); - }) - .def("squeeze", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor) { - return self.squeeze(static_cast(inputTensor)); - }) - .def("squeeze", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim) { - return self.squeeze(static_cast(inputTensor), dim); - }) - .def("squeeze", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef dims) { - return self.squeeze(static_cast(inputTensor), dims); - }) - .def("unsqueeze", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dimension) { - return self.unsqueeze(static_cast(inputTensor), dimension); - }) - .def("gelu", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, const std::string &approximation) { - return self.gelu(static_cast(inputTensor), approximation); - }) - .def("glu", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int64_t dim) { - return self.glu(static_cast(inputTensor), dim); - }) - .def("pixel_shuffle", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int upscale_factor) { - return self.pixel_shuffle(static_cast(inputTensor), upscale_factor); - }) - .def("split_size", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef split_sizes, int dim) { - return self.split_size(static_cast(inputTensor), split_sizes, dim); - }) - .def("split", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int split_size, int dim) { - return self.split(static_cast(inputTensor), split_size, dim); - }) - .def("unbind", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim) { - return self.unbind(static_cast(inputTensor), dim); - }) - .def("cat", [](MPSGraphModule& self,int dim, py::args catTensors) { - return self.cat(dim, catTensors); - }) - .def("stack", [](MPSGraphModule& self,int dim, py::args stackTensors) { - return self.stack(dim, stackTensors); - }) - .def("slice", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int64_t dim, c10::optional start, c10::optional end, int64_t step) { - return self.slice(static_cast(inputTensor), dim, start, end, step); - }) - .def("expand", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef sizes){ - return self.expand(static_cast(inputTensor), sizes); - }) - .def("select", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim, int index) { - return self.select(static_cast(inputTensor), dim, index); - }) - .def("view", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef shape){ - return self.view(static_cast(inputTensor), shape); - }) - .def("permute", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef axes) { - return self.permute(static_cast(inputTensor), axes); - }) - .def("cumsum", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int dim) { - return self.cumsum(static_cast(inputTensor), dim); - }) - .def("addmm", [](MPSGraphModule& self, PyMPSGraphTensor* biasTensor, - PyMPSGraphTensor* inputTensor, PyMPSGraphTensor* weightTensor, - float beta, float alpha) { - return self.addmm(static_cast(biasTensor), - static_cast(inputTensor), - static_cast(weightTensor), - beta, alpha); - }) - .def("constant_pad_nd", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, IntArrayRef pad, const double value) { - return self.constant_pad_nd(static_cast(inputTensor), pad, value); - }) - .def("add", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor, float alpha) { - return self.additionWithTensor(static_cast(primaryTensor), static_cast(secondaryTensor), alpha); - }) - .def("add", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor, int alpha) { - return self.additionWithTensor(static_cast(primaryTensor), static_cast(secondaryTensor), alpha); - }) - .def("sub", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor, float alpha) { - return self.subtractionWithTensor(static_cast(primaryTensor), static_cast(secondaryTensor), alpha); - }) - .def("sub", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor, int alpha) { - return self.subtractionWithTensor(static_cast(primaryTensor), static_cast(secondaryTensor), alpha); - }) - .def("mulWithScalar", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, float scalar) { - return self.multiplicationWithScalar(static_cast(inputTensor), scalar); - }) - .def("mulWithScalar", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int scalar) { - return self.multiplicationWithScalar(static_cast(inputTensor), scalar); - }) - .def("arange", [](MPSGraphModule& self, int64_t start, int64_t end, int64_t step, - MPSDataType dtype, int numOfElements) { - return self.arange(start, end, step, dtype, numOfElements); - }) - .def("arange", [](MPSGraphModule& self, float start, float end, float step, - MPSDataType dtype, int numOfElements) { - return self.arange(start, end, step, dtype, numOfElements); - }) - .def("where", [](MPSGraphModule& self, PyMPSGraphTensor* cond, PyMPSGraphTensor* input, - PyMPSGraphTensor* other) { - return self.where(static_cast(cond), static_cast(input), - static_cast(other)); - }) - .def("scalar_out", [](MPSGraphModule& self, double scalar, MPSDataType dtype) { - return self.constant(scalar, dtype); - }) - .def("index_select", [](MPSGraphModule& self, PyMPSGraphTensor* inputTensor, int64_t dim, PyMPSGraphTensor* indexTensor) { - return self.index_select(static_cast(inputTensor), dim, static_cast(indexTensor)); - }) - .def("empty", [](MPSGraphModule& self, IntArrayRef sizes, MPSDataType dtype) { - return self.constantWithScalar(dtype, sizes, 0); - }) - // Arithmetic Binary Ops - REGISTER_PYBIND11_MPS_BINARY_OP("add", addition) - REGISTER_PYBIND11_MPS_BINARY_OP("sub", subtraction) - REGISTER_PYBIND11_MPS_BINARY_OP("mul", multiplication) - REGISTER_PYBIND11_MPS_BINARY_OP("min", minimum) - REGISTER_PYBIND11_MPS_BINARY_OP("max", maximum) - REGISTER_PYBIND11_MPS_BINARY_OP("pow", power) - REGISTER_PYBIND11_MPS_BINARY_OP("remainder", modulo) - REGISTER_PYBIND11_MPS_BINARY_OP("atan2", atan2) - REGISTER_PYBIND11_MPS_BINARY_OP("bmm", matrixMultiplication) - REGISTER_PYBIND11_MPS_BINARY_OP("minimum", minimum) - - // Comparison Ops - REGISTER_PYBIND11_MPS_BINARY_OP("eq", equal) - REGISTER_PYBIND11_MPS_BINARY_OP("ne", notEqual) - REGISTER_PYBIND11_MPS_BINARY_OP("ge", greaterThanOrEqualTo) - REGISTER_PYBIND11_MPS_BINARY_OP("gt", greaterThan) - REGISTER_PYBIND11_MPS_BINARY_OP("le", lessThanOrEqualTo) - REGISTER_PYBIND11_MPS_BINARY_OP("lt", lessThan) - - // Bitwise Ops - REGISTER_PYBIND11_MPS_BITWISE_BINARY_OP("bitwise_and", AND) - REGISTER_PYBIND11_MPS_BITWISE_BINARY_OP("bitwise_or", OR) - REGISTER_PYBIND11_MPS_BITWISE_BINARY_OP("bitwise_xor", XOR) - - .def("bitwise_not", [](MPSGraphModule& self, PyMPSGraphTensor* input) { - return self.bitwiseNotTensor(static_cast(input), "bitwise_not"); - }) - - // Boolean Binary Ops - REGISTER_PYBIND11_MPS_BINARY_OP("eq", equal) - REGISTER_PYBIND11_MPS_BINARY_OP("ne", notEqual) - REGISTER_PYBIND11_MPS_BINARY_OP("le", lessThanOrEqualTo) - REGISTER_PYBIND11_MPS_BINARY_OP("lt", lessThan) - REGISTER_PYBIND11_MPS_BINARY_OP("ge", greaterThanOrEqualTo) - REGISTER_PYBIND11_MPS_BINARY_OP("gt", greaterThan) - - // Unary Ops - - REGISTER_PYBIND11_MPS_UNARY_OP("abs", absolute) - REGISTER_PYBIND11_MPS_UNARY_OP("exp", exponent) - REGISTER_PYBIND11_MPS_UNARY_OP("exp2", exponentBase2) - REGISTER_PYBIND11_MPS_UNARY_OP("reciprocal", reciprocal) - REGISTER_PYBIND11_MPS_UNARY_OP("sqrt", squareRoot) - REGISTER_PYBIND11_MPS_UNARY_OP("neg", negative) - REGISTER_PYBIND11_MPS_UNARY_OP("log", logarithm) - REGISTER_PYBIND11_MPS_UNARY_OP("log10", logarithmBase10) - REGISTER_PYBIND11_MPS_UNARY_OP("log2", logarithmBase2) - REGISTER_PYBIND11_MPS_UNARY_OP("erf", erf) - REGISTER_PYBIND11_MPS_UNARY_OP("floor", floor) - REGISTER_PYBIND11_MPS_UNARY_OP("ceil", ceil) - REGISTER_PYBIND11_MPS_UNARY_OP("rsqrt", reverseSquareRoot) - REGISTER_PYBIND11_MPS_UNARY_OP("sin", sin) - REGISTER_PYBIND11_MPS_UNARY_OP("sign", sign) - REGISTER_PYBIND11_MPS_UNARY_OP("sigmoid", sigmoid) - REGISTER_PYBIND11_MPS_UNARY_OP("cos", cos) - REGISTER_PYBIND11_MPS_UNARY_OP("tan", tan) - REGISTER_PYBIND11_MPS_UNARY_OP("asin", asin) - REGISTER_PYBIND11_MPS_UNARY_OP("acos", acos) - REGISTER_PYBIND11_MPS_UNARY_OP("atan", atan) - REGISTER_PYBIND11_MPS_UNARY_OP("sinh", sinh) - REGISTER_PYBIND11_MPS_UNARY_OP("cosh", cosh) - REGISTER_PYBIND11_MPS_UNARY_OP("tanh", tanh) - REGISTER_PYBIND11_MPS_UNARY_OP("asinh", asinh) - REGISTER_PYBIND11_MPS_UNARY_OP("acosh", acosh) - REGISTER_PYBIND11_MPS_UNARY_OP("atanh", atanh) - REGISTER_PYBIND11_MPS_UNARY_OP("isinf", isInfinite) - REGISTER_PYBIND11_MPS_UNARY_OP("isnan", isNaN) - REGISTER_PYBIND11_MPS_UNARY_OP("round", round) - - .def("div", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor) { - return self.div_mode_template(static_cast(primaryTensor), static_cast(secondaryTensor), c10::nullopt, "div"); - }) - .def("fmod", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor) { - return self.div_mode_template(static_cast(primaryTensor), static_cast(secondaryTensor), "trunc", "fmod_mps_out"); - }) - .def("floor_divide", [](MPSGraphModule& self, PyMPSGraphTensor* primaryTensor, PyMPSGraphTensor* secondaryTensor) { - return self.div_mode_template(static_cast(primaryTensor), static_cast(secondaryTensor), "floor", "floor_divide"); - }) - - // - // Graph debug methods. - // - .def("printGraph", &MPSGraphModule::printGraph) - - // - // Serialization / deserialization methods. - // - .def("serialize", &MPSGraphModule::serialize); - - // Export `MPSDataType` Objective-C enum to python. - py::enum_(m, "MPSDataType") - .value("MPSDataTypeTypeInvalid", MPSDataType::MPSDataTypeFloatBit) - .value("MPSDataTypeFloat32", MPSDataType::MPSDataTypeFloat32) - .value("MPSDataTypeFloat16", MPSDataType::MPSDataTypeFloat16) - .value("MPSDataTypeInt32", MPSDataType::MPSDataTypeInt32) - .value("MPSDataTypeInt64", MPSDataType::MPSDataTypeInt64) - .value("MPSDataTypeInt16", MPSDataType::MPSDataTypeInt16) - .value("MPSDataTypeInt8", MPSDataType::MPSDataTypeInt8) - .value("MPSDataTypeUInt8", MPSDataType::MPSDataTypeUInt8) - .value("MPSDataTypeBool", MPSDataType::MPSDataTypeBool) - .export_values(); -} - -} // namespace -} // namespace mps diff --git a/backends/apple/mps/utils/MPSGraphInterface.h b/backends/apple/mps/utils/MPSGraphInterface.h deleted file mode 100644 index d51b6f1353f..00000000000 --- a/backends/apple/mps/utils/MPSGraphInterface.h +++ /dev/null @@ -1,259 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#pragma once - -#include -#include -#include -#include "OperationUtils.h" -#include "operations/BinaryOps.h" -#include "operations/UnaryOps.h" - -// Workaround for PyBind custom class return type. -// We need a type caster -// (https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html) for -// MPSGraphTensor. Return `void*` for now instead of `MPSGraphTensor*`. -typedef void PyMPSGraphTensor; - -namespace mps { - -using namespace torch; - -// ExecuTorch is supported only from macOS 14.0 and above -// Previous macOS version don't have support to generate .mpsgraphpackage -enum class MacOSVersion : uint32_t { - MACOS_VER_14_0_PLUS = 0, -}; - -class MPSGraphModule { - public: - MPSGraphModule(); - ~MPSGraphModule(); - - // Graph placeholders. - PyMPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSDataType dataType); - PyMPSGraphTensor* mpsGraphRankedPlaceHolder( - MPSDataType dataType, - const IntArrayRef& shape); - PyMPSGraphTensor* mpsGraphScalarPlaceHolder(MPSDataType dataType); - void set_outputs(py::args args); - - // Graph operators. - PyMPSGraphTensor* constant(double scalar, MPSDataType dataType); - PyMPSGraphTensor* constantTensor( - Tensor constant_tensor, - MPSDataType dataType); - PyMPSGraphTensor* constantWithScalar( - MPSDataType dtype, - const IntArrayRef& sizes, - double scalar); - PyMPSGraphTensor* full(IntArrayRef size, double scalar, MPSDataType dataType); - PyMPSGraphTensor* full_like(MPSGraphTensor* inputTensor, double scalar); - std::tuple batchNorm( - MPSGraphTensor* inputTensor, - MPSGraphTensor* meanTensor, - MPSGraphTensor* varTensor, - MPSGraphTensor* weightTensor, - MPSGraphTensor* biasTensor, - float momentum, - float epsilon); - std::tuple layerNorm( - MPSGraphTensor* inputTensor, - IntArrayRef normalized_shape, - MPSGraphTensor* weightTensor, - MPSGraphTensor* biasTensor, - float epsilon); - PyMPSGraphTensor* conv2D( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - MPSGraphTensor* bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool transpose, - IntArrayRef outputPadding, - int64_t groups, - bool is_depthwise); - std::tuple maxPool2DWithIndices( - MPSGraphTensor* inputTensor, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - bool ceil_mode); - PyMPSGraphTensor* avgPool2D( - MPSGraphTensor* inputTensor, - IntArrayRef kernel_size, - IntArrayRef stride, - IntArrayRef padding, - bool ceil_mode, - bool count_include_pad, - c10::optional divisor_override); - PyMPSGraphTensor* - hardTanh(MPSGraphTensor* inputTensor, float min_val, float max_val); - PyMPSGraphTensor* - mean(MPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims); - std::tuple - minDim(MPSGraphTensor* inputTensor, int dim, bool keep_dims); - std::tuple - maxDim(MPSGraphTensor* inputTensor, int dim, bool keep_dims); - PyMPSGraphTensor* - amax(MPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims); - PyMPSGraphTensor* - amin(MPSGraphTensor* inputTensor, IntArrayRef dims, bool keep_dims); - PyMPSGraphTensor* argmax( - MPSGraphTensor* inputTensor, - int64_t dim, - bool keep_dims, - bool flatten); - PyMPSGraphTensor* argmin( - MPSGraphTensor* inputTensor, - int64_t dim, - bool keep_dims, - bool flatten); - PyMPSGraphTensor* mm( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor); - PyMPSGraphTensor* identity(MPSGraphTensor* inputTensor); - PyMPSGraphTensor* clamp( - MPSGraphTensor* inputTensor, - float min, - float max, - bool use_min, - bool use_max); - PyMPSGraphTensor* relu(MPSGraphTensor* inputTensor); - PyMPSGraphTensor* leaky_relu( - MPSGraphTensor* inputTensor, - float negative_slope); - PyMPSGraphTensor* - softmax(MPSGraphTensor* inputTensor, int dim, bool half_to_float); - PyMPSGraphTensor* - log_softmax(MPSGraphTensor* inputTensor, int dim, bool half_to_float); - PyMPSGraphTensor* squeeze(MPSGraphTensor* inputTensor); - PyMPSGraphTensor* squeeze(MPSGraphTensor* inputTensor, int dim); - PyMPSGraphTensor* squeeze(MPSGraphTensor* inputTensor, IntArrayRef dim); - PyMPSGraphTensor* unsqueeze(MPSGraphTensor* inputTensor, int dimension); - PyMPSGraphTensor* gelu( - MPSGraphTensor* inputTensor, - const std::string& approximation); - PyMPSGraphTensor* glu(MPSGraphTensor* inputTensor, int64_t dim); - PyMPSGraphTensor* cat(int dim, py::args catTensors); - PyMPSGraphTensor* pixel_shuffle( - MPSGraphTensor* inputTensor, - int upscale_factor); - std::vector - split(MPSGraphTensor* inputTensor, int split_size, int dim); - std::vector - split_size(MPSGraphTensor* inputTensor, IntArrayRef split_sizes, int dim); - std::vector unbind(MPSGraphTensor* inputTensor, int dim); - PyMPSGraphTensor* stack(int dim, py::args stackTensors); - PyMPSGraphTensor* slice( - MPSGraphTensor* inputTensor, - int64_t dim, - c10::optional start, - c10::optional end, - int64_t step); - PyMPSGraphTensor* expand(MPSGraphTensor* inputTensor, IntArrayRef sizes); - PyMPSGraphTensor* select(MPSGraphTensor* inputTensor, int dim, int index); - PyMPSGraphTensor* view(MPSGraphTensor* inputTensor, IntArrayRef shape); - PyMPSGraphTensor* permute(MPSGraphTensor* inputTensor, IntArrayRef axes); - PyMPSGraphTensor* cumsum(MPSGraphTensor* inputTensor, int dim); - PyMPSGraphTensor* addmm( - MPSGraphTensor* biasTensor, - MPSGraphTensor* inputTensor, - MPSGraphTensor* weightTensor, - float beta, - float alpha); - MPSGraphTensor* trunc_tensor(MPSGraphTensor* inputTensor); - - // Binary Ops - PyMPSGraphTensor* div_mode_template( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - c10::optional rounding_mode, - const string& op_name); - PyMPSGraphTensor* binaryOpTensor( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - const std::string& op_name, - std::function - binaryOpFunction); - - PyMPSGraphTensor* binaryOpWithScalar( - MPSGraphTensor* inputTensor, - Scalar scalar, - const std::string& op_name, - std::function - binaryOpFunction); - - // Unary Ops - PyMPSGraphTensor* unaryOpTensor( - MPSGraphTensor* inputTensor, - const std::string& op_name, - std::function unaryOpFunction); - - PyMPSGraphTensor* additionWithTensor( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - Scalar alpha); - PyMPSGraphTensor* subtractionWithTensor( - MPSGraphTensor* primaryTensor, - MPSGraphTensor* secondaryTensor, - Scalar alpha); - PyMPSGraphTensor* multiplicationWithScalar( - MPSGraphTensor* inputTensor, - Scalar scalar); - - // bitwise Ops - PyMPSGraphTensor* bitwiseNotTensor( - MPSGraphTensor* inputTensor, - const std::string& op_name); - - // Pad Ops - PyMPSGraphTensor* - constant_pad_nd(MPSGraphTensor* input, IntArrayRef pad, const double value); - - // range Ops - PyMPSGraphTensor* arange( - Scalar start, - Scalar end, - Scalar step, - MPSDataType dataType, - const int numEle); - - // trinary Ops - PyMPSGraphTensor* - where(MPSGraphTensor* cond, MPSGraphTensor* input, MPSGraphTensor* other); - - // Indexing Ops - PyMPSGraphTensor* index_select( - MPSGraphTensor* inputTensor, - int64_t dim, - MPSGraphTensor* indexTensor); - - MPSGraph* getMPSGraph() { - return mpsGraph; - } - - // Graph debug methods. - void printGraph(); - bool macos_version_or_newer( - MacOSVersion version = MacOSVersion::MACOS_VER_14_0_PLUS); - - MPSGraphExecutable* compileMPSGraphExecutable(); - std::vector serialize(); - - private: - MPSGraph* mpsGraph; - std::vector outputTensors_; - std::vector inputTensors_; - MPSGraphExecutable* executable_; - - id device_; - id commandQueue_; -}; - -} // namespace mps diff --git a/backends/apple/mps/utils/MPSGraphInterface.mm b/backends/apple/mps/utils/MPSGraphInterface.mm deleted file mode 100644 index f2c270adb18..00000000000 --- a/backends/apple/mps/utils/MPSGraphInterface.mm +++ /dev/null @@ -1,154 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "MPSGraphInterface.h" -#include "MPSGraphPackageExport.h" - -namespace mps { - -using namespace torch; -MPSGraphModule::MPSGraphModule() { - TORCH_CHECK(macos_version_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS), - "MPS Executorch backend is supported only from macOS 14.0 and above."); - - mpsGraph = [MPSGraph new]; - device_ = MTLCreateSystemDefaultDevice(); - commandQueue_ = [device_ newCommandQueue]; -} - -MPSGraphModule::~MPSGraphModule() { - [mpsGraph release]; -} - -PyMPSGraphTensor* -MPSGraphModule::mpsGraphUnrankedPlaceHolder(MPSDataType dataType) { - inputTensors_.push_back([mpsGraph placeholderWithShape:nil - dataType:dataType - name:nil]); - return inputTensors_.back(); -} - -PyMPSGraphTensor* -MPSGraphModule::mpsGraphRankedPlaceHolder(MPSDataType dataType, const at::IntArrayRef& shape) { - inputTensors_.push_back([mpsGraph placeholderWithShape:getMPSShape(shape) - dataType:dataType - name:nil]); - return inputTensors_.back(); -} - -PyMPSGraphTensor* -MPSGraphModule::mpsGraphScalarPlaceHolder(MPSDataType dataType) { - inputTensors_.push_back([mpsGraph placeholderWithShape:@[@1] - dataType:dataType - name:nil]); - return inputTensors_.back(); -} - -void -MPSGraphModule::set_outputs(py::args args) { - for (const auto i: c10::irange(args.size())) { - MPSGraphTensor* outputTensor = static_cast(pybind11::cast(args[i])); - outputTensors_.push_back(outputTensor); - } -} - -PyMPSGraphTensor* -MPSGraphModule::mm(MPSGraphTensor* primaryTensor, MPSGraphTensor* secondaryTensor) { - return [mpsGraph matrixMultiplicationWithPrimaryTensor:primaryTensor - secondaryTensor:secondaryTensor - name:nil]; -} - -PyMPSGraphTensor* -MPSGraphModule::identity(MPSGraphTensor* inputTensor) { - return [mpsGraph identityWithTensor:inputTensor - name:nil]; -} - -bool MPSGraphModule::macos_version_or_newer(MacOSVersion version) { - id mpsCD = NSClassFromString(@"MPSGraph"); - static auto compileOptions = [[[MTLCompileOptions alloc] init] autorelease]; - - static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(imToColWithSourceTensor:descriptor:name:)] == YES; - - switch (version) { - case MacOSVersion::MACOS_VER_14_0_PLUS: return _macos_14_0_plus; - default: return false; - } -} - -void MPSGraphModule::printGraph() { - NSLog(@"%@", [mpsGraph debugDescription]); -} - -MPSGraphExecutable* -MPSGraphModule::compileMPSGraphExecutable() { - NSMutableDictionary *feeds = [NSMutableDictionary dictionary]; - for (const auto i: c10::irange(inputTensors_.size())) { - feeds[inputTensors_[i]] = [[MPSGraphShapedType alloc] initWithShape:[inputTensors_[i] shape] dataType:[inputTensors_[i] dataType]]; - } - - NSMutableArray *targetTensors = [NSMutableArray new]; - std::for_each(outputTensors_.begin(), outputTensors_.end(), ^(MPSGraphTensor* outputTensor) { - [targetTensors addObject:outputTensor]; - }); - - MPSGraphExecutable *exec = [mpsGraph compileWithDevice:[MPSGraphDevice deviceWithMTLDevice:device_] - feeds:feeds - targetTensors:targetTensors - targetOperations:nil - compilationDescriptor:nil]; - - return exec; -} - -std::vector MPSGraphModule::serialize() { - MPSGraphExecutable* exec = compileMPSGraphExecutable(); - - std::string dataFolder = "/tmp/"; - - std::string name = "mpsgraphmodule_" + std::to_string(arc4random_uniform(INT_MAX)); - std::string mpsgraphpackagePath = dataFolder + name + ".mpsgraphpackage"; - - NSString *mpsgraphpackageFileStr = [NSString stringWithUTF8String:mpsgraphpackagePath.c_str()]; - NSURL *bundleURL = [NSURL fileURLWithPath:mpsgraphpackageFileStr]; - - MPSGraphExecutableSerializationDescriptor *serializationDescriptor = [MPSGraphExecutableSerializationDescriptor new]; - serializationDescriptor.deploymentPlatform = MPSGraphDeploymentPlatformMacOS; - serializationDescriptor.minimumDeploymentTarget = @"14.0.0"; - [exec serializeToMPSGraphPackageAtURL:bundleURL descriptor:serializationDescriptor]; - - NSString* mpsgraphpackage_manifest_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/manifest.plist").c_str()]; - NSString* mpsgraphpackage_model_0_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_0.mpsgraph").c_str()]; - NSString* mpsgraphpackage_model_1_file = [NSString stringWithUTF8String:(mpsgraphpackagePath + "/model_1.mpsgraph").c_str()]; - - NSURL* manifestPlistURL = [NSURL fileURLWithPath:mpsgraphpackage_manifest_file]; - NSURL* model0URL = [NSURL fileURLWithPath:mpsgraphpackage_model_0_file]; - NSURL* model1URL = [NSURL fileURLWithPath:mpsgraphpackage_model_1_file]; - - NSData* manifest_plist_data = [NSData dataWithContentsOfURL:manifestPlistURL]; - NSData* model_0_data = [NSData dataWithContentsOfURL:model0URL]; - NSData* model_1_data = [NSData dataWithContentsOfURL:model1URL]; - - int64_t total_package_size = sizeof(ExirMPSGraphPackage) + [manifest_plist_data length] + [model_0_data length] + [model_1_data length]; - ExirMPSGraphPackage *exirMPSGraphPackage = (ExirMPSGraphPackage*)malloc(total_package_size); - assert(exirMPSGraphPackage != nil); - - exirMPSGraphPackage->manifest_plist_offset = 0; - exirMPSGraphPackage->model_0_offset = [manifest_plist_data length]; - exirMPSGraphPackage->model_1_offset = exirMPSGraphPackage->model_0_offset + [model_0_data length]; - exirMPSGraphPackage->total_bytes = total_package_size; - - memcpy(exirMPSGraphPackage->data, [manifest_plist_data bytes], [manifest_plist_data length]); - memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_0_offset, [model_0_data bytes], [model_0_data length]); - memcpy(exirMPSGraphPackage->data + exirMPSGraphPackage->model_1_offset, [model_1_data bytes], [model_1_data length]); - - std::vector data((uint8_t*)exirMPSGraphPackage, (uint8_t*)exirMPSGraphPackage + total_package_size); - free(exirMPSGraphPackage); - - return data; -} - -} // namespace mps diff --git a/backends/apple/mps/utils/MPSGraphPackageExport.h b/backends/apple/mps/utils/MPSGraphPackageExport.h deleted file mode 100644 index fb84ed1d2e3..00000000000 --- a/backends/apple/mps/utils/MPSGraphPackageExport.h +++ /dev/null @@ -1,16 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#pragma once - -// Export the contents of the MPSGraphPackage directly as a list of bytes. -// Any changes to this structure will break previous exported models. -struct ExirMPSGraphPackage { - int64_t manifest_plist_offset; - int64_t model_0_offset; - int64_t model_1_offset; - int64_t total_bytes; - uint8_t data[]; -}; diff --git a/backends/apple/mps/utils/OperationUtils.h b/backends/apple/mps/utils/OperationUtils.h deleted file mode 100644 index 3be9b624bfe..00000000000 --- a/backends/apple/mps/utils/OperationUtils.h +++ /dev/null @@ -1,101 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#pragma once - -#include -#include -#if !EXIR_MPS_DELEGATE -#include -#include -#include -#else -#include -#include -#endif -#import -#import - -namespace mps { - -#if EXIR_MPS_DELEGATE -using torch::executor::mps::delegate::MPSDevice; -#endif - -#if EXIR_MPS_DELEGATE -using namespace exec_aten; -#else -using namespace torch; -#endif - -MPSDataType getMPSDataType(ScalarType scalar_type); -MPSDataType getMPSScalarType(ScalarType scalar_type); -ScalarType getScalarType(MPSDataType mpsDataType); -MPSGraphTensor* -castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, ScalarType toType); -MPSGraphTensor* -castMPSTensor(MPSGraph* mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); - -// The MPSShape could vary based on memory format -MPSShape* getMPSShape( - const Tensor& t, - MemoryFormat memory_format = MemoryFormat::Contiguous); -MPSShape* getMPSShape( - const IntArrayRef& sizes, - MemoryFormat memory_format = MemoryFormat::Contiguous); -std::vector getMPSShapeVec(const MPSShape* shape); - -static inline id getMTLBufferStorage(const Tensor& tensor) { -#if EXIR_MPS_DELEGATE -#if TARGET_OS_SIMULATOR - // Simulator crashes in newBufferWithBytesNoCopy, so we're making a copy of - // the data. - uint8_t* data = tensor.mutable_data_ptr(); - return [MPSDevice::getInstance()->device() newBufferWithBytes:data - length:tensor.nbytes() - options:0]; -#else - uint8_t* data = tensor.mutable_data_ptr(); - return [MPSDevice::getInstance()->device() - newBufferWithBytesNoCopy:data - length:tensor.nbytes() - options:0 - deallocator:nil]; -#endif // TARGET_OS_SIMULATOR -#else - return __builtin_bit_cast(id, tensor.storage().data()); -#endif // EXIR_MPS_DELEGATE -} - -class Placeholder { - public: - Placeholder() - : _placeholder(nullptr), _value(nullptr), _tensor(Tensor(nullptr)) {} - Placeholder(MPSGraphTensor* mpsGraphTensor) - : _placeholder(mpsGraphTensor), - _value(nullptr), - _tensor(Tensor(nullptr)) {} - Placeholder( - MPSGraphTensor* mpsGraphTensor, - const Tensor& self, - MPSShape* mpsShape = nullptr, - MPSDataType dataType = MPSDataTypeInvalid); - MPSGraphTensor* getMPSGraphTensor() { - return _placeholder; - } - MPSGraphTensorData* getMPSGraphTensorData() { - return _value; - } - bool isIntermediate() { - return _value == nullptr; - } - - private: - MPSGraphTensor* _placeholder; - MPSGraphTensorData* _value; - Tensor _tensor; -}; - -} // namespace mps diff --git a/backends/apple/mps/utils/OperationUtils.mm b/backends/apple/mps/utils/OperationUtils.mm deleted file mode 100644 index 184ba70249c..00000000000 --- a/backends/apple/mps/utils/OperationUtils.mm +++ /dev/null @@ -1,144 +0,0 @@ -// -// Copyright (c) 2023 Apple Inc. All rights reserved. -// Provided subject to the LICENSE file in the top level directory. -// - -#include "OperationUtils.h" - -namespace mps { - -MPSShape* getMPSShape(const Tensor& t, MemoryFormat memory_format) { - return getMPSShape(t.sizes(), memory_format); -} - -MPSShape* getMPSShape(const IntArrayRef& sizes, MemoryFormat memory_format) { - if (memory_format == MemoryFormat::ChannelsLast) { - TORCH_INTERNAL_ASSERT(sizes.size() == 4, "ChannelsLast memory format must have 4 dimensions!"); - const NSUInteger N = sizes[0]; - const NSUInteger C = sizes[1]; - const NSUInteger H = sizes[2]; - const NSUInteger W = sizes[3]; - return @[@(N), @(H), @(W), @(C)]; - } - const int sz = sizes.size(); - const int sz_ = (sz > 0) ? sz : 1; - - std::vector numbers(sz_); - - for (int i = 0; i < sz_; i++) { - NSInteger sz_i = (i < sz) ? sizes[i] : 1; - NSNumber* number = [NSNumber numberWithInteger:sz_i]; - numbers[i] = number; - } - return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; -} - -std::vector getMPSShapeVec(const MPSShape* shape) { - __block std::vector shapeVec; - shapeVec.reserve([shape count]); - [shape enumerateObjectsUsingBlock:^(NSNumber * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { - shapeVec.push_back(obj.intValue); - }]; - return shapeVec; -} - -MPSDataType getMPSScalarType(ScalarType scalar_type) { - switch (scalar_type) { - // This is an intentional fallthrough supporting Double for Scalar - // types as they are casted to Float32 currently. - case ScalarType::Double: - case ScalarType::Float: - return MPSDataTypeFloat32; - case ScalarType::Half: return MPSDataTypeFloat16; - case ScalarType::Int: return MPSDataTypeInt32; - case ScalarType::Long: return MPSDataTypeInt64; - case ScalarType::Short: return MPSDataTypeInt16; - case ScalarType::Char: - case ScalarType::QInt8: - return MPSDataTypeInt8; - case ScalarType::Byte: - case ScalarType::QUInt8: - return MPSDataTypeUInt8; - case ScalarType::Bool: return MPSDataTypeBool; - default: - TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") - } -} - -MPSDataType getMPSDataType(ScalarType scalar_type) { - switch (scalar_type) { - case ScalarType::Float: return MPSDataTypeFloat32; - case ScalarType::Half: return MPSDataTypeFloat16; - case ScalarType::Int: return MPSDataTypeInt32; - case ScalarType::Long: return MPSDataTypeInt64; - case ScalarType::Short: return MPSDataTypeInt16; - case ScalarType::Char: - case ScalarType::QInt8: - return MPSDataTypeInt8; - case ScalarType::Byte: - case ScalarType::QUInt8: return MPSDataTypeUInt8; - case ScalarType::Bool: return MPSDataTypeBool; - case ScalarType::Double: - TORCH_CHECK_TYPE(false, "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. " - "Please use float32 instead.") - default: - TORCH_CHECK_TYPE(false, "Trying to convert ", scalar_type, " to the MPS backend but it does not have support for that dtype.") - } -} - -ScalarType getScalarType(MPSDataType mps_data_type) { - switch (mps_data_type) { - case MPSDataTypeFloat32: return ScalarType::Float; - case MPSDataTypeFloat16: return ScalarType::Half; - case MPSDataTypeInt32: return ScalarType::Int; - case MPSDataTypeInt64: return ScalarType::Long; - case MPSDataTypeInt16: return ScalarType::Short; - case MPSDataTypeInt8: return ScalarType::Char; - case MPSDataTypeUInt8: return ScalarType::Byte; - case MPSDataTypeBool: return ScalarType::Bool; - default: - TORCH_CHECK_TYPE(false, "Couldn't convert MPS data type ", mps_data_type, " to PyTorch data type"); - } -} - -MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) { - return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"]; -} - -MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) { - return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"]; -} - -MPSGraphTensorData* allocMPSGraphTensorData(id buffer, - MPSShape* mpsShape, - MPSDataType mpsDataType) { - MPSGraphTensorData *tensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer - shape:mpsShape - dataType:mpsDataType] autorelease]; - TORCH_INTERNAL_ASSERT(tensorData); - return tensorData; -} - -Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src, MPSShape *mpsShape, MPSDataType dataType) : _tensor(src) { - TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!"); - // extract the pointer to MTLBuffer from the Tensor's storage - id srcBuf = getMTLBufferStorage(src); - - // tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero. - // if buffer size is zero in here, it's not a user error. It could be a missing check for - // tensor.numel() == 0 in our internal implementations of ops. - TORCH_INTERNAL_ASSERT([srcBuf length] > 0, "Placeholder tensor is empty!"); - const MPSDataType mpsDataType = dataType != MPSDataTypeInvalid ? dataType : - _tensor.dim() == 0 ? getMPSScalarType(_tensor.scalar_type()) : getMPSDataType(_tensor.scalar_type()); - - if (!mpsShape) { - mpsShape = getMPSShape(_tensor); - } - - _value = allocMPSGraphTensorData(srcBuf, mpsShape, mpsDataType); - TORCH_INTERNAL_ASSERT(_value); - - _placeholder = mpsGraphTensor; -} - -} // namespace mps diff --git a/backends/apple/mps/utils/graph_bindings.py b/backends/apple/mps/utils/graph_bindings.py deleted file mode 100644 index 825003b375f..00000000000 --- a/backends/apple/mps/utils/graph_bindings.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Copyright (c) 2023 Apple Inc. All rights reserved. -# Provided subject to the LICENSE file in the top level directory. -# - -import torch.utils.cpp_extension - -sources = [ - "MPSGraphInterface.mm", - "OperationUtils.mm", - "Bindings.mm", -] - -ops = [ - "ConvolutionOps.mm", - "NormalizationOps.mm", - "ActivationOps.mm", - "ReduceOps.mm", - "ConstantOps.mm", - "UnaryOps.mm", - "BinaryOps.mm", - "ClampOps.mm", - "ShapeOps.mm", - "LinearAlgebraOps.mm", - "BitwiseOps.mm", - "PoolingOps.mm", - "PadOps.mm", - "RangeOps.mm", - "Indexing.mm", -] - -SOURCES_PATH = "backends/apple/mps/" -final_sources = [SOURCES_PATH + "utils/" + source for source in sources] + [ - SOURCES_PATH + "operations/" + op for op in ops -] - -graph_bindings = torch.utils.cpp_extension.load( - name="MPSGraphBindings", - sources=final_sources, - extra_include_paths=[SOURCES_PATH], - verbose=False, -) diff --git a/backends/apple/mps/utils/mps_utils.py b/backends/apple/mps/utils/mps_utils.py index f6f3f57bf24..5c26faa0464 100644 --- a/backends/apple/mps/utils/mps_utils.py +++ b/backends/apple/mps/utils/mps_utils.py @@ -3,35 +3,79 @@ # Provided subject to the LICENSE file in the top level directory. # +from typing import cast, Optional, Union + import torch -from executorch.backends.apple.mps.utils.graph_bindings import graph_bindings - - -def get_mps_data_type(dtype): - scalar_type_to_mps_dtype = { - "torch.float32": graph_bindings.MPSDataTypeFloat32, - "torch.float16": graph_bindings.MPSDataTypeFloat16, - "torch.int32": graph_bindings.MPSDataTypeInt32, - "torch.int64": graph_bindings.MPSDataTypeInt64, - "torch.int16": graph_bindings.MPSDataTypeInt16, - "torch.int8": graph_bindings.MPSDataTypeInt8, - "torch.qint8": graph_bindings.MPSDataTypeInt8, - "torch.uint8": graph_bindings.MPSDataTypeUInt8, - "torch.quint8": graph_bindings.MPSDataTypeUInt8, - "torch.bool": graph_bindings.MPSDataTypeBool, - torch.float32: graph_bindings.MPSDataTypeFloat32, - torch.float16: graph_bindings.MPSDataTypeFloat16, - torch.int32: graph_bindings.MPSDataTypeInt32, - torch.int64: graph_bindings.MPSDataTypeInt64, - torch.int16: graph_bindings.MPSDataTypeInt16, - torch.int8: graph_bindings.MPSDataTypeInt8, - torch.qint8: graph_bindings.MPSDataTypeInt8, - torch.uint8: graph_bindings.MPSDataTypeUInt8, - torch.quint8: graph_bindings.MPSDataTypeUInt8, - torch.bool: graph_bindings.MPSDataTypeBool, - } +from executorch.backends.apple.mps.serialization.mps_graph_schema import MPSDataType +from executorch.exir import ExportedProgram +from torch._export.utils import get_buffer, get_param, is_buffer, is_param + + +def get_input_node(node: torch.fx.Node, input_index: int) -> Union[torch.fx.Node, None]: + return None if node is None else cast(torch.fx.Node, node.args[input_index]) + +def get_scalar_val(node: torch.fx.Node, input_index: int) -> Union[float, int]: + return node.args[input_index] + + +def edge_dtype_to_mps_dtype(dtype: torch.dtype): + if not hasattr(edge_dtype_to_mps_dtype, "map"): + edge_dtype_to_mps_dtype.map = { + torch.float16: MPSDataType.mps_data_type_float16, + torch.float32: MPSDataType.mps_data_type_float32, + torch.bfloat16: MPSDataType.mps_data_type_bfloat16, + torch.int8: MPSDataType.mps_data_type_int8, + torch.int16: MPSDataType.mps_data_type_int16, + torch.int32: MPSDataType.mps_data_type_int32, + torch.int64: MPSDataType.mps_data_type_int64, + torch.uint8: MPSDataType.mps_data_type_uint8, + torch.bool: MPSDataType.mps_data_type_bool, + torch.cfloat: MPSDataType.mps_data_type_complex_float32, + torch.chalf: MPSDataType.mps_data_type_complex_float16, + } try: - return scalar_type_to_mps_dtype[dtype] + return edge_dtype_to_mps_dtype.map[dtype] except KeyError: - raise AssertionError(f"Invalid data type: {dtype}") + raise RuntimeError(f"Invalid data type: {dtype}") + + +def get_param_tensor( + exp_prog: ExportedProgram, node: torch.fx.Node +) -> Optional[torch.Tensor]: + if node is None: + return None + elif is_param(exp_prog, node): + return get_param(exp_prog, node) + elif is_buffer(exp_prog, node): + return get_buffer(exp_prog, node) + elif is_get_attr(node): + # Support both lifted and unlifted graph + try: + # Unlifted graph (coming from old exir.capture API) + return getattr(node.graph.owning_module, node.target) + except AttributeError: + return getattr(exp_prog.graph_module, node.target) + raise RuntimeError(f"unsupported param type, {node.op}.") + + +def is_get_attr(node: torch.fx.Node): + """ + Returns true if the given node is a get attr node for a tensor of the model + """ + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def is_parameter(exp_prog: torch.export.ExportedProgram, node: torch.fx.Node) -> bool: + """ + Check if a node is a lifted parameter (static data like weights and bias are + are supplied as inputs to the graph. + + Args: + edge_program (torch.export.ExportedProgram): _description_ + node (torch.fx.Node): _description_ + + Returns: + bool: _description_ + """ + return is_get_attr(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node) diff --git a/build/cmake_deps.toml b/build/cmake_deps.toml index b4d5dc3e133..9909631e0dd 100644 --- a/build/cmake_deps.toml +++ b/build/cmake_deps.toml @@ -157,6 +157,14 @@ deps = [ "executorch", ] +[targets.mps_schema] +buck_targets = [ + "//backends/apple/mps:mps_schema", +] +filters = [ + ".fbs$", +] + [targets.xnn_executor_runner] buck_targets = [ "//examples/xnnpack:xnn_executor_runner", diff --git a/examples/apple/mps/CMakeLists.txt b/examples/apple/mps/CMakeLists.txt index 2dbc2749f53..ec465684060 100644 --- a/examples/apple/mps/CMakeLists.txt +++ b/examples/apple/mps/CMakeLists.txt @@ -69,6 +69,7 @@ set( extract_sources(${EXECUTORCH_SRCS_FILE}) +set(_mps_schema_headers ${CMAKE_BINARY_DIR}/../../../schema/include/) include(${EXECUTORCH_SRCS_FILE}) target_include_directories( bundled_program @@ -76,12 +77,10 @@ target_include_directories( ${CMAKE_CURRENT_BINARY_DIR}/../../../sdk/include ${CMAKE_CURRENT_BINARY_DIR}/../../../sdk/bundled_program ${EXECUTORCH_ROOT}/third-party/flatbuffers/include + ${_mps_schema_headers} ) list(TRANSFORM _mps_executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/") add_executable(mps_executor_runner ${_mps_executor_runner__srcs}) -target_include_directories( - mps_executor_runner INTERFACE ${CMAKE_BINARY_DIR}/schema/include/ - ${EXECUTORCH_ROOT}/third-party/flatbuffers/include) target_link_libraries(mps_executor_runner bundled_program executorch gflags mpsdelegate diff --git a/examples/apple/mps/README.md b/examples/apple/mps/README.md index 38e93149414..055286ceee5 100644 --- a/examples/apple/mps/README.md +++ b/examples/apple/mps/README.md @@ -33,10 +33,10 @@ cmake -DBUCK2="$BUCK" \ -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_SDK=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ + -DEXECUTORCH_BUILD_MPS=ON \ -DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \ -Bcmake-out . cmake --build cmake-out -j9 --target install --config Release -# Build the mps_executor_runner CMAKE_PREFIX_PATH="${PWD}/cmake-out/lib/cmake/ExecuTorch;${PWD}/cmake-out/third-party/gflags" # build mps_executor_runner rm -rf cmake-out/examples/apple/mps diff --git a/examples/apple/mps/executor_runner/mps_executor_runner.mm b/examples/apple/mps/executor_runner/mps_executor_runner.mm index 379bd422825..c476d2ff902 100644 --- a/examples/apple/mps/executor_runner/mps_executor_runner.mm +++ b/examples/apple/mps/executor_runner/mps_executor_runner.mm @@ -12,11 +12,10 @@ * It uses the original bundled input data from the flatbuffer file. */ -#import -#import -#import - #include +#include +#include +#include #include @@ -31,6 +30,7 @@ #include #include #include +#include #include using namespace std::chrono; @@ -67,6 +67,11 @@ false, "True for showing profile data (e.g execution time)"); +DEFINE_bool( + skip_warmup, + false, + "If true, a warmup iteration won't be executed."); + using namespace torch::executor; using torch::executor::util::FileDataLoader; @@ -231,6 +236,18 @@ bool tensors_are_close_( } int main(int argc, char** argv) { + { + const char* usage = R"(MPS Executor Runner. Sample usage: + mps_executor_runner --model_path model.pte)"; + gflags::SetUsageMessage(usage); + } + + if (argc == 1) { + ET_LOG(Error, "No options provided."); + gflags::ShowUsageWithFlags(argv[0]); + return 1; + } + runtime_init(); gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -397,7 +414,7 @@ MemoryManager memory_manager( // Prepare the inputs. exec_aten::ArrayRef inputs; if (FLAGS_bundled_program) { - ET_LOG(Info, "Loading bundled program...\n"); + ET_LOG(Debug, "Loading bundled program"); // Use the inputs embedded in the bundled program. status = torch::executor::bundled_program::LoadBundledInput( *method, @@ -412,16 +429,21 @@ MemoryManager memory_manager( // Use ones-initialized inputs. inputs = torch::executor::util::PrepareInputTensors(*method); } - ET_LOG(Info, "Inputs prepared."); + ET_LOG(Debug, "Inputs prepared"); - for (int i = 0; i < FLAGS_num_runs; i++) { + int num_iterations = FLAGS_num_runs + (FLAGS_skip_warmup ? 0 : 1); + std::vector exec_times; + exec_times.reserve(FLAGS_num_runs); + for (int i = 0; i < num_iterations; i++) { auto start_exec_time = high_resolution_clock::now(); // Run the model. Error status = method->execute(); auto end_exec_time = high_resolution_clock::now(); - auto duration = duration_cast(end_exec_time - start_exec_time); + auto duration = duration_cast(end_exec_time - start_exec_time); + exec_times.push_back(duration.count()); if (FLAGS_profile) { - ET_LOG(Info, "[Run %d] Inference time: %lld milliseconds", i, duration.count()); + const float miliseconds = static_cast(duration.count()) / 1000.f; + ET_LOG(Info, "[Run %d] Inference time: %.3f miliseconds", i, miliseconds); } ET_CHECK_MSG( status == Error::Ok, @@ -429,7 +451,15 @@ MemoryManager memory_manager( method_name, status); } - ET_LOG(Info, "Model executed successfully."); + if (FLAGS_profile && FLAGS_num_runs) { + auto itr = exec_times.begin(); + if (!FLAGS_skip_warmup) + itr++; + + const float avg_time = (std::reduce(itr, exec_times.end()) / static_cast(FLAGS_num_runs)) / 1000.f; + std::cout << "Average inference time: " << std::setprecision(2) << std::fixed << avg_time << " miliseconds\n"; + } + ET_LOG(Debug, "Model executed successfully."); auto output_list = runtime_allocator.allocateList(method->outputs_size()); @@ -440,15 +470,10 @@ MemoryManager memory_manager( std::vector outputs(method->outputs_size()); status = method->get_outputs(outputs.data(), outputs.size()); ET_CHECK(status == Error::Ok); - for (EValue& output : outputs) { - // TODO(T159700776): This assumes that all outputs are fp32 tensors. Add - // support for other EValues and Tensor dtypes, and print tensors in a more - // readable way. - auto output_tensor = output.toTensor(); - auto data_output = output_tensor.const_data_ptr(); - for (size_t j = 0; j < output_tensor.numel(); ++j) { - ET_LOG(Info, "%f", data_output[j]); - } + // Print the first and last 100 elements of long lists of scalars. + std::cout << torch::executor::util::evalue_edge_items(100); + for (int i = 0; i < outputs.size(); ++i) { + std::cout << "Output " << i << ": " << outputs[i] << std::endl; } // Dump the profiling data to the specified file. @@ -464,13 +489,15 @@ MemoryManager memory_manager( if (FLAGS_bundled_program) { double rtol = 1e-05; double atol = 1e-08; - - if (strstr(model_path, "mv3") || + if (strstr(model_path, "fp16")) { + rtol = 1e-01; + atol = 1e-01; + } else if (strstr(model_path, "mv3") || strstr(model_path, "mv2") || + strstr(model_path, "conv") || strstr(model_path, "vit") || strstr(model_path, "resnet18") || strstr(model_path, "resnet50") || - strstr(model_path, "mobilebert") || strstr(model_path, "emformer") || strstr(model_path, "emformer_transcribe") || strstr(model_path, "emformer_join") || @@ -478,7 +505,10 @@ MemoryManager memory_manager( strstr(model_path, "llama2") || strstr(model_path, "ic3") || strstr(model_path, "ic4")) { - atol = 1e-04; + atol = 1e-04; + } else if (strstr(model_path, "mobilebert")) { + atol = 1e-01; + rtol = 1e-01; } status = torch::executor::bundled_program::VerifyResultWithBundledExpectedOutput( *method, diff --git a/examples/apple/mps/executor_runner/targets.bzl b/examples/apple/mps/executor_runner/targets.bzl index dba7db2447a..0e7a509bb3a 100644 --- a/examples/apple/mps/executor_runner/targets.bzl +++ b/examples/apple/mps/executor_runner/targets.bzl @@ -20,13 +20,14 @@ def define_common_targets(): "mps_executor_runner.mm", ], deps = [ + "//executorch/backends/apple/mps:mps_schema", "//executorch/backends/apple/mps:mps", "//executorch/runtime/executor:program", + "//executorch/extension/evalue_util:print_evalue", "//executorch/extension/data_loader:file_data_loader", "//executorch/kernels/portable:generated_lib_all_ops", "//executorch/extension/data_loader:file_data_loader", "//executorch/extension/data_loader:buffer_data_loader", - "//executorch/util:util", "//executorch/sdk/bundled_program:runtime", "//executorch/util:util", ], diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 3217569c525..f2ecfd6a46f 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -11,8 +11,10 @@ import torch._export as export from executorch import exir from executorch.backends.apple.mps.mps_preprocess import MPSBackend +from executorch.exir import EdgeCompileConfig from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.backend_details import CompileSpec from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.sdk.bundled_program.config import MethodTestCase, MethodTestSuite from executorch.sdk import BundledProgram @@ -23,12 +25,11 @@ from ....models import MODEL_NAME_TO_MODEL from ....models.model_factory import EagerModelFactory -from ....portable.utils import save_pte_program +from ....portable.utils import export_to_edge, save_pte_program FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -38,6 +39,13 @@ help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", ) + parser.add_argument( + "--use_fp16", + default=True, + action=argparse.BooleanOptionalAction, + help="Whether to automatically convert float32 operations to float16 operations.", + ) + parser.add_argument( "-b", "--bundled", @@ -60,20 +68,21 @@ # pre-autograd export. eventually this will become torch.export model = export.capture_pre_autograd_graph(model, example_inputs) - edge = exir.capture( - model, example_inputs, exir.CaptureConfig(enable_aot=True, _unlift=True) - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) - logging.info(f"Exported graph:\n{edge.exported_program.graph}") - - lowered_module = to_backend(MPSBackend.__name__, edge.exported_program, []) - - logging.info(f"Lowered graph:\n{edge.exported_program.graph}") + edge = export_to_edge( + model, + example_inputs, + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))] + lowered_module = to_backend( + MPSBackend.__name__, edge.exported_program(), compile_specs + ) executorch_program = ( exir.capture( lowered_module, example_inputs, - exir.CaptureConfig(enable_aot=True, _unlift=True), + exir.CaptureConfig(enable_aot=True, _unlift=False), ) .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) .to_executorch(config=ExecutorchBackendConfig(extract_constant_segment=False)) @@ -92,12 +101,17 @@ ], ) ] + logging.info(f"Expected output: {model(*example_inputs)}") bundled_program = BundledProgram(executorch_program, method_test_suites) bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( bundled_program ) model_name = f"{model_name}_bundled" + extension = "fp16" + if not args.use_fp16: + extension = "fp32" + model_name = f"{model_name}_{extension}" program_buffer = bundled_program_buffer else: program_buffer = executorch_program.buffer From 5cc1394187a383b4c2a6fe5c508adf8183ec9724 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Fri, 19 Jan 2024 15:26:15 -0800 Subject: [PATCH 2/3] Revert partitioner changes (coming in a follow up PR) --- .../apple/mps/partition/mps_partitioner.py | 89 ++++++++----------- .../apple/mps/runtime/operations/PadOps.mm | 9 -- 2 files changed, 37 insertions(+), 61 deletions(-) diff --git a/backends/apple/mps/partition/mps_partitioner.py b/backends/apple/mps/partition/mps_partitioner.py index 3fdc09028f9..a9246219982 100644 --- a/backends/apple/mps/partition/mps_partitioner.py +++ b/backends/apple/mps/partition/mps_partitioner.py @@ -4,78 +4,63 @@ # import logging -from typing import Any, Dict, List, Union import torch -from executorch.backends.apple.mps.mps_preprocess import MPSBackend -from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors -from executorch.exir.backend.backend_details import CompileSpec -from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( - generate_partitions_from_list_of_nodes, -) from executorch.exir.backend.partitioner import ( DelegationSpec, Partitioner, PartitionResult, ) + from torch._export.exported_program import ExportedProgram -from torch.fx.passes.infra.partitioner import Partition +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) -class MPSOperatorSupport(OperatorSupportBase): - def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs): - self.node_visitors = get_node_visitors(edge_program) +class OperatorsSupportedForMpsBackend(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - if node.op != "call_function": - return False - - if node.target.__name__ not in self.node_visitors: - return False - - return True + supported_mps_ops = [ + torch.ops.aten.add.Tensor, + torch.ops.aten.mm.default, + torch.ops.aten.div.default, + ] + ret_val = ( + (node.op == "call_function" and node.target in supported_mps_ops) + or node.op == "get_attr" + or node.op == "output" + ) + return ret_val +# TODO MPSPartitioner is work in progress currently. +# Use whole graph delegation instead when lowering to MPS. class MPSPartitioner(Partitioner): - compile_spec: List[CompileSpec] = [] + compile_spec = [] def __init__(self) -> None: - self.delegation_spec = DelegationSpec(MPSBackend.__name__, self.compile_spec) - self.partition_tags: Dict[str, DelegationSpec] = {} + self.delegation_spec = DelegationSpec("MPSBackend", self.compile_spec) - def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]: - self.supported_ops = MPSOperatorSupport( - edge_program=edge_program, compiler_specs=self.delegation_spec.compile_specs - ) - return generate_partitions_from_list_of_nodes( - edge_program.graph_module, - op_support=self.supported_ops, - ) + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + # Run the CapabilityBasedPartitioner to return the largest possible + # subgraphs containing the nodes with the tags + logger.info("MpsPartitioner::partition") + partition_tags = {} - def tag_nodes(self, partitions: List[Partition]) -> None: - for partition in partitions: + capability_partitioner = CapabilityBasedPartitioner( + exported_program.graph_module, + OperatorsSupportedForMpsBackend(), + allows_single_node_partition=True, + ) + partition_list = capability_partitioner.propose_partitions() + for partition in partition_list: for node in partition.nodes: - delegation_tag = f"mps_{partition.id}" - node.meta["delegation_tag"] = delegation_tag - self.partition_tags[delegation_tag] = self.delegation_spec + tag = f"tag{partition.id}" + node.meta["delegation_tag"] = tag + partition_tags[tag] = self.delegation_spec - @staticmethod - def check_partitions(partitions: Union[dict, list]) -> bool: - pl = len(partitions) - if pl == 0: - logging.warning("Nothing can be partitioned!") - else: - logging.info(f"Found {pl} subgraphs to be partitioned.") - return pl != 0 - - # override - def partition(self, edge_program: ExportedProgram) -> PartitionResult: - partitions = self.generate_partitions(edge_program=edge_program) - if self.check_partitions(partitions): - self.tag_nodes(partitions) - x = PartitionResult( - tagged_exported_program=edge_program, partition_tags=self.partition_tags + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags ) - - return x diff --git a/backends/apple/mps/runtime/operations/PadOps.mm b/backends/apple/mps/runtime/operations/PadOps.mm index 7728d7edf48..1d97af8cf3b 100644 --- a/backends/apple/mps/runtime/operations/PadOps.mm +++ b/backends/apple/mps/runtime/operations/PadOps.mm @@ -152,15 +152,6 @@ return Error::Ok; } -// PyMPSGraphTensor* -// MPSGraphModule::constant_pad_nd( -// MPSGraphTensor* input, -// IntArrayRef pad, -// const double value) { -// return pad_out_template(mpsGraph, input, pad, MPSGraphPaddingModeConstant, value); -// } - - } // namespace delegate } // namespace mps } // namespace executor From 923d30a215d1219aae792a1335f934c718702a21 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Fri, 19 Jan 2024 18:45:40 -0800 Subject: [PATCH 3/3] Remove comments; Add copyright to unary_ops.py --- backends/apple/mps/operators/shape_ops.py | 4 ---- backends/apple/mps/operators/unary_ops.py | 4 ++++ backends/apple/mps/runtime/MPSExecutor.mm | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/apple/mps/operators/shape_ops.py b/backends/apple/mps/operators/shape_ops.py index b6c9bad692e..76c559018be 100644 --- a/backends/apple/mps/operators/shape_ops.py +++ b/backends/apple/mps/operators/shape_ops.py @@ -238,10 +238,6 @@ def define_node( node: torch.fx.Node, mps_graph: MPSGraph, ) -> None: - # mps_node = self.create_unary_node( - # node, mps_graph, MPSSlice - # ) - input1_id = self.define_tensor(get_input_node(node, 0), mps_graph) output_ids = self.define_tensor_list(node, mps_graph) split_sizes = eval_shape(cast(torch.SymInt, node.args[1])) diff --git a/backends/apple/mps/operators/unary_ops.py b/backends/apple/mps/operators/unary_ops.py index a8b957dc44f..fd60e150e1a 100644 --- a/backends/apple/mps/operators/unary_ops.py +++ b/backends/apple/mps/operators/unary_ops.py @@ -1,3 +1,7 @@ +# Copyright (c) 2023 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + import torch from executorch.backends.apple.mps.operators.node_visitor import ( NodeVisitor, diff --git a/backends/apple/mps/runtime/MPSExecutor.mm b/backends/apple/mps/runtime/MPSExecutor.mm index 0311214bf58..f59cf42ace5 100644 --- a/backends/apple/mps/runtime/MPSExecutor.mm +++ b/backends/apple/mps/runtime/MPSExecutor.mm @@ -70,7 +70,7 @@ @interface MPSNDArrayDescriptor () __ET_NODISCARD Error MPSExecutor::forward(std::vector& outputs) { Error err = Error::Ok; MPSStream* mpsStream = getDefaultMPSStream(); - if (mpsStream->commitAndContinueEnabled() || mpsStream->hasLiveCommandBuffer() || true) { + if (mpsStream->commitAndContinueEnabled() || mpsStream->hasLiveCommandBuffer()) { id commandBuffer = mpsStream->commandBuffer(); [executable_ encodeToCommandBuffer:commandBuffer inputsArray:inputsArray_