diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bf9246734..f917f0312 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -104,3 +104,69 @@ jobs: jekyll: false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + test_opencl: + runs-on: ubuntu-latest + env: + CC: mpicc + PETSC_DIR: ${{ github.workspace }}/petsc + PETSC_ARCH: default + PETSC_CONFIGURE_OPTIONS: --with-debugging=1 --with-shared-libraries=1 --with-c2html=0 --with-fortran-bindings=0 --with-opencl=1 --download-viennacl --with-viennacl=1 + + steps: + - name: Install system dependencies + shell: bash + run: | + sudo apt update + sudo apt install build-essential mpich libmpich-dev \ + libblas-dev liblapack-dev gfortran + sudo apt install ocl-icd-opencl-dev pocl-opencl-icd + + - name: Set correct Python version + uses: actions/setup-python@v2 + with: + python-version: '3.10' + + - name: Clone PETSc + uses: actions/checkout@v2 + with: + repository: firedrakeproject/petsc + path: ${{ env.PETSC_DIR }} + + - name: Build and install PETSc + shell: bash + working-directory: ${{ env.PETSC_DIR }} + run: | + ./configure ${PETSC_CONFIGURE_OPTIONS} + make + + - name: Build and install petsc4py + shell: bash + working-directory: ${{ env.PETSC_DIR }}/src/binding/petsc4py + run: | + python -m pip install --upgrade cython numpy + python -m pip install --no-deps . + + - name: Checkout PyOP2 + uses: actions/checkout@v2 + with: + path: PyOP2 + + - name: Install PyOP2 + shell: bash + working-directory: PyOP2 + run: | + python -m pip install pip==20.2 # pip 20.2 needed for loopy install to work. + + # xargs is used to force installation of requirements in the order we specified. + xargs -l1 python -m pip install < requirements-ext.txt + xargs -l1 python -m pip install < requirements-git.txt + python -m pip install pulp + python -m pip install -U flake8 + python -m pip install pyopencl + python -m pip install . + + - name: Run tests + shell: bash + working-directory: PyOP2 + run: pytest test -v --tb=native diff --git a/pyop2/array.py b/pyop2/array.py new file mode 100644 index 000000000..2126cf436 --- /dev/null +++ b/pyop2/array.py @@ -0,0 +1,274 @@ +import abc +import collections +import enum + +import numpy as np +from petsc4py import PETSc + +from pyop2.configuration import configuration +from pyop2.datatypes import ScalarType +from pyop2.exceptions import DataTypeError +from pyop2.offload_utils import _offloading as offloading, _backend as backend, OffloadingBackend +from pyop2.types.access import READ, RW, WRITE, INC, MIN, MAX +from pyop2.utils import verify_reshape + +try: + import pyopencl +except ImportError: + pass + +try: + import pycuda +except ImportError: + pass + + +class MirroredArray(abc.ABC): + """An array that is available on both host and device, copying values where necessary.""" + + # TODO: Use __new__ instead? + @classmethod + def new(cls, *args, **kwargs): + if backend == OffloadingBackend.CPU: + return CPUOnlyArray(*args, **kwargs) + elif backend == OffloadingBackend.OPENCL: + return OpenCLArray(*args, **kwargs) + elif backend == OffloadingBackend.CUDA: + return CUDAArray(*args, **kwargs) + else: + raise ValueError + + def __init__(self, data): + # option 1: passed shape and dtype + if (len(data) == 2 and isinstance(data[0], collections.abc.Collection) + and isinstance(data[1], np.dtype)): + self.shape = data[0] + self.dtype = data[1] + self._lazy_host_data = None + # option 2: passed an array + elif isinstance(data, np.ndarray): + self.shape = data.shape + self.dtype = data.dtype + self._lazy_host_data = data + else: + raise ValueError("Unexpected arguments encountered") + + # for now assume we are only ever passed numpy arrays (not device arrays) + self._lazy_device_data = None + self.is_available_on_host = True + self.is_available_on_device = False + + # counter used to keep track of modifications + self.state = 0 + + @property + def data(self): + return self.device_data if offloading else self.host_data + + @property + def data_ro(self): + return self.device_data_ro if offloading else self.host_data_ro + + @property + def host_data(self): + return self.host_data_rw + + @property + def host_data_rw(self): + self.state += 1 + + if not self.is_available_on_host: + self.device_to_host_copy() + self.is_available_on_host = True + + # modifying on host invalidates the device + self.is_available_on_device = False + + data = self._host_data.view() + data.setflags(write=True) + return data + + @property + def host_data_ro(self): + if not self.is_available_on_host: + self.device_to_host_copy() + self.is_available_on_host = True + + data = self._host_data.view() + data.setflags(write=False) + return data + + @property + def host_data_wo(self): + self.state += 1 # TODO make a decorator + + # if we are writing then we don't need to copy from the device + pass + + # modifying on host invalidates the device + self.is_available_on_host = True + self.is_available_on_device = False + + data = self._host_data.view() + data.setflags(write=True) + return data + + @property + def device_data(self): + return self.device_data_rw + + @property + @abc.abstractmethod + def device_data_rw(self): + pass + + @property + @abc.abstractmethod + def device_data_ro(self): + pass + + @property + @abc.abstractmethod + def device_data_wo(self): + pass + + @property + def ptr_rw(self): + return self.device_ptr_rw if offloading else self.host_ptr_rw + + @property + def ptr_ro(self): + return self.device_ptr_ro if offloading else self.host_ptr_ro + + @property + def ptr_wo(self): + return self.device_ptr_wo if offloading else self.host_ptr_wo + + @property + def host_ptr_rw(self): + return self.host_data_rw.ctypes.data + + @property + def host_ptr_ro(self): + return self.host_data_ro.ctypes.data + + @property + def host_ptr_wo(self): + return self.host_data_wo.ctypes.data + + @property + @abc.abstractmethod + def device_ptr_rw(self): + pass + + @property + @abc.abstractmethod + def device_ptr_ro(self): + pass + + @property + @abc.abstractmethod + def device_ptr_wo(self): + pass + + @abc.abstractmethod + def host_to_device_copy(self): + pass + + @abc.abstractmethod + def device_to_host_copy(self): + pass + + @abc.abstractmethod + def alloc_device_data(self, shape, dtype): + pass + + @property + def size(self): + return self.data_ro.size + + @property + def _host_data(self): + if self._lazy_host_data is None: + self._lazy_host_data = np.zeros(self.shape, self.dtype) + return self._lazy_host_data + + @property + def _device_data(self): + if self._lazy_device_data is None: + self._lazy_device_data = self.alloc_device_data(self.shape, self.dtype) + return self._lazy_device_data + + +class CPUOnlyArray(MirroredArray): + """TODO""" + + @property + def device_data_rw(self): + return self.host_data_rw + + @property + def device_data_ro(self): + return self.host_data_ro + + @property + def device_data_wo(self): + return self.host_data_wo + + @property + def device_ptr_rw(self): + return self.host_ptr_rw + + @property + def device_ptr_ro(self): + return self.host_ptr_ro + + @property + def device_ptr_wo(self): + return self.host_ptr_wo + + def host_to_device_copy(self): + pass + + def device_to_host_copy(self): + pass + + def alloc_device_data(self, shape, dtype): + raise AssertionError("Should not hit this") + + +class OpenCLArray(MirroredArray): + + @property + def device_data_rw(self): + raise NotImplementedError + + @property + def device_data_ro(self): + raise NotImplementedError + + @property + def device_data_wo(self): + raise NotImplementedError + + @property + def device_ptr_rw(self): + raise NotImplementedError + + @property + def device_ptr_ro(self): + raise NotImplementedError + + @property + def device_ptr_wo(self): + raise NotImplementedError + + def host_to_device_copy(self): + self._device_data.set(self._host_data) + + def device_to_host_copy(self): + # FIXME: self.queue not a thing + self._device_data.get(self.queue, self._host_data) + + def alloc_device_data(self, shape, dtype): + raise NotImplementedError("TODO") diff --git a/pyop2/backends/__init__.py b/pyop2/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyop2/backends/cpu.py b/pyop2/backends/cpu.py new file mode 100644 index 000000000..b8581c6f4 --- /dev/null +++ b/pyop2/backends/cpu.py @@ -0,0 +1,187 @@ +from pyop2.types.dat import Dat as BaseDat, MixedDat, DatView +from pyop2.types.set import Set, ExtrudedSet, Subset, MixedSet +from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet +from pyop2.types.map import Map, MixedMap +from pyop2.parloop import AbstractParloop +from pyop2.global_kernel import AbstractGlobalKernel +from pyop2.types.access import READ, INC, MIN, MAX +from pyop2.types.mat import Mat +from pyop2.types.glob import Global as BaseGlobal +from pyop2.backends import AbstractComputeBackend +from petsc4py import PETSc +from pyop2 import ( + compilation, + mpi, + utils +) + +import ctypes +import os +import loopy as lp +from contextlib import contextmanager +import numpy as np + + +class Dat(BaseDat): + @utils.cached_property + def _vec(self): + assert self.dtype == PETSc.ScalarType, \ + "Can't create Vec with type %s, must be %s" % (self.dtype, + PETSc.ScalarType) + # Can't duplicate layout_vec of dataset, because we then + # carry around extra unnecessary data. + # But use getSizes to save an Allreduce in computing the + # global size. + size = self.dataset.layout_vec.getSizes() + data = self._data[:size[0]] + vec = PETSc.Vec().createWithArray(data, size=size, + bsize=self.cdim, comm=self.comm) + return vec + + @contextmanager + def vec_context(self, access): + # PETSc Vecs have a state counter and cache norm computations + # to return immediately if the state counter is unchanged. + # Since we've updated the data behind their back, we need to + # change that state counter. + self._vec.stateIncrease() + yield self._vec + if access is not READ: + self.halo_valid = False + + def ensure_availability_on_device(self): + from pyop2.op2 import compute_backend + assert compute_backend is cpu_backend + # data transfer is noop for CPU backend + + def ensure_availability_on_host(self): + from pyop2.op2 import compute_backend + assert compute_backend is cpu_backend + # data transfer is noop for CPU backend + + +class Global(BaseGlobal): + @utils.cached_property + def _vec(self): + assert self.dtype == PETSc.ScalarType, \ + "Can't create Vec with type %s, must be %s" % (self.dtype, + PETSc.ScalarType) + # Can't duplicate layout_vec of dataset, because we then + # carry around extra unnecessary data. + # But use getSizes to save an Allreduce in computing the + # global size. + data = self._data + size = self.dataset.layout_vec.getSizes() + if self.comm.rank == 0: + return PETSc.Vec().createWithArray(data, size=size, + bsize=self.cdim, + comm=self.comm) + else: + return PETSc.Vec().createWithArray(np.empty(0, dtype=self.dtype), + size=size, + bsize=self.cdim, + comm=self.comm) + + @contextmanager + def vec_context(self, access): + """A context manager for a :class:`PETSc.Vec` from a :class:`Global`. + + :param access: Access descriptor: READ, WRITE, or RW.""" + yield self._vec + if access is not READ: + data = self._data + self.comm.Bcast(data, 0) + + def ensure_availability_on_device(self): + from pyop2.op2 import compute_backend + assert compute_backend is cpu_backend + # data transfer is noop for CPU backend + + def ensure_availability_on_host(self): + from pyop2.op2 import compute_backend + assert compute_backend is cpu_backend + # data transfer is noop for CPU backend + + +class GlobalKernel(AbstractGlobalKernel): + + @utils.cached_property + def code_to_compile(self): + """Return the C/C++ source code as a string.""" + from pyop2.codegen.rep2loopy import generate + + # import pdb; pdb.set_trace() + wrapper = generate(self.builder) + code = lp.generate_code_v2(wrapper) + + if self.local_kernel.cpp: + from loopy.codegen.result import process_preambles + preamble = "".join( + process_preambles(getattr(code, "device_preambles", []))) + device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs) + return preamble + '\nextern "C" {\n' + device_code + "\n}\n" + return code.device_code() + + @PETSc.Log.EventDecorator() + @mpi.collective + def compile(self, comm): + """Compile the kernel. + + :arg comm: The communicator the compilation is collective over. + :returns: A ctypes function pointer for the compiled function. + """ + extension = "cpp" if self.local_kernel.cpp else "c" + cppargs = ( + tuple("-I%s/include" % d for d in utils.get_petsc_dir()) + + tuple("-I%s" % d for d in self.local_kernel.include_dirs) + + ("-I%s" % os.path.abspath(os.path.dirname(__file__)),) + ) + ldargs = ( + tuple("-L%s/lib" % d for d in utils.get_petsc_dir()) + + tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()) + + ("-lpetsc", "-lm") + + tuple(self.local_kernel.ldargs) + ) + + return compilation.load(self, extension, self.name, + cppargs=cppargs, + ldargs=ldargs, + restype=ctypes.c_int, + comm=comm) + + +class Parloop(AbstractParloop): + @PETSc.Log.EventDecorator("ParLoopRednBegin") + @mpi.collective + def reduction_begin(self): + """Begin reductions.""" + requests = [] + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + mpi_op = {INC: mpi.MPI.SUM, + MIN: mpi.MPI.MIN, + MAX: mpi.MPI.MAX}.get(self.accesses[idx]) + + if mpi.MPI.VERSION >= 3: + requests.append(self.comm.Iallreduce(glob.data, + glob._buf, + op=mpi_op)) + else: + self.comm.Allreduce(glob.data, glob._buf, op=mpi_op) + return tuple(requests) + + @PETSc.Log.EventDecorator("ParLoopRednEnd") + @mpi.collective + def reduction_end(self, requests): + """Finish reductions.""" + if mpi.MPI.VERSION >= 3: + mpi.MPI.Request.Waitall(requests) + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + glob._array.data[:] = glob._buf + else: + assert len(requests) == 0 + + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + glob._data[:] = glob._buf diff --git a/pyop2/backends/cuda.py b/pyop2/backends/cuda.py new file mode 100644 index 000000000..2d9d67122 --- /dev/null +++ b/pyop2/backends/cuda.py @@ -0,0 +1,663 @@ +"""OP2 CUDA backend.""" + +import os + +from hashlib import md5 +from contextlib import contextmanager + +from pyop2 import compilation, mpi, utils +from pyop2.offload_utils import (AVAILABLE_ON_HOST_ONLY, + AVAILABLE_ON_DEVICE_ONLY, + AVAILABLE_ON_BOTH, + DataAvailability) +from pyop2.configuration import configuration +from pyop2.types.set import (MixedSet, Subset as BaseSubset, + ExtrudedSet as BaseExtrudedSet, + Set) +from pyop2.types.map import Map as BaseMap, MixedMap +from pyop2.types.dat import Dat as BaseDat, MixedDat, DatView +from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet +from pyop2.types.mat import Mat +from pyop2.types.glob import Global as BaseGlobal +from pyop2.types.access import RW, READ, MIN, MAX, INC +from pyop2.parloop import AbstractParloop +from pyop2.global_kernel import AbstractGlobalKernel +from pyop2.backends import AbstractComputeBackend, cpu as cpu_backend +from petsc4py import PETSc + +import numpy +import pycuda.driver as cuda +import pycuda.gpuarray as cuda_np +import loopy as lp +from pytools import memoize_method +from dataclasses import dataclass +from typing import Tuple +import ctypes + + +class Map(BaseMap): + + def __init__(self, *args, **kwargs): + super(Map, self).__init__(*args, **kwargs) + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def _cuda_values(self): + self._availability_flag = AVAILABLE_ON_BOTH + return cuda_np.to_gpu(self._values) + + def get_availability(self): + return self._availability_flag + + def ensure_availability_on_device(self): + self._cuda_values + + def ensure_availability_on_host(self): + # Map once initialized is not over-written so always available + # on host. + pass + + @property + def _kernel_args_(self): + if cuda_backend.offloading: + if not self.is_available_on_device(): + self.ensure_availability_on_device() + + return (self._cuda_values.gpudata,) + else: + return super(Map, self)._kernel_args_ + + +class ExtrudedSet(BaseExtrudedSet): + """ + ExtrudedSet for CUDA. + """ + + def __init__(self, *args, **kwargs): + super(ExtrudedSet, self).__init__(*args, **kwargs) + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def cuda_layers_array(self): + self._availability_flag = AVAILABLE_ON_BOTH + return cuda_np.to_gpu(self.layers_array) + + def get_availability(self): + return self._availability_flag + + def ensure_availability_on_device(self): + self.cuda_layers_array + + def ensure_availability_on_host(self): + # ExtrudedSet once initialized is not over-written so always available + # on host. + pass + + @property + def _kernel_args_(self): + if cuda_backend.offloading: + if not self.is_available_on_device(): + self.ensure_availability_on_device() + + return (self.cuda_layers_array.gpudata,) + else: + return super(ExtrudedSet, self)._kernel_args_ + + +class Subset(BaseSubset): + """ + Subset for CUDA. + """ + def __init__(self, *args, **kwargs): + super(Subset, self).__init__(*args, **kwargs) + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + def get_availability(self): + return self._availability_flag + + @utils.cached_property + def _cuda_indices(self): + self._availability_flag = AVAILABLE_ON_BOTH + return cuda_np.to_gpu(self._indices) + + def ensure_availability_on_device(self): + self._cuda_indices + + def ensure_availability_on_host(self): + # Subset once initialized is not over-written so always available + # on host. + pass + + @property + def _kernel_args_(self): + if cuda_backend.offloading: + if not self.is_available_on_device(): + self.ensure_availability_on_device() + + return (self._cuda_indices.gpudata,) + else: + return super(Subset, self)._kernel_args_ + + +class Dat(BaseDat): + """ + Dat for CUDA. + """ + def __init__(self, *args, **kwargs): + super(Dat, self).__init__(*args, **kwargs) + # _availability_flag: only used when Dat cannot be represented as a + # petscvec; when Dat can be represented as a petscvec the availability + # flag is directly read from the petsc vec. + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def _cuda_data(self): + """ + Only used when the Dat's data cannot be represented as a petsc Vec. + """ + self._availability_flag = AVAILABLE_ON_BOTH + if self.can_be_represented_as_petscvec: + with self.vec as petscvec: + return cuda_np.GPUArray(shape=self._data.shape, + dtype=self._data.dtype, + gpudata=petscvec.getCUDAHandle("r"), + strides=self._data.strides) + else: + return cuda_np.to_gpu(self._data) + + def zero(self, subset=None): + if subset is None: + self.data[:] = 0*self.data + self.halo_valid = True + else: + raise NotImplementedError + + def copy(self, other, subset=None): + raise NotImplementedError + + @utils.cached_property + def _vec(self): + assert self.dtype == PETSc.ScalarType, \ + "Can't create Vec with type %s, must be %s" % (self.dtype, + PETSc.ScalarType) + # Can't duplicate layout_vec of dataset, because we then + # carry around extra unnecessary data. + # But use getSizes to save an Allreduce in computing the + # global size. + size = self.dataset.layout_vec.getSizes() + data = self._data[:size[0]] + return PETSc.Vec().createCUDAWithArrays(data, size=size, + bsize=self.cdim, + comm=self.comm) + + def get_availability(self): + if self.can_be_represented_as_petscvec: + return DataAvailability(self._vec.getOffloadMask()) + else: + return self._availability_flag + + def ensure_availability_on_device(self): + if self.can_be_represented_as_petscvec: + if not cuda_backend.offloading: + raise NotImplementedError("PETSc limitation: can ensure availability" + " on GPU only within an offloading" + " context.") + # perform a host->device transfer if needed + self._vec.getCUDAHandle("r") + else: + if not self.is_available_on_device(): + self._cuda_data.set(self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + def ensure_availability_on_host(self): + if self.can_be_represented_as_petscvec: + # perform a device->host transfer if needed + self._vec.getArray(readonly=True) + else: + if not self.is_available_on_host(): + self._cuda_data.get(self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + @contextmanager + def vec_context(self, access): + r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. + + :param access: Access descriptor: READ, WRITE, or RW.""" + # PETSc Vecs have a state counter and cache norm computations + # to return immediately if the state counter is unchanged. + # Since we've updated the data behind their back, we need to + # change that state counter. + self._vec.stateIncrease() + + if cuda_backend.offloading: + self.ensure_availability_on_device() + self._vec.bindToCPU(False) + else: + self.ensure_availability_on_host() + self._vec.bindToCPU(True) + + yield self._vec + + if access is not READ: + self.halo_valid = False + + @property + def _kernel_args_(self): + if self.can_be_represented_as_petscvec: + if cuda_backend.offloading: + self.ensure_availability_on_device() + # tell petsc that we have updated the data in the CL buffer + with self.vec as v: + v.restoreCUDAHandle(v.getCUDAHandle()) + return (self._cuda_data.gpudata,) + else: + self.ensure_availability_on_host() + # tell petsc that we have updated the data on the host + with self.vec as v: + v.stateIncrease() + return (self._data.ctypes.data, ) + else: + if cuda_backend.offloading: + self.ensure_availability_on_device() + + self._availability_flag = AVAILABLE_ON_DEVICE_ONLY + return (self._cuda_data.gpudata, ) + else: + self.ensure_availability_on_host() + + self._availability_flag = AVAILABLE_ON_HOST_ONLY + return (self._data.ctypes.data, ) + + @mpi.collective + @property + def data(self): + if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0: + raise RuntimeError("Illegal access: no data associated with this Dat!") + self.halo_valid = False + + if cuda_backend.offloading: + + self.ensure_availability_on_device() + + # {{{ marking data on the host as invalid + + if self.can_be_represented_as_petscvec: + # report to petsc that we are updating values on the device + with self.vec as petscvec: + petscvec.restoreCUDAHandle(petscvec.getCUDAHandle("w")) + else: + self._availability_flag = AVAILABLE_ON_DEVICE_ONLY + + # }}} + v = self._cuda_data[:self.dataset.size] + else: + self.ensure_availability_on_host() + v = self._data[:self.dataset.size].view() + v.setflags(write=True) + + # {{{ marking data on the device as invalid + + if self.can_be_represented_as_petscvec: + # report to petsc that we are altering data on the CPU + with self.vec as petscvec: + petscvec.stateIncrease() + else: + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + # }}} + + return v + + @property + @mpi.collective + def data_with_halos(self): + self.global_to_local_begin(RW) + self.global_to_local_end(RW) + return self.data + + @property + @mpi.collective + def data_ro(self): + + if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0: + raise RuntimeError("Illegal access: no data associated with this Dat!") + + if cuda_backend.offloading: + self.ensure_availability_on_device() + v = self._cuda_data[:self.dataset.size].view() + # FIXME: PyCUDA doesn't support 'read-only' arrays yet + else: + self.ensure_availability_on_host() + v = self._data[:self.dataset.size].view() + v.setflags(write=False) + + return v + + @property + @mpi.collective + def data_ro_with_halos(self): + self.global_to_local_begin(READ) + self.global_to_local_end(READ) + + if cuda_backend.offloading: + self.ensure_availability_on_device() + v = self._cuda_data.view() + # FIXME: PyCUDA doesn't support 'read-only' arrays yet + else: + self.ensure_availability_on_host() + v = self._data.view() + v.setflags(write=False) + + return v + + +class Global(BaseGlobal): + """ + Global for CUDA. + """ + + def __init__(self, *args, **kwargs): + super(Global, self).__init__(*args, **kwargs) + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def _cuda_data(self): + self._availability_flag = AVAILABLE_ON_BOTH + return cuda_np.to_gpu(self._data) + + def get_availability(self): + return self._availability_flag + + def ensure_availability_on_device(self): + if not self.is_available_on_device(): + self._cuda_data.set(self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + def ensure_availability_on_host(self): + if not self.is_available_on_host(): + self._cuda_data.get(ary=self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + @property + def _kernel_args_(self): + if cuda_backend.offloading: + self.ensure_availability_on_device() + self._availability_flag = AVAILABLE_ON_DEVICE_ONLY + return (self._cuda_data.gpudata,) + else: + self.ensure_availability_on_host() + self._availability_flag = AVAILABLE_ON_HOST_ONLY + return super(Global, self)._kernel_args_ + + @mpi.collective + @property + def data(self): + + if len(self._data) == 0: + raise RuntimeError("Illegal access: No data associated with" + " this Global!") + + if cuda_backend.offloading: + self.ensure_availability_on_device() + self._availability_flag = AVAILABLE_ON_DEVICE_ONLY + return self._cuda_data + else: + self.ensure_availability_on_host() + self._availability_flag = AVAILABLE_ON_HOST_ONLY + return self._data + + @property + def data_ro(self): + if len(self._data) == 0: + raise RuntimeError("Illegal access: No data associated with" + " this Global!") + + if cuda_backend.offloading: + self.ensure_availability_on_device() + view = self._cuda_data.view() + # FIXME: PyCUDA doesn't support read-only arrays yet + return view + else: + self.ensure_availability_on_host() + view = self._data.view() + view.setflags(write=False) + return view + + @utils.cached_property + def _vec(self): + raise NotImplementedError() + + @contextmanager + def vec_context(self, access): + raise NotImplementedError() + + +@dataclass(frozen=True) +class CUFunctionWithExtraArgs: + """ + A partial :class:`pycuda.driver.Function` with bound arguments *args_to_append* + that are appended to the arguments passed in :meth:`__call__`. + """ + cu_func: cuda.Function + args_to_append: Tuple[cuda_np.GPUArray, ...] + + def __call__(self, grid, block, *args): + if all(grid + block): + self.cu_func.prepared_call(grid, block, + *args, + *[arg_to_append.gpudata + for arg_to_append in self.args_to_append]) + + +class GlobalKernel(AbstractGlobalKernel): + @classmethod + def _cache_key(cls, *args, **kwargs): + key = super()._cache_key(*args, **kwargs) + key = key + (configuration["gpu_strategy"], "cuda") + return key + + @utils.cached_property + def encoded_cache_key(self): + return md5(str(self.cache_key[1:]).encode()).hexdigest() + + @utils.cached_property + def computational_grid_expr_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_grid_params.py") + + @utils.cached_property + def extra_args_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_extra_args.npz") + + @memoize_method + def get_grid_size(self, start, end): + """ + Returns a :class:`tuple` of the form ``(grid, block)``, where *grid* is + the 2-:class:`tuple` corresponding to the CUDA computational grid size + and ``block`` is the thread block size. Refer to + `CUDA docs `_ + for a detailed explanation of these terms. + """ # noqa: E501 + fpath = self.computational_grid_expr_cache_file_path + + with open(fpath, "r") as f: + globals_dict = {} + exec(f.read(), globals_dict) + get_grid_sizes = globals_dict["get_grid_sizes"] + + grid, block = get_grid_sizes(start=start, end=end) + + # TODO: PyCUDA doesn't allow numpy int's for grid dimensions. Remove + # these casts after PyCUDA has been patched. + return (tuple(int(dim) for dim in grid), + tuple(int(dim) for dim in block)) + + @memoize_method + def get_extra_args(self): + """ + Returns device buffers corresponding to array literals baked into + :attr:`local_kernel`. + """ + fpath = self.extra_args_cache_file_path + npzfile = numpy.load(fpath) + assert npzfile["ids"].ndim == 1 + assert len(npzfile.files) == len(npzfile["ids"]) + 1 + extra_args_np = [npzfile[arg_id] + for arg_id in npzfile["ids"]] + + return tuple([cuda_np.to_gpu(arg_np) for arg_np in extra_args_np]) + + @utils.cached_property + def argtypes(self): + result = super().argtypes + return result + (ctypes.c_voidp,) * len(self.get_extra_args()) + + @utils.cached_property + def code_to_compile(self): + from pyop2.codegen.rep2loopy import generate + from pyop2.transforms.gpu_utils import apply_gpu_transforms + from pymbolic.interop.ast import to_evaluatable_python_function + + t_unit = generate(self.builder, + include_math=False, + include_petsc=False, + include_complex=False) + + # Make temporary variables with initializers kernel's arguments. + t_unit, extra_args = apply_gpu_transforms(t_unit, "cuda") + + ary_ids = [f"_op2_arg_{i}" + for i in range(len(extra_args))] + + numpy.savez(self.extra_args_cache_file_path, + ids=numpy.array(ary_ids), + **{ary_id: extra_arg + for ary_id, extra_arg in zip(ary_ids, extra_args)}) + + # {{{ save python code to get grid sizes + + with open(self.computational_grid_expr_cache_file_path, "w") as f: + glens, llens = (t_unit + .default_entrypoint + .get_grid_size_upper_bounds_as_exprs(t_unit + .callables_table)) + glens = glens + (1,) * (2 - len(glens)) # cuda expects 2d grid shape + llens = llens + (1,) * (3 - len(llens)) # cuda expects 3d block shape + f.write(to_evaluatable_python_function((glens, llens), "get_grid_sizes")) + + # }}} + + code = lp.generate_code_v2(t_unit).device_code() + + return code + + @mpi.collective + def __call__(self, comm, *args): + key = id(comm) + try: + func = self._func_cache[key] + except KeyError: + func = self.compile(comm) + self._func_cache[key] = func + + grid, block = self.get_grid_size(args[0], args[1]) + func(grid, block, *args) + + @mpi.collective + def compile(self, comm): + cu_func = compilation.get_prepared_cuda_function(comm, + self) + return CUFunctionWithExtraArgs(cu_func, self.get_extra_args()) + + +class Parloop(AbstractParloop): + @PETSc.Log.EventDecorator("ParLoopRednBegin") + @mpi.collective + def reduction_begin(self): + """Begin reductions.""" + requests = [] + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + mpi_op = {INC: mpi.MPI.SUM, + MIN: mpi.MPI.MIN, + MAX: mpi.MPI.MAX}.get(self.accesses[idx]) + + if mpi.MPI.VERSION >= 3: + # FIXME: Get rid of this D2H for CUDA-Aware MPI. + glob.ensure_availability_on_host() + requests.append(self.comm.Iallreduce(glob._data, + glob._buf, + op=mpi_op)) + else: + # FIXME: Get rid of this D2H for CUDA-Aware MPI. + glob.ensure_availability_on_host() + self.comm.Allreduce(glob._data, glob._buf, op=mpi_op) + return tuple(requests) + + @PETSc.Log.EventDecorator("ParLoopRednEnd") + @mpi.collective + def reduction_end(self, requests): + """Finish reductions.""" + if mpi.MPI.VERSION >= 3: + mpi.MPI.Request.Waitall(requests) + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + glob._data[:] = glob._buf + glob._availability_flag = AVAILABLE_ON_HOST_ONLY + # FIXME: Get rid of this H2D for CUDA-Aware MPI. + glob.ensure_availability_on_device() + else: + assert len(requests) == 0 + + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + glob._data[:] = glob._buf + glob._availability_flag = AVAILABLE_ON_HOST_ONLY + # FIXME: Get rid of this H2D for CUDA-Aware MPI. + glob.ensure_availability_on_device() + + +class CUDABackend(AbstractComputeBackend): + Parloop_offloading = Parloop + Parloop_no_offloading = cpu_backend.Parloop + GlobalKernel_offloading = GlobalKernel + GlobalKernel_no_offloading = cpu_backend.GlobalKernel + + Parloop = cpu_backend.Parloop + GlobalKernel = cpu_backend.GlobalKernel + Set = Set + ExtrudedSet = ExtrudedSet + MixedSet = MixedSet + Subset = Subset + DataSet = DataSet + MixedDataSet = MixedDataSet + Map = Map + MixedMap = MixedMap + Dat = Dat + MixedDat = MixedDat + DatView = DatView + Mat = Mat + Global = Global + GlobalDataSet = GlobalDataSet + PETScVecType = PETSc.Vec.Type.CUDA + + def __init__(self): + self.offloading = False + + def turn_on_offloading(self): + self.offloading = True + self.Parloop = self.Parloop_offloading + self.GlobalKernel = self.GlobalKernel_offloading + + def turn_off_offloading(self): + self.offloading = False + self.Parloop = self.Parloop_no_offloading + self.GlobalKernel = self.GlobalKernel_no_offloading + + @property + def cache_key(self): + return (type(self), self.offloading) + + +cuda_backend = CUDABackend() diff --git a/pyop2/backends/opencl.py b/pyop2/backends/opencl.py new file mode 100644 index 000000000..2df3073e2 --- /dev/null +++ b/pyop2/backends/opencl.py @@ -0,0 +1,376 @@ +"""OP2 OpenCL backend.""" + +import os +from hashlib import md5 +from contextlib import contextmanager + +from pyop2 import compilation, mpi, utils +from pyop2.offload_utils import (AVAILABLE_ON_HOST_ONLY, + AVAILABLE_ON_DEVICE_ONLY, + AVAILABLE_ON_BOTH, + DataAvailability) +from pyop2.configuration import configuration +from pyop2.types.set import (MixedSet, Subset, + ExtrudedSet, + Set) +from pyop2.types.map import Map, MixedMap +from pyop2.types.dat import Dat as BaseDat, MixedDat, DatView +from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet +from pyop2.types.mat import Mat +from pyop2.types.glob import Global as BaseGlobal +from pyop2.types.access import RW, READ, MIN, MAX, INC +from pyop2.parloop import Parloop +from pyop2.global_kernel import AbstractGlobalKernel +from pyop2.backends import AbstractComputeBackend, cpu as cpu_backend +from petsc4py import PETSc + +import numpy +import loopy as lp +from pytools import memoize_method +from dataclasses import dataclass +from typing import Tuple +import ctypes + +import pyopencl as cl +import pyopencl.array as cla + + +def read_only_clarray_setitem(self, *args, **kwargs): + # emulates np.ndarray.setitem for numpy arrays with read-only flags + raise ValueError("assignment destination is read-only") + + +class Dat(BaseDat): + """ + Dat for OpenCL. + """ + def __init__(self, *args, **kwargs): + super(Dat, self).__init__(*args, **kwargs) + # _availability_flag: only used when Dat cannot be represented as a + # petscvec; when Dat can be represented as a petscvec the availability + # flag is directly read from the petsc vec. + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def _opencl_data(self): + """ + Only used when the Dat's data cannot be represented as a petsc Vec. + """ + self._availability_flag = AVAILABLE_ON_BOTH + if self.can_be_represented_as_petscvec: + with self.vec as petscvec: + return cla.Array(cq=opencl_backend.queue, + shape=self._data.shape, + dtype=self._data.dtype, + data=cl.Buffer.from_int_ptr( + petscvec.getCLMemHandle("r"), + retain=False), + strides=self._data.strides) + else: + return cla.to_device(opencl_backend.queue, self._data) + + @utils.cached_property + def _vec(self): + assert self.dtype == PETSc.ScalarType, \ + "Can't create Vec with type %s, must be %s" % (self.dtype, + PETSc.ScalarType) + # Can't duplicate layout_vec of dataset, because we then + # carry around extra unnecessary data. + # But use getSizes to save an Allreduce in computing the + # global size. + size = self.dataset.layout_vec.getSizes() + data = self._data[:size[0]] + return PETSc.Vec().createViennaCLWithArrays(data, size=size, + bsize=self.cdim, + comm=self.comm) + + def get_availability(self): + if self.can_be_represented_as_petscvec: + return DataAvailability(self._vec.getOffloadMask()) + else: + return self._availability_flag + + def ensure_availability_on_device(self): + if self.can_be_represented_as_petscvec: + if not opencl_backend.offloading: + raise NotImplementedError("PETSc limitation: can ensure availability" + " on GPU only within an offloading" + " context.") + # perform a host->device transfer if needed + self._vec.getCLMemHandle("r") + else: + if not self.is_available_on_device(): + self._opencl_data.set(self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + def ensure_availability_on_host(self): + if self.can_be_represented_as_petscvec: + # perform a device->host transfer if needed + self._vec.getArray(readonly=True) + else: + if not self.is_available_on_host(): + self._opencl_data.get(opencl_backend.queue, self._data) + self._availability_flag = AVAILABLE_ON_BOTH + + @contextmanager + def vec_context(self, access): + r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. + + :param access: Access descriptor: READ, WRITE, or RW.""" + # PETSc Vecs have a state counter and cache norm computations + # to return immediately if the state counter is unchanged. + # Since we've updated the data behind their back, we need to + # change that state counter. + self._vec.stateIncrease() + + if opencl_backend.offloading: + self.ensure_availability_on_device() + self._vec.bindToCPU(False) + else: + self.ensure_availability_on_host() + self._vec.bindToCPU(True) + + yield self._vec + + if access is not READ: + self.halo_valid = False + + +class Global(BaseGlobal): + """ + Global for OpenCL. + """ + + def __init__(self, *args, **kwargs): + super(Global, self).__init__(*args, **kwargs) + self._availability_flag = AVAILABLE_ON_HOST_ONLY + + @utils.cached_property + def _opencl_data(self): + self._availability_flag = AVAILABLE_ON_BOTH + return cla.to_device(opencl_backend.queue, self._data) + + def get_availability(self): + return self._availability_flag + + def ensure_availability_on_device(self): + if not self.is_available_on_device(): + self._opencl_data.set(ary=self._data, + queue=opencl_backend.queue) + self._availability_flag = AVAILABLE_ON_BOTH + + def ensure_availability_on_host(self): + if not self.is_available_on_host(): + self._opencl_data.get(ary=self._data, + queue=opencl_backend.queue) + self._availability_flag = AVAILABLE_ON_BOTH + + @property + def _kernel_args_(self): + if opencl_backend.offloading: + self.ensure_availability_on_device() + self._availability_flag = AVAILABLE_ON_DEVICE_ONLY + return (self._opencl_data.data,) + else: + self.ensure_availability_on_host() + self._availability_flag = AVAILABLE_ON_HOST_ONLY + return super(Global, self)._kernel_args_ + + @utils.cached_property + def _vec(self): + raise NotImplementedError() + + @contextmanager + def vec_context(self, access): + raise NotImplementedError() + + +@dataclass(frozen=True) +class CLKernelWithExtraArgs: + cl_kernel: cl.Kernel + args_to_append: Tuple[cla.Array, ...] + + def __call__(self, grid, block, start, end, *args): + if all(grid + block): + self.cl_kernel(opencl_backend.queue, + grid, block, + numpy.int32(start), + numpy.int32(end), + *args, + *[arg_to_append.data + for arg_to_append in self.args_to_append], + g_times_l=True) + + +class GlobalKernel(AbstractGlobalKernel): + @classmethod + def _cache_key(cls, *args, **kwargs): + key = super()._cache_key(*args, **kwargs) + key = key + (configuration["gpu_strategy"], "opencl") + return key + + @utils.cached_property + def encoded_cache_key(self): + return md5(str(self.cache_key[1:]).encode()).hexdigest() + + @utils.cached_property + def computational_grid_expr_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_grid_params.py") + + @utils.cached_property + def extra_args_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_extra_args.npz") + + @memoize_method + def get_grid_size(self, start, end): + fpath = self.computational_grid_expr_cache_file_path + + with open(fpath, "r") as f: + globals_dict = {} + exec(f.read(), globals_dict) + get_grid_sizes = globals_dict["get_grid_sizes"] + + return get_grid_sizes(start=start, end=end) + + @memoize_method + def get_extra_args(self): + """ + Returns device buffers corresponding to array literals baked into + :attr:`local_kernel`. + """ + fpath = self.extra_args_cache_file_path + npzfile = numpy.load(fpath) + assert npzfile["ids"].ndim == 1 + assert len(npzfile.files) == len(npzfile["ids"]) + 1 + extra_args_np = [npzfile[arg_id] + for arg_id in npzfile["ids"]] + return tuple(cla.to_device(opencl_backend.queue, arg) + for arg in extra_args_np) + + @utils.cached_property + def argtypes(self): + result = super().argtypes + return result + (ctypes.c_voidp,) * len(self.get_extra_args()) + + @utils.cached_property + def code_to_compile(self): + from pyop2.codegen.rep2loopy import generate + from pyop2.transforms.gpu_utils import apply_gpu_transforms + from pymbolic.interop.ast import to_evaluatable_python_function + + t_unit = generate(self.builder, + include_math=False, + include_petsc=False, + include_complex=False) + + # Make temporary variables with initializers kernel's arguments. + t_unit, extra_args = apply_gpu_transforms(t_unit, "opencl") + + ary_ids = [f"_op2_arg_{i}" + for i in range(len(extra_args))] + + numpy.savez(self.extra_args_cache_file_path, + ids=numpy.array(ary_ids), + **{ary_id: extra_arg + for ary_id, extra_arg in zip(ary_ids, extra_args)}) + + # {{{ save python code to get grid sizes + + with open(self.computational_grid_expr_cache_file_path, "w") as f: + glens, llens = (t_unit + .default_entrypoint + .get_grid_size_upper_bounds_as_exprs(t_unit + .callables_table)) + f.write(to_evaluatable_python_function((glens, llens), "get_grid_sizes")) + + code = lp.generate_code_v2(t_unit).device_code() + + # }}} + + return code + + @mpi.collective + def __call__(self, comm, *args): + key = id(comm) + try: + func = self._func_cache[key] + except KeyError: + func = self.compile(comm) + self._func_cache[key] = func + + grid, block = self.get_grid_size(args[0], args[1]) + func(grid, block, *args) + + @mpi.collective + def compile(self, comm): + cl_knl = compilation.get_opencl_kernel(comm, + self) + return CLKernelWithExtraArgs(cl_knl, self.get_extra_args()) + + +class OpenCLBackend(AbstractComputeBackend): + Parloop_offloading = Parloop + Parloop_no_offloading = cpu_backend.Parloop + GlobalKernel_offloading = GlobalKernel + GlobalKernel_no_offloading = cpu_backend.GlobalKernel + + Parloop = cpu_backend.Parloop + GlobalKernel = cpu_backend.GlobalKernel + Set = Set + ExtrudedSet = ExtrudedSet + MixedSet = MixedSet + Subset = Subset + DataSet = DataSet + MixedDataSet = MixedDataSet + Map = Map + MixedMap = MixedMap + Dat = Dat + MixedDat = MixedDat + DatView = DatView + Mat = Mat + Global = Global + GlobalDataSet = GlobalDataSet + PETScVecType = PETSc.Vec.Type.VIENNACL + + def __init__(self): + self.offloading = False + + @utils.cached_property + def context(self): + # create a dummy vector and extract its underlying context + x = PETSc.Vec().create(PETSc.COMM_WORLD) + x.setType("viennacl") + x.setSizes(size=1) + ctx_ptr = x.getCLContextHandle() + return cl.Context.from_int_ptr(ctx_ptr, retain=False) + + @utils.cached_property + def queue(self): + # TODO: Instruct the user to pass + # -viennacl_backend opencl + # -viennacl_opencl_device_type gpu + # create a dummy vector and extract its associated command queue + x = PETSc.Vec().create(PETSc.COMM_WORLD) + x.setType("viennacl") + x.setSizes(size=1) + queue_ptr = x.getCLQueueHandle() + return cl.CommandQueue.from_int_ptr(queue_ptr, retain=False) + + def turn_on_offloading(self): + self.offloading = True + self.Parloop = self.Parloop_offloading + self.GlobalKernel = self.GlobalKernel_offloading + + def turn_off_offloading(self): + self.offloading = False + self.Parloop = self.Parloop_no_offloading + self.GlobalKernel = self.GlobalKernel_no_offloading + + @property + def cache_key(self): + return (type(self), self.offloading) + + +opencl_backend = OpenCLBackend() diff --git a/pyop2/codegen/rep2loopy.py b/pyop2/codegen/rep2loopy.py index 7085c324d..1fdb04fda 100644 --- a/pyop2/codegen/rep2loopy.py +++ b/pyop2/codegen/rep2loopy.py @@ -400,7 +400,7 @@ def bounds(exprs): return dict((op, (names[op], dep)) for op, dep in deps.items()) -def generate(builder, wrapper_name=None): +def generate(builder, wrapper_name=None, include_math=True, include_petsc=True, include_complex=True): if builder.layer_index is not None: outer_inames = frozenset([builder._loop_index.name, builder.layer_index.name]) @@ -529,6 +529,7 @@ def renamer(expr): assumptions = assumptions & pwaffd[parameters.layer_start].le_set(pwaffd[parameters.layer_end]) assumptions = reduce(operator.and_, assumptions.get_basic_sets()) + # import pdb; pdb.set_trace() wrapper = loopy.make_kernel(domains, statements, kernel_data=parameters.kernel_data, @@ -547,7 +548,16 @@ def renamer(expr): # register kernel kernel = builder.kernel headers = set(kernel.headers) - headers = headers | set(["#include ", "#include ", "#include "]) + + if include_math: + headers.add("#include ") + + if include_petsc: + headers.add("#include ") + + if include_complex: + headers.add("#include ") + if PETSc.Log.isActive(): headers = headers | set(["#include "]) preamble = "\n".join(sorted(headers)) @@ -622,15 +632,22 @@ def statement_assign(expr, context): if isinstance(lvalue, Indexed): context.index_ordering.append(tuple(i.name for i in lvalue.index_ordering())) lvalue, rvalue = tuple(expression(c, context.parameters) for c in expr.children) - within_inames = context.within_inames[expr] + if isinstance(expr.label, (PreUnpackInst, UnpackInst)): + tag = "scatter" + elif isinstance(expr.label, PackInst): + tag = "gather" + else: + raise NotImplementedError() + within_inames = context.within_inames[expr] id, depends_on = context.instruction_dependencies[expr] predicates = frozenset(context.conditions) return loopy.Assignment(lvalue, rvalue, within_inames=within_inames, within_inames_is_final=True, predicates=predicates, id=id, - depends_on=depends_on, depends_on_is_final=True) + depends_on=depends_on, depends_on_is_final=True, + tags=frozenset([tag])) @statement.register(FunctionCall) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 8fd7bf023..3449f9df1 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -147,6 +147,29 @@ def sniff_compiler(exe): return compiler +def _check_src_hashes(comm, global_kernel): + hsh = md5(str(global_kernel.cache_key[1:]).encode()) + basename = hsh.hexdigest() + dirpart, basename = basename[:2], basename[2:] + cachedir = configuration["cache_dir"] + cachedir = os.path.join(cachedir, dirpart) + + if configuration["check_src_hashes"] or configuration["debug"]: + matching = comm.allreduce(basename, op=_check_op) + if matching != basename: + # Dump all src code to disk for debugging + output = os.path.join(cachedir, "mismatching-kernels") + srcfile = os.path.join(output, "src-rank%d.c" % comm.rank) + if comm.rank == 0: + os.makedirs(output, exist_ok=True) + comm.barrier() + with open(srcfile, "w") as f: + f.write(global_kernel.code_to_compile) + comm.barrier() + raise CompilationError("Generated code differs across ranks" + f" (see output in {output})") + + class Compiler(ABC): """A compiler for shared libraries. @@ -317,19 +340,8 @@ def get_so(self, jitmodule, extension): # atomically (avoiding races). tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid)) - if configuration['check_src_hashes'] or configuration['debug']: - matching = self.comm.allreduce(basename, op=_check_op) - if matching != basename: - # Dump all src code to disk for debugging - output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, "src-rank%d.c" % self.comm.rank) - if self.comm.rank == 0: - os.makedirs(output, exist_ok=True) - self.comm.barrier() - with open(srcfile, "w") as f: - f.write(jitmodule.code_to_compile) - self.comm.barrier() - raise CompilationError("Generated code differs across ranks (see output in %s)" % output) + _check_src_hashes(self.comm, jitmodule) + try: # Are we in the cache? return ctypes.CDLL(soname) @@ -662,3 +674,81 @@ def clear_cache(prompt=False): shutil.rmtree(cachedir) else: print("Not removing cached libraries") + + +def _get_code_to_compile(comm, global_kernel): + # Determine cache key + hsh = md5(str(global_kernel.cache_key[1:]).encode()) + basename = hsh.hexdigest() + cachedir = configuration["cache_dir"] + dirpart, basename = basename[:2], basename[2:] + cachedir = os.path.join(cachedir, dirpart) + cname = os.path.join(cachedir, f"{basename}_code.cu") + + _check_src_hashes(comm, global_kernel) + + if os.path.isfile(cname): + # Are we in the cache? + with open(cname, "r") as f: + code_to_compile = f.read() + else: + # No, let"s go ahead and build + if comm.rank == 0: + # No need to do this on all ranks + os.makedirs(cachedir, exist_ok=True) + with progress(INFO, "Compiling wrapper"): + # make sure that compiles successfully before writing to file + code_to_compile = global_kernel.code_to_compile + with open(cname, "w") as f: + f.write(code_to_compile) + comm.barrier() + + return code_to_compile + + +@mpi.collective +def get_prepared_cuda_function(comm, global_kernel): + from pycuda.compiler import SourceModule + + # Determine cache key + hsh = md5(str(global_kernel.cache_key[1:]).encode()) + basename = hsh.hexdigest() + cachedir = configuration["cache_dir"] + dirpart, basename = basename[:2], basename[2:] + cachedir = os.path.join(cachedir, dirpart) + + nvcc_opts = ["-use_fast_math", "-w"] + + code_to_compile = _get_code_to_compile(comm, global_kernel) + source_module = SourceModule(code_to_compile, options=nvcc_opts, + cache_dir=cachedir) + + cu_func = source_module.get_function(global_kernel.name) + + type_map = {ctypes.c_void_p: "P", ctypes.c_int: "i"} + argtypes = "".join(type_map[t] for t in global_kernel.argtypes) + cu_func.prepare(argtypes) + + return cu_func + + +@mpi.collective +def get_opencl_kernel(comm, global_kernel): + import pyopencl as cl + from pyop2.backends.opencl import opencl_backend + cl_ctx = opencl_backend.context + + # Determine cache key + hsh = md5(str(global_kernel.cache_key[1:]).encode()) + basename = hsh.hexdigest() + cachedir = configuration["cache_dir"] + dirpart, basename = basename[:2], basename[2:] + cachedir = os.path.join(cachedir, dirpart) + + code_to_compile = _get_code_to_compile(comm, global_kernel) + + prg = cl.Program(cl_ctx, code_to_compile).build(options=[], + cache_dir=cachedir) + + cl_knl = cl.Kernel(prg, global_kernel.name) + return cl_knl diff --git a/pyop2/configuration.py b/pyop2/configuration.py index 29717718c..07681cc1c 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -74,10 +74,18 @@ class Configuration(dict): cdim > 1 be built as block sparsities, or dof sparsities. The former saves memory but changes which preconditioners are available for the resulting matrices. (Default yes) + :param gpu_strategy: A :class:str` indicating the transformation strategy + that must be applied to a :class:`pyop2.global_kernel.GlobalKernel` + when offloading to a GPGPU. Can be one of: + - ``"snpt"``: Single-"N" Per Thread. In the transform strategy, the + work of each element of the iteration set over which a global kernel + operates is assigned to a work-item (i.e. a CUDA thread) """ # name, env variable, type, default, write once cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid()) DEFAULTS = { + "backend": + ("PYOP2_BACKEND", str, "CPU_ONLY"), "cc": ("PYOP2_CC", str, ""), "cxx": @@ -113,7 +121,9 @@ class Configuration(dict): "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": - ("PYOP2_BLOCK_SPARSITY", bool, True) + ("PYOP2_BLOCK_SPARSITY", bool, True), + "gpu_strategy": + ("PYOP2_GPU_STRATEGY", str, "snpt"), } """Default values for PyOP2 configuration parameters""" diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 0277ef773..f0e39c8e5 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -1,20 +1,22 @@ -import collections.abc +import abc +import collections import ctypes +import os from dataclasses import dataclass import itertools -import os from typing import Optional, Tuple import loopy as lp -from petsc4py import PETSc import numpy as np +from petsc4py import PETSc -from pyop2 import compilation, mpi +from pyop2 import compilation, mpi, utils from pyop2.caching import Cached from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion -from pyop2.utils import cached_property, get_petsc_dir +from pyop2.offload_utils import OffloadingBackend, _backend as backend, _offloading as offloading +from pyop2.utils import cached_property # We set eq=False to force identity-based hashing. This is required for when @@ -226,7 +228,7 @@ def pack(self): return MatPack -class GlobalKernel(Cached): +class GlobalKernel(Cached, abc.ABC): """Class representing the generated code for the global computation. :param local_kernel: :class:`pyop2.LocalKernel` instance representing the @@ -266,6 +268,16 @@ def _cache_key(cls, local_knl, arguments, **kwargs): return tuple(key) + @classmethod + def new(cls, *args, **kwargs): + if offloading: + if backend == OffloadingBackend.CPU: + return CPUGlobalKernel(*args, **kwargs) + else: + raise NotImplementedError + else: + return CPUGlobalKernel(*args, **kwargs) + def __init__(self, local_kernel, arguments, *, extruded=False, extruded_periodic=False, @@ -348,6 +360,38 @@ def builder(self): builder.add_argument(arg) return builder + @cached_property + def argtypes(self): + """Return the ctypes datatypes of the compiled function.""" + # The first two arguments to the global kernel are the 'start' and 'stop' + # indices. All other arguments are declared to be void pointers. + dtypes = [as_ctypes(IntType)] * 2 + dtypes.extend([ctypes.c_voidp for _ in self.builder.wrapper_args[2:]]) + return tuple(dtypes) + + def num_flops(self, iterset): + """Compute the number of FLOPs done by the kernel.""" + size = 1 + if iterset._extruded: + region = self._iteration_region + layers = np.mean(iterset.layers_array[:, 1] - iterset.layers_array[:, 0]) + if region is IterationRegion.INTERIOR_FACETS: + size = layers - 2 + elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}: + size = layers - 1 + return size * self.local_kernel.num_flops + + @abc.abstractproperty + def code_to_compile(self): + """Return the C/C++ source code as a string.""" + + @abc.abstractmethod + def compile(self): + pass + + +class CPUGlobalKernel(GlobalKernel): + @cached_property def code_to_compile(self): """Return the C/C++ source code as a string.""" @@ -358,9 +402,10 @@ def code_to_compile(self): if self.local_kernel.cpp: from loopy.codegen.result import process_preambles - preamble = "".join(process_preambles(getattr(code, "device_preambles", []))) + preamble = "".join( + process_preambles(getattr(code, "device_preambles", []))) device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs) - return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n" + return preamble + '\nextern "C" {\n' + device_code + "\n}\n" return code.device_code() @PETSc.Log.EventDecorator() @@ -373,13 +418,13 @@ def compile(self, comm): """ extension = "cpp" if self.local_kernel.cpp else "c" cppargs = ( - tuple("-I%s/include" % d for d in get_petsc_dir()) + tuple("-I%s/include" % d for d in utils.get_petsc_dir()) + tuple("-I%s" % d for d in self.local_kernel.include_dirs) + ("-I%s" % os.path.abspath(os.path.dirname(__file__)),) ) ldargs = ( - tuple("-L%s/lib" % d for d in get_petsc_dir()) - + tuple("-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()) + tuple("-L%s/lib" % d for d in utils.get_petsc_dir()) + + tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()) + ("-lpetsc", "-lm") + tuple(self.local_kernel.ldargs) ) @@ -390,23 +435,113 @@ def compile(self, comm): restype=ctypes.c_int, comm=comm) - @cached_property + +class OpenCLGlobalKernel(GlobalKernel): + @classmethod + def _cache_key(cls, *args, **kwargs): + key = super()._cache_key(*args, **kwargs) + key = key + (configuration["gpu_strategy"], "opencl") + return key + + @utils.cached_property + def encoded_cache_key(self): + return md5(str(self.cache_key[1:]).encode()).hexdigest() + + @utils.cached_property + def computational_grid_expr_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_grid_params.py") + + @utils.cached_property + def extra_args_cache_file_path(self): + cachedir = configuration["cache_dir"] + return os.path.join(cachedir, f"{self.encoded_cache_key}_extra_args.npz") + + # @memoize_method + def get_grid_size(self, start, end): + fpath = self.computational_grid_expr_cache_file_path + + with open(fpath, "r") as f: + globals_dict = {} + exec(f.read(), globals_dict) + get_grid_sizes = globals_dict["get_grid_sizes"] + + return get_grid_sizes(start=start, end=end) + + # @memoize_method + def get_extra_args(self): + """ + Returns device buffers corresponding to array literals baked into + :attr:`local_kernel`. + """ + fpath = self.extra_args_cache_file_path + npzfile = numpy.load(fpath) + assert npzfile["ids"].ndim == 1 + assert len(npzfile.files) == len(npzfile["ids"]) + 1 + extra_args_np = [npzfile[arg_id] + for arg_id in npzfile["ids"]] + return tuple(cla.to_device(opencl_backend.queue, arg) + for arg in extra_args_np) + + @utils.cached_property def argtypes(self): - """Return the ctypes datatypes of the compiled function.""" - # The first two arguments to the global kernel are the 'start' and 'stop' - # indices. All other arguments are declared to be void pointers. - dtypes = [as_ctypes(IntType)] * 2 - dtypes.extend([ctypes.c_voidp for _ in self.builder.wrapper_args[2:]]) - return tuple(dtypes) + result = super().argtypes + return result + (ctypes.c_voidp,) * len(self.get_extra_args()) - def num_flops(self, iterset): - """Compute the number of FLOPs done by the kernel.""" - size = 1 - if iterset._extruded: - region = self._iteration_region - layers = np.mean(iterset.layers_array[:, 1] - iterset.layers_array[:, 0]) - if region is IterationRegion.INTERIOR_FACETS: - size = layers - 2 - elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}: - size = layers - 1 - return size * self.local_kernel.num_flops + @utils.cached_property + def code_to_compile(self): + from pyop2.codegen.rep2loopy import generate + from pyop2.transforms.gpu_utils import apply_gpu_transforms + from pymbolic.interop.ast import to_evaluatable_python_function + + t_unit = generate(self.builder, + include_math=False, + include_petsc=False, + include_complex=False) + + # Make temporary variables with initializers kernel's arguments. + t_unit, extra_args = apply_gpu_transforms(t_unit, "opencl") + + ary_ids = [f"_op2_arg_{i}" + for i in range(len(extra_args))] + + numpy.savez(self.extra_args_cache_file_path, + ids=numpy.array(ary_ids), + **{ary_id: extra_arg + for ary_id, extra_arg in zip(ary_ids, extra_args)}) + + # {{{ save python code to get grid sizes + + with open(self.computational_grid_expr_cache_file_path, "w") as f: + glens, llens = (t_unit + .default_entrypoint + .get_grid_size_upper_bounds_as_exprs(t_unit + .callables_table)) + f.write(to_evaluatable_python_function((glens, llens), "get_grid_sizes")) + + code = lp.generate_code_v2(t_unit).device_code() + + # }}} + + return code + + @mpi.collective + def __call__(self, comm, *args): + key = id(comm) + try: + func = self._func_cache[key] + except KeyError: + func = self.compile(comm) + self._func_cache[key] = func + + grid, block = self.get_grid_size(args[0], args[1]) + func(grid, block, *args) + + @mpi.collective + def compile(self, comm): + cl_knl = compilation.get_opencl_kernel(comm, + self) + return CLKernelWithExtraArgs(cl_knl, self.get_extra_args()) + +class CUDAGlobalKernel(GlobalKernel): + ... diff --git a/pyop2/offload_utils.py b/pyop2/offload_utils.py new file mode 100644 index 000000000..d3bd91239 --- /dev/null +++ b/pyop2/offload_utils.py @@ -0,0 +1,56 @@ +from contextlib import contextmanager +import enum + + +class OffloadingBackend(enum.Enum): + """TODO""" + CPU = enum.auto() + OPENCL = enum.auto() + CUDA = enum.auto() + + +_offloading = False +"""Global state indicating whether or not we are running on the host or device""" + +_backend = OffloadingBackend.CPU +"""TODO""" + + +def set_offloading_backend(backend): + """ + Sets a backend for offloading operations to. + + By default the operations are executed on the host, to mark operations for + offloading wrap them within a :func:`~pyop2.offload_utils.offloading` + context. + + :arg backend: An instance of :class:`pyop2.backends.AbstractComputeBackend`. + + .. warning:: + + * Must be called before any instance of PyOP2 type is allocated. + * Calling :func:`set_offloading_bacckend` different values of *backend* + over the course of the program is an undefined behavior. (i.e. + preferably avoided) + """ + global _backend + + if not isinstance(backend, OffloadingBackend): + raise TypeError("TODO") + _backend = backend + + +@contextmanager +def offloading(): + """ + Operations (such as manipulating a :class:`~pyop2.types.Dat`, calling a + :class:`~pyop2.global_kernel.GlobalKernel`, etc) within the offloading + region will be executed on backend as selected via + :func:`set_offloading_backend`. + """ + global _offloading + + _offloading = True + yield + _offloading = False + return diff --git a/pyop2/op2.py b/pyop2/op2.py index 726168e79..10a4cfb81 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -49,13 +49,15 @@ ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL) from pyop2.local_kernel import CStringLocalKernel, LoopyLocalKernel, CoffeeLocalKernel, Kernel # noqa: F401 -from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401 - MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel) +from pyop2.global_kernel import (GlobalKernel, GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401 + MatKernelArg, MixedMatKernelArg, MapKernelArg) from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401 - MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop) + MatParloopArg, MixedMatParloopArg, Parloop, + parloop, par_loop) from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401 MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop) +from pyop2.offload_utils import set_offloading_backend, offloading __all__ = ['configuration', 'READ', 'WRITE', 'RW', 'INC', 'MIN', 'MAX', 'ON_BOTTOM', 'ON_TOP', 'ON_INTERIOR_FACETS', 'ALL', @@ -64,7 +66,8 @@ 'MixedSet', 'Subset', 'DataSet', 'GlobalDataSet', 'MixedDataSet', 'Halo', 'Dat', 'MixedDat', 'Mat', 'Global', 'Map', 'MixedMap', 'Sparsity', 'parloop', 'Parloop', 'ParLoop', 'par_loop', - 'DatView', 'PermutedMap', 'ComposedMap'] + 'DatView', 'PermutedMap', 'ComposedMap', 'set_offloading_backend', + 'offloading'] _initialised = False diff --git a/pyop2/parloop.py b/pyop2/parloop.py index ac78e6bda..2c849c583 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -12,11 +12,12 @@ from pyop2.configuration import configuration from pyop2.datatypes import as_numpy_dtype from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError -from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, - MatKernelArg, MixedMatKernelArg, GlobalKernel) +from pyop2.global_kernel import (GlobalKernel, GlobalKernelArg, DatKernelArg, MixedDatKernelArg, + MatKernelArg, MixedMatKernelArg) from pyop2.local_kernel import LocalKernel, CStringLocalKernel, CoffeeLocalKernel, LoopyLocalKernel from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set, MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap) +from pyop2.types.access import READ, WRITE, RW, INC, MIN, MAX from pyop2.utils import cached_property @@ -60,7 +61,7 @@ def __post_init__(self): @property def map_kernel_args(self): - return self.map_._kernel_args_ if self.map_ else () + return self.map_.kernel_args if self.map_ else () @property def maps(self): @@ -82,7 +83,7 @@ def __post_init__(self): @property def map_kernel_args(self): - return self.map_._kernel_args_ if self.map_ else () + return self.map_.kernel_args if self.map_ else () @property def maps(self): @@ -104,7 +105,7 @@ def __post_init__(self): @property def map_kernel_args(self): rmap, cmap = self.maps - return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) + return tuple(itertools.chain(*itertools.product(rmap.kernel_args, cmap.kernel_args))) @dataclass @@ -122,7 +123,7 @@ def __post_init__(self): @property def map_kernel_args(self): rmap, cmap = self.maps - return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) + return tuple(itertools.chain(*itertools.product(rmap.kernel_args, cmap.kernel_args))) class Parloop: @@ -168,9 +169,18 @@ def accesses(self): @property def arglist(self): """Prepare the argument list for calling generated code.""" - arglist = self.iterset._kernel_args_ - for d in self.arguments: - arglist += d.data._kernel_args_ + arglist = self.iterset.kernel_args_ro + for larg, _, plarg in self.zipped_arguments: + if larg.access == READ: + kernel_args = plarg.data.kernel_args_ro + elif larg.access in {WRITE, MIN, MAX}: + kernel_args = plarg.data.kernel_args_wo + elif larg.access in {RW, INC}: + # TODO: Is this definitely right for INC? + kernel_args = plarg.data.kernel_args_rw + else: + raise NotImplementedError + arglist += kernel_args # Collect an ordered set of maps (ignore duplicates) maps = {m: None for d in self.arguments for m in d.map_kernel_args} @@ -226,7 +236,7 @@ def __call__(self): def zero_global_increments(self): """Zero any global increments every time the loop is executed.""" for g in self.reduced_globals.keys(): - g._data[...] = 0 + g.data[...] = 0 def replace_lgmaps(self): """Swap out any lgmaps for any :class:`MatParloopArg` instances @@ -357,14 +367,20 @@ def reduction_begin(self): requests = [] for idx in self._reduction_idxs: glob = self.arguments[idx].data - mpi_op = {Access.INC: mpi.MPI.SUM, - Access.MIN: mpi.MPI.MIN, - Access.MAX: mpi.MPI.MAX}.get(self.accesses[idx]) + # FIXME also do for buf + # if not glob.data.is_available_on_host: + # glob.data.device_to_host_copy() + + mpi_op = {INC: mpi.MPI.SUM, + MIN: mpi.MPI.MIN, + MAX: mpi.MPI.MAX}.get(self.accesses[idx]) if mpi.MPI.VERSION >= 3: - requests.append(self.comm.Iallreduce(glob._data, glob._buf, op=mpi_op)) + requests.append(self.comm.Iallreduce(glob.data, + glob._stashed_data.data, + op=mpi_op)) else: - self.comm.Allreduce(glob._data, glob._buf, op=mpi_op) + self.comm.Allreduce(glob.data, glob._stashed_data.data, op=mpi_op) return tuple(requests) @PETSc.Log.EventDecorator("ParLoopRednEnd") @@ -372,16 +388,16 @@ def reduction_begin(self): def reduction_end(self, requests): """Finish reductions.""" if mpi.MPI.VERSION >= 3: - for idx, req in zip(self._reduction_idxs, requests): - req.Wait() + mpi.MPI.Request.Waitall(requests) + for idx in self._reduction_idxs: glob = self.arguments[idx].data - glob._data[:] = glob._buf + glob._data.data[...] = glob._stashed_data.data else: assert len(requests) == 0 for idx in self._reduction_idxs: glob = self.arguments[idx].data - glob._data[:] = glob._buf + glob._data[...] = glob._stashed_data.data @cached_property def _reduction_idxs(self): @@ -393,7 +409,7 @@ def _reduction_idxs(self): def finalize_global_increments(self): """Finalise global increments.""" for tmp, glob in self.reduced_globals.items(): - glob.data._data += tmp._data + glob.data.data[:] += tmp.data_ro @mpi.collective def update_arg_data_state(self): @@ -467,7 +483,7 @@ def prepare_reduced_globals(self, arguments, global_knl): reduced_globals = {} for i, (lk_arg, gk_arg, pl_arg) in enumerate(self.zip_arguments(global_knl, arguments)): if isinstance(gk_arg, GlobalKernelArg) and lk_arg.access == Access.INC: - tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype, comm=self.comm) + tmp = compute_backend.Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype, comm=self.comm) reduced_globals[tmp] = pl_arg arguments[i] = GlobalParloopArg(tmp) @@ -632,12 +648,12 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs): extruded_periodic = iterset._extruded_periodic constant_layers = extruded and iterset.constant_layers subset = isinstance(iterset, Subset) - global_knl = GlobalKernel(local_knl, global_knl_args, - extruded=extruded, - extruded_periodic=extruded_periodic, - constant_layers=constant_layers, - subset=subset, - **kwargs) + global_knl = GlobalKernel.new(local_knl, global_knl_args, + extruded=extruded, + extruded_periodic=extruded_periodic, + constant_layers=constant_layers, + subset=subset, + **kwargs) parloop_args = tuple(a.parloop_arg for a in args) return Parloop(global_knl, iterset, parloop_args) diff --git a/pyop2/transforms/__init__.py b/pyop2/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pyop2/transforms/gpu_utils.py b/pyop2/transforms/gpu_utils.py new file mode 100644 index 000000000..28cd558e1 --- /dev/null +++ b/pyop2/transforms/gpu_utils.py @@ -0,0 +1,94 @@ +import loopy as lp +from pyop2.configuration import configuration + + +def get_loopy_target(target): + if target == "opencl": + return lp.PyOpenCLTarget() + elif target == "cuda": + return lp.CudaTarget() + else: + raise NotImplementedError() + + +def preprocess_t_unit_for_gpu(t_unit): + + # {{{ inline all kernels in t_unit + + kernels_to_inline = { + name for name, clbl in t_unit.callables_table.items() + if isinstance(clbl, lp.CallableKernel)} + + for knl_name in kernels_to_inline: + t_unit = lp.inline_callable_kernel(t_unit, knl_name) + + # }}} + + kernel = t_unit.default_entrypoint + + # changing the address space of temps + def _change_aspace_tvs(tv): + if tv.read_only: + assert tv.initializer is not None + return tv.copy(address_space=lp.AddressSpace.GLOBAL) + else: + return tv.copy(address_space=lp.AddressSpace.PRIVATE) + + new_tvs = {tv_name: _change_aspace_tvs(tv) for tv_name, tv in + kernel.temporary_variables.items()} + kernel = kernel.copy(temporary_variables=new_tvs) + + def insn_needs_atomic(insn): + # updates to global variables are atomic + import pymbolic + if isinstance(insn, lp.Assignment): + if isinstance(insn.assignee, pymbolic.primitives.Subscript): + assignee_name = insn.assignee.aggregate.name + else: + assert isinstance(insn.assignee, pymbolic.primitives.Variable) + assignee_name = insn.assignee.name + + if assignee_name in kernel.arg_dict: + return assignee_name in insn.read_dependency_names() + return False + + new_insns = [] + args_marked_for_atomic = set() + for insn in kernel.instructions: + if insn_needs_atomic(insn): + atomicity = (lp.AtomicUpdate(insn.assignee.aggregate.name), ) + insn = insn.copy(atomicity=atomicity) + args_marked_for_atomic |= set([insn.assignee.aggregate.name]) + + new_insns.append(insn) + + # label args as atomic + new_args = [] + for arg in kernel.args: + if arg.name in args_marked_for_atomic: + new_args.append(arg.copy(for_atomic=True)) + else: + new_args.append(arg) + + kernel = kernel.copy(instructions=new_insns, args=new_args) + + return t_unit.with_kernel(kernel) + + +def apply_gpu_transforms(t_unit, target): + t_unit = t_unit.copy(target=get_loopy_target(target)) + t_unit = preprocess_t_unit_for_gpu(t_unit) + kernel = t_unit.default_entrypoint + transform_strategy = configuration["gpu_strategy"] + + kernel = lp.assume(kernel, "end > start") + + if transform_strategy == "snpt": + from pyop2.transforms.snpt import split_n_across_workgroups + kernel, args_to_make_global = split_n_across_workgroups(kernel, 32) + else: + raise NotImplementedError(f"'{transform_strategy}' transform strategy.") + + t_unit = t_unit.with_kernel(kernel) + + return t_unit, args_to_make_global diff --git a/pyop2/transforms/snpt.py b/pyop2/transforms/snpt.py new file mode 100644 index 000000000..23c145f1c --- /dev/null +++ b/pyop2/transforms/snpt.py @@ -0,0 +1,50 @@ +import loopy as lp + + +def _make_tv_array_arg(tv): + assert tv.address_space != lp.AddressSpace.PRIVATE + arg = lp.ArrayArg(name=tv.name, + dtype=tv.dtype, + shape=tv.shape, + dim_tags=tv.dim_tags, + offset=tv.offset, + dim_names=tv.dim_names, + order=tv.order, + alignment=tv.alignment, + address_space=tv.address_space, + is_output=not tv.read_only, + is_input=tv.read_only) + return arg + + +def split_n_across_workgroups(kernel, workgroup_size): + """ + Returns a transformed version of *kernel* with the workload in the loop + with induction variable 'n' distributed across work-groups of size + *workgroup_size* and each work-item in the work-group performing the work + of a single iteration of 'n'. + """ + + kernel = lp.assume(kernel, "start < end") + kernel = lp.split_iname(kernel, "n", workgroup_size, + outer_tag="g.0", inner_tag="l.0") + + # {{{ making consts as globals: necessary to make the strategy emit valid + # kernels for all forms + + old_temps = kernel.temporary_variables.copy() + args_to_make_global = [tv.initializer.flatten() + for tv in old_temps.values() + if tv.initializer is not None] + + new_temps = {tv.name: tv + for tv in old_temps.values() + if tv.initializer is None} + kernel = kernel.copy(args=kernel.args+[_make_tv_array_arg(tv) + for tv in old_temps.values() + if tv.initializer is not None], + temporary_variables=new_temps) + + # }}} + + return kernel, args_to_make_global diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 9819c3452..78a632ecc 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -1,4 +1,4 @@ -import abc +import collections import contextlib import ctypes import itertools @@ -7,6 +7,7 @@ import loopy as lp import numpy as np from petsc4py import PETSc +import pytools from pyop2 import ( configuration as conf, @@ -15,13 +16,15 @@ mpi, utils ) -from pyop2.types.access import Access +from pyop2.types.access import READ, WRITE, RW, INC, MIN, MAX from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet -from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin +from pyop2.types.data_carrier import DataCarrier, VecAccessMixin from pyop2.types.set import ExtrudedSet, GlobalSet, Set +from pyop2.array import MirroredArray +from pyop2.offload_utils import _offloading as offloading -class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): +class AbstractDat(DataCarrier): """OP2 vector data. A :class:`Dat` holds values on every element of a :class:`DataSet`.o @@ -61,28 +64,35 @@ class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): _zero_kernels = {} """Class-level cache for zero kernels.""" - _modes = [Access.READ, Access.WRITE, Access.RW, Access.INC, Access.MIN, Access.MAX] + _modes = [READ, WRITE, RW, INC, MIN, MAX] @utils.validate_type(('dataset', (DataCarrier, DataSet, Set), ex.DataSetTypeError), ('name', str, ex.NameTypeError)) @utils.validate_dtype(('dtype', None, ex.DataTypeError)) def __init__(self, dataset, data=None, dtype=None, name=None): - if isinstance(dataset, Dat): self.__init__(dataset.dataset, None, dtype=dataset.dtype, name="copy_of_%s" % dataset.name) dataset.copy(self) return + + # If a Set, rather than a dataset is passed in, default to + # a dataset dimension of 1. if type(dataset) is Set or type(dataset) is ExtrudedSet: - # If a Set, rather than a dataset is passed in, default to - # a dataset dimension of 1. dataset = dataset ** 1 - self._shape = (dataset.total_size,) + (() if dataset.cdim == 1 else dataset.dim) - EmptyDataMixin.__init__(self, data, dtype, self._shape) - self._dataset = dataset + # handle the interplay between dataset, data and dtype + shape = (dataset.total_size,) + (() if dataset.cdim == 1 else dataset.dim) + if data is not None: + data = utils.verify_reshape(data, dtype, shape) + else: + dtype = np.dtype(dtype or DataCarrier.DEFAULT_DTYPE) + data = (shape, dtype) + self.comm = mpi.internal_comm(dataset.comm) self.halo_valid = True + self._dataset = dataset + self._data = MirroredArray.new(data) self._name = name or "dat_#x%x" % id(self) self._halo_frozen = False @@ -92,9 +102,21 @@ def __del__(self): if hasattr(self, "comm"): mpi.decref(self.comm) - @utils.cached_property - def _kernel_args_(self): - return (self._data.ctypes.data, ) + @property + def kernel_args_rw(self): + return (self._data.ptr_rw,) + + @property + def kernel_args_ro(self): + return (self._data.ptr_ro,) + + @property + def kernel_args_wo(self): + return (self._data.ptr_wo,) + + @property + def dat_version(self): + return self._data.state @utils.cached_property def _argtypes_(self): @@ -152,14 +174,10 @@ def data(self): :meth:`data_with_halos`. """ - # Increment dat_version since this accessor assumes data modification - self.increment_dat_version() if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0: raise RuntimeError("Illegal access: no data associated with this Dat!") self.halo_valid = False - v = self._data[:self.dataset.size].view() - v.setflags(write=True) - return v + return self._data.data[:self.dataset.size] @property @mpi.collective @@ -172,12 +190,10 @@ def data_with_halos(self): With this accessor, you get to see up to date halo values, but you should not try and modify them, because they will be overwritten by the next halo exchange.""" - self.global_to_local_begin(Access.RW) - self.global_to_local_end(Access.RW) + self.global_to_local_begin(RW) + self.global_to_local_end(RW) self.halo_valid = False - v = self._data.view() - v.setflags(write=True) - return v + return self._data.data @property @mpi.collective @@ -193,9 +209,7 @@ def data_ro(self): """ if self.dataset.total_size > 0 and self._data.size == 0 and self.cdim > 0: raise RuntimeError("Illegal access: no data associated with this Dat!") - v = self._data[:self.dataset.size].view() - v.setflags(write=False) - return v + return self._data.data_ro[:self.dataset.size] @property @mpi.collective @@ -211,11 +225,9 @@ def data_ro_with_halos(self): overwritten by the next halo exchange. """ - self.global_to_local_begin(Access.READ) - self.global_to_local_end(Access.READ) - v = self._data.view() - v.setflags(write=False) - return v + self.global_to_local_begin(READ) + self.global_to_local_end(READ) + return self._data.data_ro def save(self, filename): """Write the data array to file ``filename`` in NumPy format.""" @@ -238,13 +250,13 @@ def load(self, filename): else: self.data[:] = np.load(filename) - @utils.cached_property + @property def shape(self): - return self._shape + return self._data.shape - @utils.cached_property + @property def dtype(self): - return self._dtype + return self._data.dtype @utils.cached_property def nbytes(self): @@ -265,10 +277,9 @@ def zero(self, subset=None): :arg subset: A :class:`Subset` of entries to zero (optional).""" # Data modification - self.increment_dat_version() # If there is no subset we can safely zero the halo values. if subset is None: - self._data[:] = 0 + self.data_with_halos[...] = 0 self.halo_valid = True elif subset.superset != self.dataset.set: raise ex.MapValueError("The subset and dataset are incompatible") @@ -286,10 +297,11 @@ def copy(self, other, subset=None): if subset is None: # If the current halo is valid we can also copy these values across. if self.halo_valid: - other._data[:] = self._data + other.data_with_halos[...] = self.data_ro_with_halos other.halo_valid = True else: - other.data[:] = self.data_ro + other.data[...] = self.data_ro + other.halo_valid = False elif subset.superset != self.dataset.set: raise ex.MapValueError("The subset and dataset are incompatible") else: @@ -354,13 +366,13 @@ def _op(self, other, op): ret = Dat(self.dataset, None, self.dtype) if np.isscalar(other): - other = Global(1, data=other, comm=self.comm) + other = compute_backend.Global(1, data=other, comm=self.comm) globalp = True else: self._check_shape(other) globalp = False parloop(self._op_kernel(op, globalp, other.dtype), - self.dataset.set, self(Access.READ), other(Access.READ), ret(Access.WRITE)) + self.dataset.set, self(READ), other(READ), ret(WRITE)) return ret def _iop_kernel(self, op, globalp, other_is_self, dtype): @@ -398,18 +410,18 @@ def _iop_kernel(self, op, globalp, other_is_self, dtype): return self._iop_kernel_cache.setdefault(key, Kernel(knl, name)) def _iop(self, other, op): - from pyop2.parloop import parloop from pyop2.types.glob import Global + from pyop2.parloop import parloop globalp = False if np.isscalar(other): - other = Global(1, data=other, comm=self.comm) + other = compute_backend.Global(1, data=other, comm=self.comm) globalp = True elif other is not self: self._check_shape(other) - args = [self(Access.INC)] + args = [self(INC)] if other is not self: - args.append(other(Access.READ)) + args.append(other(READ)) parloop(self._iop_kernel(op, globalp, other is self, other.dtype), self.dataset.set, *args) return self @@ -450,9 +462,9 @@ def inner(self, other): from pyop2.types.glob import Global self._check_shape(other) - ret = Global(1, data=0, dtype=self.dtype, comm=self.comm) + ret = compute_backend.Global(1, data=0, dtype=self.dtype, comm=self.comm) parloop(self._inner_kernel(other.dtype), self.dataset.set, - self(Access.READ), other(Access.READ), ret(Access.INC)) + self(READ), other(READ), ret(INC)) return ret.data_ro[0] @property @@ -501,7 +513,7 @@ def __neg__(self): from pyop2.parloop import parloop neg = Dat(self.dataset, dtype=self.dtype) - parloop(self._neg_kernel, self.dataset.set, neg(Access.WRITE), self(Access.READ)) + parloop(self._neg_kernel, self.dataset.set, neg(WRITE), self(READ)) return neg def __sub__(self, other): @@ -555,11 +567,11 @@ def global_to_local_begin(self, access_mode): halo = self.dataset.halo if halo is None or self._halo_frozen: return - if not self.halo_valid and access_mode in {Access.READ, Access.RW}: - halo.global_to_local_begin(self, Access.WRITE) - elif access_mode in {Access.INC, Access.MIN, Access.MAX}: + if not self.halo_valid and access_mode in {READ, RW}: + halo.global_to_local_begin(self, WRITE) + elif access_mode in {INC, MIN, MAX}: min_, max_ = dtypes.dtype_limits(self.dtype) - val = {Access.MAX: min_, Access.MIN: max_, Access.INC: 0}[access_mode] + val = {MAX: min_, MIN: max_, INC: 0}[access_mode] self._data[self.dataset.size:] = val else: # WRITE @@ -574,10 +586,10 @@ def global_to_local_end(self, access_mode): halo = self.dataset.halo if halo is None or self._halo_frozen: return - if not self.halo_valid and access_mode in {Access.READ, Access.RW}: - halo.global_to_local_end(self, Access.WRITE) + if not self.halo_valid and access_mode in {READ, RW}: + halo.global_to_local_end(self, WRITE) self.halo_valid = True - elif access_mode in {Access.INC, Access.MIN, Access.MAX}: + elif access_mode in {INC, MIN, MAX}: self.halo_valid = False else: # WRITE @@ -661,9 +673,17 @@ def __init__(self, dat, index): name="view[%s](%s)" % (index, dat.name)) self._parent = dat - @utils.cached_property - def _kernel_args_(self): - return self._parent._kernel_args_ + @property + def kernel_args_rw(self): + return self._parent.kernel_args_rw + + @property + def kernel_args_ro(self): + return self._parent.kernel_args_ro + + @property + def kernel_args_wo(self): + return self._parent.kernel_args_wo @utils.cached_property def _argtypes_(self): @@ -711,33 +731,85 @@ def data_ro_with_halos(self): class Dat(AbstractDat, VecAccessMixin): - def __init__(self, *args, **kwargs): - AbstractDat.__init__(self, *args, **kwargs) - # Determine if we can rely on PETSc state counter - petsc_counter = (self.dtype == PETSc.ScalarType) - VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + super().__init__(*args, **kwargs) + self._lazy_petsc_vec = None - @utils.cached_property - def _vec(self): - assert self.dtype == PETSc.ScalarType, \ - "Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType) - # Can't duplicate layout_vec of dataset, because we then - # carry around extra unnecessary data. - # But use getSizes to save an Allreduce in computing the - # global size. - size = self.dataset.layout_vec.getSizes() - data = self._data[:size[0]] - return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm) + @property + def petsc_vec(self): + if self.dtype != dtypes.ScalarType: + raise ex.DataTypeError( + "Only arrays with dtype matching PETSc's scalar type can be " + "represented as a PETSc Vec") + if self._lazy_petsc_vec is None: + assert self._data._lazy_host_data is not None + + bsize = np.prod(self._data.shape[1:], dtype=int) + self._lazy_petsc_vec = PETSc.Vec().createWithArray( + self._data._lazy_host_data, size=self._data.size, bsize=bsize, comm=self.comm + ) + + return self._lazy_petsc_vec + + @property @contextlib.contextmanager - def vec_context(self, access): - r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`. + def vec(self): + """TODO""" + if offloading: + if not self._data.is_available_on_device: + self._data.host_to_device_copy() + self.petsc_vec.bindToCPU(False) + else: + if not self._data.is_available_on_host: + self._data.device_to_host_copy() + self.petsc_vec.bindToCPU(True) - :param access: Access descriptor: READ, WRITE, or RW.""" - yield self._vec - if access is not Access.READ: - self.halo_valid = False + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + self.halo_valid = False + if offloading: + self._data.is_available_on_host = False + else: + self._data.is_available_on_device = False + + @property + @contextlib.contextmanager + def vec_ro(self): + """TODO""" + if offloading: + if not self._data.is_available_on_device: + self._data.host_to_device_copy() + self.petsc_vec.bindToCPU(False) + else: + if not self._data.is_available_on_host: + self._data.device_to_host_copy() + self.petsc_vec.bindToCPU(True) + + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + @property + @contextlib.contextmanager + def vec_wo(self): + """TODO""" + if offloading: + self.petsc_vec.bindToCPU(False) + else: + self.petsc_vec.bindToCPU(True) + + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + self.halo_valid = False + if offloading: + self._data.is_available_on_host = False + else: + self._data.is_available_on_device = False class MixedDat(AbstractDat, VecAccessMixin): @@ -774,6 +846,8 @@ def what(x): # TODO: Think about different communicators on dats (c.f. MixedSet) self.comm = mpi.internal_comm(self._dats[0].comm) + self._lazy_petsc_vec = None + @property def dat_version(self): return sum(d.dat_version for d in self._dats) @@ -782,9 +856,17 @@ def __call__(self, access, path=None): from pyop2.parloop import MixedDatLegacyArg return MixedDatLegacyArg(self, path, access) - @utils.cached_property - def _kernel_args_(self): - return tuple(itertools.chain(*(d._kernel_args_ for d in self))) + @property + def kernel_args_rw(self): + return tuple(itertools.chain(*(d.kernel_args_rw for d in self))) + + @property + def kernel_args_ro(self): + return tuple(itertools.chain(*(d.kernel_args_ro for d in self))) + + @property + def kernel_args_wo(self): + return tuple(itertools.chain(*(d.kernel_args_wo for d in self))) @utils.cached_property def _argtypes_(self): @@ -813,12 +895,6 @@ def dataset(self): r""":class:`MixedDataSet`\s this :class:`MixedDat` is defined on.""" return MixedDataSet(tuple(s.dataset for s in self._dats)) - @utils.cached_property - def _data(self): - """Return the user-provided data buffer, or a zeroed buffer of - the correct size if none was provided.""" - return tuple(d._data for d in self) - @property @mpi.collective def data(self): @@ -1049,42 +1125,121 @@ def _vec(self): # because we're not placing an array. return self.dataset.layout_vec.duplicate() + @property + def petsc_vec(self): + if self._lazy_petsc_vec is None: + subvecs = [dat.petsc_vec for dat in self._dats] + self._lazy_petsc_vec = PETSc.Vec().createNest(subvecs, comm=self.comm) + + return self._lazy_petsc_vec + + @property @contextlib.contextmanager - def vec_context(self, access): - r"""A context manager scattering the arrays of all components of this - :class:`MixedDat` into a contiguous :class:`PETSc.Vec` and reverse - scattering to the original arrays when exiting the context. + def vec(self): + """TODO""" + for subdat in self._dats: + if offloading: + if not subdat._data.is_available_on_device: + subdat._data.host_to_device_copy() + subdat.petsc_vec.bindToCPU(False) + else: + if not subdat._data.is_available_on_host: + subdat._data.device_to_host_copy() + subdat.petsc_vec.bindToCPU(True) - :param access: Access descriptor: READ, WRITE, or RW. + subdat.petsc_vec.stateSet(subdat._data.state) - .. note:: + yield self.petsc_vec - The :class:`~PETSc.Vec` obtained from this context is in - the correct order to be left multiplied by a compatible - :class:`MixedMat`. In parallel it is *not* just a - concatenation of the underlying :class:`Dat`\s.""" - # Do the actual forward scatter to fill the full vector with - # values - if access is not Access.WRITE: - offset = 0 - with self._vec as array: - for d in self: - with d.vec_ro as v: - size = v.local_size - array[offset:offset+size] = v.array_r[:] - offset += size - - yield self._vec - if access is not Access.READ: - # Reverse scatter to get the values back to their original locations - offset = 0 - array = self._vec.array_r - for d in self: - with d.vec_wo as v: - size = v.local_size - v.array[:] = array[offset:offset+size] - offset += size - self.halo_valid = False + for subdat in self._dats: + subdat._data.state = subdat.petsc_vec.stateGet() + subdat.halo_valid = False + if offloading: + subdat._data.is_available_on_host = False + else: + subdat._data.is_available_on_device = False + + @property + @contextlib.contextmanager + def vec_ro(self): + """TODO""" + for subdat in self._dats: + if offloading: + if not subdat._data.is_available_on_device: + subdat._data.host_to_device_copy() + subdat.petsc_vec.bindToCPU(False) + else: + if not subdat._data.is_available_on_host: + subdat._data.device_to_host_copy() + subdat.petsc_vec.bindToCPU(True) + + subdat.petsc_vec.stateSet(subdat._data.state) + + yield self.petsc_vec + + for subdat in self._dats: + subdat._data.state = subdat.petsc_vec.stateGet() + + @property + @contextlib.contextmanager + def vec_wo(self): + """TODO""" + for subdat in self._dats: + if offloading: + subdat.petsc_vec.bindToCPU(False) + else: + subdat.petsc_vec.bindToCPU(True) + + subdat.petsc_vec.stateSet(subdat._data.state) + + yield self.petsc_vec + + for subdat in self._dats: + subdat._data.state = subdat.petsc_vec.stateGet() + + subdat.halo_valid = False + if offloading: + subdat._data.is_available_on_host = False + else: + subdat._data.is_available_on_device = False + + # TODO VecNest? + # @contextlib.contextmanager + # def vec_context(self, access): + # r"""A context manager scattering the arrays of all components of this + # :class:`MixedDat` into a contiguous :class:`PETSc.Vec` and reverse + # scattering to the original arrays when exiting the context. + # + # :param access: Access descriptor: READ, WRITE, or RW. + # + # .. note:: + # + # The :class:`~PETSc.Vec` obtained from this context is in + # the correct order to be left multiplied by a compatible + # :class:`MixedMat`. In parallel it is *not* just a + # concatenation of the underlying :class:`Dat`\s.""" + # # Do the actual forward scatter to fill the full vector with + # # values + # if access in {READ, RW}: + # offset = 0 + # with self.petsc_vec as array: + # for subdat in self: + # with subdat.vec_ro as vec: + # size = vec.local_size + # array[offset:offset+size] = vec.array_r + # offset += size + # + # yield self.petsc_vec + # if access is not Access.READ: + # # Reverse scatter to get the values back to their original locations + # offset = 0 + # array = self._vec.array_r + # for d in self: + # with d.vec_wo as v: + # size = v.local_size + # v.array[:] = array[offset:offset+size] + # offset += size + # self.halo_valid = False class frozen_halo: diff --git a/pyop2/types/data_carrier.py b/pyop2/types/data_carrier.py index 73d3974c2..9e0259b85 100644 --- a/pyop2/types/data_carrier.py +++ b/pyop2/types/data_carrier.py @@ -8,6 +8,7 @@ utils ) from pyop2.types.access import Access +from pyop2.array import MirroredArray class DataCarrier(abc.ABC): @@ -18,6 +19,8 @@ class DataCarrier(abc.ABC): (:class:`Global`), rank 1 (:class:`Dat`), or rank 2 (:class:`Mat`)""" + DEFAULT_DTYPE = dtypes.ScalarType + @utils.cached_property def dtype(self): """The Python type of the data.""" @@ -44,90 +47,27 @@ def cdim(self): the product of the dim tuple.""" return self._cdim - def increment_dat_version(self): - pass - - -class EmptyDataMixin(abc.ABC): - """A mixin for :class:`Dat` and :class:`Global` objects that takes - care of allocating data on demand if the user has passed nothing - in. - - Accessing the :attr:`_data` property allocates a zeroed data array - if it does not already exist. - """ - def __init__(self, data, dtype, shape): - if data is None: - self._dtype = np.dtype(dtype if dtype is not None else dtypes.ScalarType) - else: - self._numpy_data = utils.verify_reshape(data, dtype, shape, allow_none=True) - self._dtype = self._data.dtype - - @utils.cached_property - def _data(self): - """Return the user-provided data buffer, or a zeroed buffer of - the correct size if none was provided.""" - if not self._is_allocated: - self._numpy_data = np.zeros(self.shape, dtype=self._dtype) - return self._numpy_data - - @property - def _is_allocated(self): - """Return True if the data buffer has been allocated.""" - return hasattr(self, '_numpy_data') - class VecAccessMixin(abc.ABC): - - def __init__(self, petsc_counter=None): - if petsc_counter: - # Use lambda since `_vec` allocates the data buffer - # -> Dat/Global should not allocate storage until accessed - self._dat_version = lambda: self._vec.stateGet() - self.increment_dat_version = lambda: self._vec.stateIncrease() - else: - # No associated PETSc Vec if incompatible type: - # -> Equip Dat/Global with their own counter. - self._version = 0 - self._dat_version = lambda: self._version - - def _inc(): - self._version += 1 - self.increment_dat_version = _inc + """Mixin class providing access to arrays as PETSc Vecs.""" @property - def dat_version(self): - return self._dat_version() - + @mpi.collective @abc.abstractmethod - def vec_context(self, access): - pass - - @abc.abstractproperty - def _vec(self): + def vec(self): + """Return a read-write context manager for wrapping data with a PETSc Vec.""" pass @property @mpi.collective - def vec(self): - """Context manager for a PETSc Vec appropriate for this Dat. - - You're allowed to modify the data you get back from this view.""" - return self.vec_context(access=Access.RW) + @abc.abstractmethod + def vec_ro(self): + """Return a read-only context manager for wrapping data with a PETSc Vec.""" + pass @property @mpi.collective + @abc.abstractmethod def vec_wo(self): - """Context manager for a PETSc Vec appropriate for this Dat. - - You're allowed to modify the data you get back from this view, - but you cannot read from it.""" - return self.vec_context(access=Access.WRITE) - - @property - @mpi.collective - def vec_ro(self): - """Context manager for a PETSc Vec appropriate for this Dat. - - You're not allowed to modify the data you get back from this view.""" - return self.vec_context(access=Access.READ) + """Return a write-only context manager for wrapping data with a PETSc Vec.""" + pass diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 4e114032a..848aad633 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -10,6 +10,7 @@ mpi, utils ) +from pyop2.offload_utils import _backend as backend, OffloadingBackend from pyop2.types.set import ExtrudedSet, GlobalSet, MixedSet, Set, Subset @@ -188,9 +189,14 @@ def local_ises(self): @utils.cached_property def layout_vec(self): """A PETSc Vec compatible with the dof layout of this DataSet.""" + if backend == OffloadingBackend.CPU: + vec_type = PETSc.Vec.Type.STANDARD + else: + raise NotImplementedError + vec = PETSc.Vec().create(comm=self.comm) - size = (self.size * self.cdim, None) - vec.setSizes(size, bsize=self.cdim) + vec.setSizes(self.size*self.cdim, bsize=self.cdim) + vec.setType(vec_type) vec.setUp() return vec @@ -315,9 +321,11 @@ def local_ises(self): @utils.cached_property def layout_vec(self): """A PETSc Vec compatible with the dof layout of this DataSet.""" + from pyop2.op2 import compute_backend vec = PETSc.Vec().create(comm=self.comm) size = (self.size * self.cdim, None) vec.setSizes(size, bsize=self.cdim) + vec.setType(compute_backend.PETScVecType) vec.setUp() return vec @@ -463,10 +471,12 @@ def __repr__(self): @utils.cached_property def layout_vec(self): """A PETSc Vec compatible with the dof layout of this MixedDataSet.""" + from pyop2.op2 import compute_backend vec = PETSc.Vec().create(comm=self.comm) # Compute local and global size from sizes of layout vecs lsize, gsize = map(sum, zip(*(d.layout_vec.sizes for d in self))) vec.setSizes((lsize, gsize), bsize=1) + vec.setType(compute_backend.PETScVecType) vec.setUp() return vec diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 4bc67fed8..10c48cb5a 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -1,21 +1,25 @@ +import collections import contextlib import ctypes import operator +import numbers import numpy as np from petsc4py import PETSc from pyop2 import ( + datatypes, exceptions as ex, mpi, utils ) from pyop2.types.access import Access -from pyop2.types.dataset import GlobalDataSet -from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin +from pyop2.types.data_carrier import DataCarrier, VecAccessMixin +from pyop2.array import MirroredArray +from pyop2.offload_utils import _offloading as offloading -class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): +class Global(DataCarrier, VecAccessMixin): """OP2 global value. @@ -44,27 +48,43 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): self.__init__(dim._dim, None, dtype=dim.dtype, name="copy_of_%s" % dim.name, comm=dim.comm) dim.copy(self) + return + + dim = utils.as_tuple(dim, int) + + # TODO make this a function so can be shared by Dat and Global + # handle the interplay between dataset, data and dtype + if data is not None: + data = utils.verify_reshape(data, dtype, dim) + self._stashed_data = MirroredArray.new((data.shape, data.dtype)) else: - self._dim = utils.as_tuple(dim, int) - self._cdim = np.prod(self._dim).item() - EmptyDataMixin.__init__(self, data, dtype, self._dim) - self._buf = np.empty(self.shape, dtype=self.dtype) - self._name = name or "global_#x%x" % id(self) - if comm is None: - import warnings - warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!") - self.comm = mpi.internal_comm(comm) - # Object versioning setup - petsc_counter = (comm and self.dtype == PETSc.ScalarType) - VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + dtype = np.dtype(dtype or DataCarrier.DEFAULT_DTYPE) + data = (dim, dtype) + self._stashed_data = MirroredArray.new(data) + self._data = MirroredArray.new(data) + + # TODO shouldn't need to set these + self._dim = dim + self._cdim = np.prod(dim) + + self._name = name or "global_#x%x" % id(self) + self.comm = mpi.internal_comm(comm) def __del__(self): if hasattr(self, "comm"): mpi.decref(self.comm) - @utils.cached_property - def _kernel_args_(self): - return (self._data.ctypes.data, ) + @property + def kernel_args_rw(self): + return (self._data.ptr_rw,) + + @property + def kernel_args_ro(self): + return (self._data.ptr_ro,) + + @property + def kernel_args_wo(self): + return (self._data.ptr_wo,) @utils.cached_property def _argtypes_(self): @@ -97,43 +117,34 @@ def __getitem__(self, idx): def __str__(self): return "OP2 Global Argument: %s with dim %s and value %s" \ - % (self._name, self._dim, self._data) + % (self._name, self._dim, self.data_ro) def __repr__(self): - return "Global(%r, %r, %r, %r)" % (self._dim, self._data, - self._data.dtype, self._name) + return "Global(%r, %r, %r, %r)" % (self.shape, self.data_ro, + self.data.dtype, self._name) @utils.cached_property def dataset(self): - return GlobalDataSet(self) + from pyop2.op2 import compute_backend + return compute_backend.GlobalDataSet(self) @property def shape(self): - return self._dim + return self._data.shape @property def data(self): """Data array.""" - self.increment_dat_version() - if len(self._data) == 0: - raise RuntimeError("Illegal access: No data associated with this Global!") - return self._data + return self._data.data @property def dtype(self): - return self._dtype + return self._data.dtype @property def data_ro(self): """Data array.""" - view = self._data.view() - view.setflags(write=False) - return view - - @data.setter - def data(self, value): - self.increment_dat_version() - self._data[:] = utils.verify_reshape(value, self.dtype, self.dim) + return self._data.data_ro @property def data_with_halos(self): @@ -177,8 +188,7 @@ def copy(self, other, subset=None): @mpi.collective def zero(self, subset=None): assert subset is None - self.increment_dat_version() - self._data[...] = 0 + self._data.data[...] = 0 @mpi.collective def global_to_local_begin(self, access_mode): @@ -300,32 +310,96 @@ def inner(self, other): assert isinstance(other, Global) return np.dot(self.data_ro, np.conj(other.data_ro)) - @utils.cached_property - def _vec(self): - assert self.dtype == PETSc.ScalarType, \ - "Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType) - # Can't duplicate layout_vec of dataset, because we then - # carry around extra unnecessary data. - # But use getSizes to save an Allreduce in computing the - # global size. - data = self._data - size = self.dataset.layout_vec.getSizes() - if self.comm.rank == 0: - return PETSc.Vec().createWithArray(data, size=size, - bsize=self.cdim, - comm=self.comm) + @property + def dat_version(self): + return self._data.state + + @property + def petsc_vec(self): + if self._lazy_petsc_vec is None: + if self.dtype != datatypes.ScalarType: + raise ex.DataTypeError( + "Only arrays with dtype matching PETSc's scalar type can be " + "represented as a PETSc Vec") + + assert self.comm is not None + assert self._data._lazy_host_data is not None + + self._lazy_petsc_vec = PETSc.Vec().createWithArray( + self._data._lazy_host_data, size=self._data.size, + bsize=self.cdim, comm=self.comm) + + return self._lazy_petsc_vec + + @property + @contextlib.contextmanager + def vec(self): + """TODO""" + if offloading: + raise NotImplementedError("TODO") + # if offloading: + # if not self._data.is_available_on_device: + # self._data.host_to_device_copy() + # self.petsc_vec.bindToCPU(False) + # else: + # if not self._data.is_available_on_host: + # self._data.device_to_host_copy() + # self.petsc_vec.bindToCPU(True) + + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + if offloading: + self._data.is_available_on_host = False else: - return PETSc.Vec().createWithArray(np.empty(0, dtype=self.dtype), - size=size, - bsize=self.cdim, - comm=self.comm) + self._data.is_available_on_device = False + # not sure this is right + self.comm.Bcast(self._data._lazy_host_data, 0) @contextlib.contextmanager def vec_context(self, access): - """A context manager for a :class:`PETSc.Vec` from a :class:`Global`. + """A context manager for a :class:`PETSc.Vec` from a :class:`Global`.""" + + @property + @contextlib.contextmanager + def vec_ro(self): + """TODO""" + if offloading: + raise NotImplementedError("TODO") + if offloading: + if not self._data.is_available_on_device: + self._data.host_to_device_copy() + self.petsc_vec.bindToCPU(False) + else: + if not self._data.is_available_on_host: + self._data.device_to_host_copy() + self.petsc_vec.bindToCPU(True) + + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + @property + @contextlib.contextmanager + def vec_wo(self): + """TODO""" + if offloading: + raise NotImplementedError("TODO") + if offloading: + self.petsc_vec.bindToCPU(False) + else: + self.petsc_vec.bindToCPU(True) + + self.petsc_vec.stateSet(self._data.state) + yield self.petsc_vec + self._data.state = self.petsc_vec.stateGet() + + self.halo_valid = False + if offloading: + self._data.is_available_on_host = False + else: + self._data.is_available_on_device = False - :param access: Access descriptor: READ, WRITE, or RW.""" - yield self._vec - if access is not Access.READ: - data = self._data - self.comm.Bcast(data, 0) + # not sure this is right + self.comm.Bcast(self._data._lazy_host_data, 0) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 91224d52a..65ffbac79 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -12,6 +12,7 @@ ) from pyop2 import mpi from pyop2.types.set import GlobalSet, MixedSet, Set +from pyop2.array import MirroredArray class Map: @@ -33,23 +34,24 @@ class Map: @utils.validate_type(('iterset', Set, ex.SetTypeError), ('toset', Set, ex.SetTypeError), ('arity', numbers.Integral, ex.ArityTypeError), ('name', str, ex.NameTypeError)) - def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): + def __init__(self, iterset, toset, arity, values, name=None, offset=None, offset_quotient=None): + shape = (iterset.total_size, arity) + values = utils.verify_reshape(values, self.dtype, shape) + self._iterset = iterset self._toset = toset self.comm = mpi.internal_comm(toset.comm) self._arity = arity - self._values = utils.verify_reshape(values, dtypes.IntType, - (iterset.total_size, arity), allow_none=True) - self.shape = (iterset.total_size, arity) + self._values = MirroredArray.new(values) self._name = name or "map_#x%x" % id(self) if offset is None or len(offset) == 0: self._offset = None else: - self._offset = utils.verify_reshape(offset, dtypes.IntType, (arity, )) + self._offset = utils.verify_reshape(offset, self.dtype, (arity, )) if offset_quotient is None or len(offset_quotient) == 0: self._offset_quotient = None else: - self._offset_quotient = utils.verify_reshape(offset_quotient, dtypes.IntType, (arity, )) + self._offset_quotient = utils.verify_reshape(offset_quotient, self.dtype, (arity, )) # A cache for objects built on top of this map self._cache = {} @@ -57,9 +59,10 @@ def __del__(self): if hasattr(self, "comm"): mpi.decref(self.comm) - @utils.cached_property - def _kernel_args_(self): - return (self._values.ctypes.data, ) + @property + def kernel_args(self): + # we only ever read from maps + return (self._values.ptr_ro,) @utils.cached_property def _wrapper_cache_key_(self): @@ -120,22 +123,22 @@ def arange(self): """Tuple of arity offsets for each constituent :class:`Map`.""" return (0, self._arity) - @utils.cached_property + @property def values(self): """Mapping array. This only returns the map values for local points, to see the halo points too, use :meth:`values_with_halo`.""" - return self._values[:self.iterset.size] + return self._values.data[:self.iterset.size] - @utils.cached_property + @property def values_with_halo(self): """Mapping array. This returns all map values (including halo points), see :meth:`values` if you only need to look at the local points.""" - return self._values + return self._values.data @utils.cached_property def name(self): @@ -263,8 +266,8 @@ def __init__(self, *maps_, name=None): self.maps_ = tuple(maps_) @utils.cached_property - def _kernel_args_(self): - return tuple(itertools.chain(*[m._kernel_args_ for m in self.maps_])) + def kernel_args(self): + return tuple(itertools.chain(*[m.kernel_args for m in self.maps_])) @utils.cached_property def _wrapper_cache_key_(self): @@ -328,9 +331,9 @@ def _process_args(cls, *args, **kwargs): def _cache_key(cls, maps): return maps - @utils.cached_property - def _kernel_args_(self): - return tuple(itertools.chain(*(m._kernel_args_ for m in self if m is not None))) + @property + def kernel_args(self): + return tuple(itertools.chain(*(m.kernel_args for m in self if m is not None))) @utils.cached_property def _argtypes_(self): diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index aefd77de1..b1c402f6d 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -641,10 +641,13 @@ def __init__(self, *args, **kwargs): # Firedrake relies on this to distinguish between MatBlock and not for boundary conditions local_to_global_maps = (None, None) - @utils.cached_property - def _kernel_args_(self): + @property + def kernel_args_rw(self): return tuple(a.handle.handle for a in self) + kernel_args_ro = kernel_args_rw + kernel_args_wo = kernel_args_rw + @mpi.collective def _init(self): if not self.dtype == PETSc.ScalarType: diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 1f6ea30c8..28dfc5e83 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -11,6 +11,7 @@ mpi, utils ) +from pyop2.array import MirroredArray class Set: @@ -55,7 +56,9 @@ class Set: _extruded = False _extruded_periodic = False - _kernel_args_ = () + kernel_args_rw = () + kernel_args_ro = () + kernel_args_wo = () _argtypes_ = () @utils.cached_property @@ -163,6 +166,7 @@ def __call__(self, *indices): indices = indices[0] if np.isscalar(indices): indices = [indices] + return Subset(self, indices) def __contains__(self, dset): @@ -206,7 +210,7 @@ def difference(self, other): if other is self: return Subset(self, []) else: - return type(other)(self, np.setdiff1d(np.asarray(range(self.total_size), dtype=dtypes.IntType), other._indices)) + return type(other)(self, np.setdiff1d(np.asarray(range(self.total_size), dtype=dtypes.IntType), other.indices)) def symmetric_difference(self, other): self._check_operands(other) @@ -221,7 +225,9 @@ class GlobalSet(Set): """A proxy set allowing a :class:`Global` to be used in place of a :class:`Dat` where appropriate.""" - _kernel_args_ = () + kernel_args_rw = () + kernel_args_ro = () + kernel_args_wo = () _argtypes_ = () def __init__(self, comm=None): @@ -329,13 +335,35 @@ def __init__(self, parent, layers, extruded_periodic=False): layers = np.asarray([[0, layers]], dtype=dtypes.IntType) self.constant_layers = True - self._layers = layers + self._layers_array = MirroredArray.new(layers) self._extruded = True self._extruded_periodic = extruded_periodic + self._sizes = parent._sizes + self._cache = parent._cache - @utils.cached_property - def _kernel_args_(self): - return (self.layers_array.ctypes.data, ) + @property + def comm(self): + return self._parent.comm + + @property + def _name(self): + return self._parent._name + + @property + def _layers(self): + return self._layers_array.data + + @property + def kernel_args_rw(self): + return (self._layers_array.ptr_rw,) + + @property + def kernel_args_ro(self): + return (self._layers_array.ptr_ro,) + + @property + def kernel_args_wo(self): + return (self._layers_array.ptr_wo,) @utils.cached_property def _argtypes_(self): @@ -402,23 +430,42 @@ def __init__(self, superset, indices): assert type(superset) is Set or type(superset) is ExtrudedSet, \ 'Subset construction failed, should not happen' - self._superset = superset - self._indices = utils.verify_reshape(indices, dtypes.IntType, (len(indices),)) + indices = utils.verify_reshape(indices, dtypes.IntType, (len(indices),)) - if len(self._indices) > 0 and (self._indices[0] < 0 or self._indices[-1] >= self._superset.total_size): + if len(indices) > 0 and (indices[0] < 0 or indices[-1] >= superset.total_size): raise ex.SubsetIndexOutOfBounds( 'Out of bounds indices in Subset construction: [%d, %d) not [0, %d)' % - (self._indices[0], self._indices[-1], self._superset.total_size)) + (indices[0], indices[-1], superset.total_size)) - self._sizes = ((self._indices < superset.core_size).sum(), - (self._indices < superset.size).sum(), - len(self._indices)) + self._superset = superset + self._indices = MirroredArray.new(indices) + self._sizes = ((indices < superset.core_size).sum(), + (indices < superset.size).sum(), + len(indices)) self._extruded = superset._extruded self._extruded_periodic = superset._extruded_periodic + self.constant_layers = superset.constant_layers if hasattr(superset, "constant_layers") else None + self._cache = superset._cache - @utils.cached_property - def _kernel_args_(self): - return self._superset._kernel_args_ + (self._indices.ctypes.data, ) + @property + def name(self): + return self._superset.name + + @property + def comm(self): + return self._superset.comm + + @property + def kernel_args_rw(self): + return self._superset.kernel_args_rw + (self._indices.ptr_rw,) + + @property + def kernel_args_ro(self): + return self._superset.kernel_args_ro + (self._indices.ptr_ro,) + + @property + def kernel_args_wo(self): + return self._superset.kernel_args_wo + (self._indices.ptr_wo,) @utils.cached_property def _argtypes_(self): @@ -439,7 +486,7 @@ def __str__(self): (self._name, self._sizes) def __repr__(self): - return "Subset(%r, %r)" % (self._superset, self._indices) + return "Subset(%r, %r)" % (self._superset, self.indices) def __call__(self, *indices): """Build a :class:`Subset` from this :class:`Subset` @@ -454,24 +501,24 @@ def __call__(self, *indices): indices = [indices] return Subset(self, indices) - @utils.cached_property + @property def superset(self): """Returns the superset Set""" return self._superset - @utils.cached_property + @property def indices(self): """Returns the indices pointing in the superset.""" - return self._indices + return self._indices.data - @utils.cached_property + @property def owned_indices(self): """Return the indices that correspond to the owned entities of the superset. """ return self.indices[self.indices < self.superset.size] - @utils.cached_property + @property def layers_array(self): if self._superset.constant_layers: return self._superset.layers_array @@ -493,28 +540,28 @@ def intersection(self, other): if other is self._superset: return self else: - return type(self)(self._superset, np.intersect1d(self._indices, other._indices)) + return type(self)(self._superset, np.intersect1d(self.indices, other.indices)) def union(self, other): self._check_operands(other) if other is self._superset: return other else: - return type(self)(self._superset, np.union1d(self._indices, other._indices)) + return type(self)(self._superset, np.union1d(self.indices, other.indices)) def difference(self, other): self._check_operands(other) if other is self._superset: return Subset(other, []) else: - return type(self)(self._superset, np.setdiff1d(self._indices, other._indices)) + return type(self)(self._superset, np.setdiff1d(self.indices, other.indices)) def symmetric_difference(self, other): self._check_operands(other) if other is self._superset: return other.symmetric_difference(self) else: - return type(self)(self._superset, np.setxor1d(self._indices, other._indices)) + return type(self)(self._superset, np.setxor1d(self.indices, other.indices)) class SetPartition: diff --git a/setup.py b/setup.py index a271e2415..5ed54b7fe 100644 --- a/setup.py +++ b/setup.py @@ -142,7 +142,8 @@ def run(self): ], install_requires=install_requires + test_requires, dependency_links=dep_links, - packages=['pyop2', 'pyop2.codegen', 'pyop2.types'], + packages=['pyop2', 'pyop2.backends', 'pyop2.codegen', 'pyop2.types', + 'pyop2.transforms'], package_data={ 'pyop2': ['assets/*', '*.h', '*.pxd', '*.pyx', 'codegen/c/*.c']}, scripts=glob('scripts/*'), diff --git a/test/unit/test_api.py b/test/unit/test_api.py index 8ec4cf2ab..fbb308736 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -825,7 +825,8 @@ def test_dat_ro_write_accessor(self, dat): def test_dat_lazy_allocation(self, dset): "Temporary Dats should not allocate storage until accessed." d = op2.Dat(dset) - assert not d._is_allocated + assert not d._data._lazy_host_data + assert not d._data._lazy_device_data def test_dat_zero_cdim(self, set): "A Dat built on a DataSet with zero dim should be allowed." @@ -1284,16 +1285,6 @@ def test_global_properties(self): assert g.dim == (2, 2) and g.dtype == np.float64 and g.name == 'bar' \ and g.data.sum() == 4 - def test_global_setter(self, g): - "Setter attribute on data should correct set data value." - g.data = 2 - assert g.data.sum() == 2 - - def test_global_setter_malformed_data(self, g): - "Setter attribute should reject malformed data." - with pytest.raises(exceptions.DataValueError): - g.data = [1, 2] - def test_global_iter(self, g): "Global should be iterable and yield self." for g_ in g: diff --git a/test/unit/test_global_reduction.py b/test/unit/test_global_reduction.py index fa2258924..2f281b538 100644 --- a/test/unit/test_global_reduction.py +++ b/test/unit/test_global_reduction.py @@ -400,7 +400,7 @@ def test_1d_multi_inc_same_global_reset(self, k1_inc_to_global, set, d1): d1(op2.READ)) assert g.data == d1.data.sum() - g.data = 10 + g.data[...] = 10 op2.par_loop(k1_inc_to_global, set, g(op2.INC), d1(op2.READ)) diff --git a/test/unit/test_globals.py b/test/unit/test_globals.py index b7adf57c6..2a0a99cb9 100644 --- a/test/unit/test_globals.py +++ b/test/unit/test_globals.py @@ -72,7 +72,7 @@ def test_global_dat_version(): assert g2.dat_version == 1 # Access data setter - g2.data = d1 + g2.data[...] = d1 assert g1.dat_version == 2 assert g2.dat_version == 2 diff --git a/test/unit/test_indirect_loop.py b/test/unit/test_indirect_loop.py index ab77d182b..fed56b0b8 100644 --- a/test/unit/test_indirect_loop.py +++ b/test/unit/test_indirect_loop.py @@ -112,22 +112,16 @@ def test_mismatching_iterset(self, iterset, indset, x): the par_loop's should raise an exception.""" with pytest.raises(MapValueError): op2.par_loop(op2.Kernel("", "dummy"), iterset, - x(op2.WRITE, op2.Map(op2.Set(nelems), indset, 1))) + x(op2.WRITE, op2.Map(op2.Set(nelems), indset, 1, + np.zeros(nelems)))) def test_mismatching_indset(self, iterset, x): """Accessing a par_loop argument via a Map with toset not matching the Dat's should raise an exception.""" with pytest.raises(MapValueError): op2.par_loop(op2.Kernel("", "dummy"), iterset, - x(op2.WRITE, op2.Map(iterset, op2.Set(nelems), 1))) - - def test_uninitialized_map(self, iterset, indset, x): - """Accessing a par_loop argument via an uninitialized Map should raise - an exception.""" - kernel_wo = "static void wo(unsigned int* x) { *x = 42; }\n" - with pytest.raises(MapValueError): - op2.par_loop(op2.Kernel(kernel_wo, "wo"), iterset, - x(op2.WRITE, op2.Map(iterset, indset, 1))) + x(op2.WRITE, op2.Map(iterset, op2.Set(nelems), 1, + np.zeros(nelems)))) def test_onecolor_wo(self, iterset, x, iterset2indset): """Set a Dat to a scalar value with op2.WRITE.""" diff --git a/test/unit/test_linalg.py b/test/unit/test_linalg.py index 5fce55d0e..b7587a055 100644 --- a/test/unit/test_linalg.py +++ b/test/unit/test_linalg.py @@ -84,19 +84,19 @@ class TestLinAlgOp: """ def test_add(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x + y).data == 3 * y.data) def test_sub(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x - y).data == y.data) def test_mul(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x * y).data == 2 * y.data * y.data) def test_div(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x / y).data == 2.0) def test_add_shape_mismatch(self, x2, y2): @@ -116,11 +116,11 @@ def test_div_shape_mismatch(self, x2, y2): x2 / y2 def test_add_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 assert all(x.data == (y + 1.0).data) def test_radd_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 assert all(x.data == (1.0 + y).data) def test_pos_copies(self, y): @@ -134,23 +134,23 @@ def test_neg_copies(self, y): assert z is not y def test_sub_scalar(self, x, y): - x._data = y.data - 1.0 + x.data[...] = y.data - 1.0 assert all(x.data == (y - 1.0).data) def test_rsub_scalar(self, x, y): - x._data = 1.0 - y.data + x.data[...] = 1.0 - y.data assert all(x.data == (1.0 - y).data) def test_mul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all(x.data == (y * 2.0).data) def test_rmul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all(x.data == (2.0 * y).data) def test_div_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x / 2.0).data == y.data) def test_add_ftype(self, y, yi): @@ -187,7 +187,7 @@ def test_div_itype(self, y, yi): def test_linalg_and_parloop(self, x, y): """Linear algebra operators should force computation""" - x._data = np.zeros(x.dataset.total_size, dtype=np.float64) + x.data[...] = np.zeros(x.dataset.total_size, dtype=np.float64) k = op2.Kernel('static void k(double *x) { *x = 1.0; }', 'k') op2.par_loop(k, x.dataset.set, x(op2.WRITE)) z = x + y @@ -201,22 +201,22 @@ class TestLinAlgIop: """ def test_iadd(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x += y assert all(x.data == 3 * y.data) def test_isub(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x -= y assert all(x.data == y.data) def test_imul(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x *= y assert all(x.data == 2 * y.data * y.data) def test_idiv(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x /= y assert all(x.data == 2.0) @@ -237,22 +237,22 @@ def test_idiv_shape_mismatch(self, x2, y2): x2 /= y2 def test_iadd_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 y += 1.0 assert all(x.data == y.data) def test_isub_scalar(self, x, y): - x._data = y.data - 1.0 + x.data[...] = y.data - 1.0 y -= 1.0 assert all(x.data == y.data) def test_imul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data y *= 2.0 assert all(x.data == y.data) def test_idiv_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x /= 2.0 assert all(x.data == y.data) diff --git a/test/unit/test_linalg_complex.py b/test/unit/test_linalg_complex.py index f7ee2f4cd..eed8cc1bd 100644 --- a/test/unit/test_linalg_complex.py +++ b/test/unit/test_linalg_complex.py @@ -94,30 +94,30 @@ class TestLinAlgOp: """ def test_add(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x + y).data == 3 * y.data) def test_sub(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x - y).data == y.data) def test_mul_complex(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data assert all((x * y).data == (2+2j) * y.data * y.data) def test_div_complex(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data # Note complex division does not have the same stability as # floating point when vectorised assert all(x.data / y.data == 2.0+2.j) assert np.allclose((x / y).data, 2.0+2.j) def test_mul(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x * y).data == 2 * y.data * y.data) def test_div(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x.data / y.data # Note complex division does not have the same stability as # floating point when vectorised @@ -141,19 +141,19 @@ def test_div_shape_mismatch(self, x2, y2): x2 / y2 def test_add_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 assert all(x.data == (y + 1.0).data) def test_radd_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 assert all(x.data == (1.0 + y).data) def test_add_complex_scalar(self, x, y): - x._data = y.data + (1.0+1.j) + x.data[...] = y.data + (1.0+1.j) assert all(x.data == (y + (1.0+1.j)).data) def test_radd_complex_scalar(self, x, y): - x._data = y.data + (1.0+1.j) + x.data[...] = y.data + (1.0+1.j) assert all(x.data == ((1.0+1.j) + y).data) def test_pos_copies(self, y): @@ -167,39 +167,39 @@ def test_neg_copies(self, y): assert z is not y def test_sub_scalar(self, x, y): - x._data = y.data - 1.0 + x.data[...] = y.data - 1.0 assert all(x.data == (y - 1.0).data) def test_rsub_scalar(self, x, y): - x._data = 1.0 - y.data + x.data[...] = 1.0 - y.data assert all(x.data == (1.0 - y).data) def test_mul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all(x.data == (y * 2.0).data) def test_rmul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all(x.data == (2.0 * y).data) def test_sub_complex_scalar(self, x, y): - x._data = y.data - (1.0+1.j) + x.data[...] = y.data - (1.0+1.j) assert all(x.data == (y - (1.0+1.j)).data) def test_rsub_complex_scalar(self, x, y): - x._data = (1.0+1.j) - y.data + x.data[...] = (1.0+1.j) - y.data assert all(x.data == ((1.0+1.j) - y).data) def test_mul_complex_scalar(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data assert all(x.data == (y * (2.0+2.j)).data) def test_rmul_complex_scalar(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data assert all(x.data == ((2.0+2.j) * y).data) def test_div_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data assert all((x / 2.0).data == y.data) def test_add_ftype(self, y, yf): @@ -252,7 +252,7 @@ def test_div_itype(self, y, yi): def test_linalg_and_parloop(self, x, y): """Linear algebra operators should force computation""" - x._data = np.zeros(x.dataset.total_size, dtype=np.complex128) + x.data[...] = np.zeros(x.dataset.total_size, dtype=np.complex128) k = op2.Kernel('static void k(complex double *x) { *x = 1.0+1.0*I; }', 'k') op2.par_loop(k, x.dataset.set, x(op2.WRITE)) z = x + y @@ -266,22 +266,22 @@ class TestLinAlgIop: """ def test_iadd(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x += y assert all(x.data == 3 * y.data) def test_isub(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x -= y assert all(x.data == y.data) def test_imul(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x *= y assert all(x.data == 2 * y.data * y.data) def test_idiv(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x /= y # Note complex division does not have the same stability as # floating point when vectorised @@ -304,42 +304,42 @@ def test_idiv_shape_mismatch(self, x2, y2): x2 /= y2 def test_iadd_scalar(self, x, y): - x._data = y.data + 1.0 + x.data[...] = y.data + 1.0 y += 1.0 assert all(x.data == y.data) def test_isub_scalar(self, x, y): - x._data = y.data - 1.0 + x.data[...] = y.data - 1.0 y -= 1.0 assert all(x.data == y.data) def test_imul_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data y *= 2.0 assert all(x.data == y.data) def test_idiv_scalar(self, x, y): - x._data = 2 * y.data + x.data[...] = 2 * y.data x /= 2.0 assert all(x.data == y.data) def test_iadd_complex_scalar(self, x, y): - x._data = y.data + (1.0+1.j) + x.data[...] = y.data + (1.0+1.j) y += (1.0+1.j) assert all(x.data == y.data) def test_isub_complex_scalar(self, x, y): - x._data = y.data - (1.0+1.j) + x.data[...] = y.data - (1.0+1.j) y -= (1.0+1.j) assert all(x.data == y.data) def test_imul_complex_scalar(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data y *= (2.0+2.j) assert all(x.data == y.data) def test_idiv_complex_scalar(self, x, y): - x._data = (2+2j) * y.data + x.data[...] = (2+2j) * y.data x /= (2.0+2j) assert all(x.data == y.data) diff --git a/test/unit/test_matrices.py b/test/unit/test_matrices.py index 34b467e21..1e6b39a84 100644 --- a/test/unit/test_matrices.py +++ b/test/unit/test_matrices.py @@ -567,13 +567,6 @@ class TestSparsity: Sparsity tests """ - def test_sparsity_null_maps(self): - """Building sparsity from a pair of non-initialized maps should fail.""" - s = op2.Set(5) - with pytest.raises(MapValueError): - m = op2.Map(s, s, 1) - op2.Sparsity((s, s), (m, m)) - def test_sparsity_has_diagonal_space(self): # A sparsity should have space for diagonal entries if rmap==cmap s = op2.Set(1) diff --git a/test/unit/test_opencl_offloading.py b/test/unit/test_opencl_offloading.py new file mode 100644 index 000000000..b39f11a1c --- /dev/null +++ b/test/unit/test_opencl_offloading.py @@ -0,0 +1,119 @@ +import pytest +pytest.importorskip("pyopencl") + +import sys +import petsc4py +petsc4py.init(sys.argv + + "-viennacl_backend opencl".split() + + "-viennacl_opencl_device_type cpu".split()) +from pyop2 import op2 +import pyopencl.array as cla +import numpy as np +import loopy as lp +lp.set_caching_enabled(False) + + +def allclose(a, b, rtol=1e-05, atol=1e-08): + """ + Prefer this routine over np.allclose(...) to allow pycuda/pyopencl arrays + """ + return bool(abs(a - b) < (atol + rtol * abs(b))) + + +def pytest_generate_tests(metafunc): + if "backend" in metafunc.fixturenames: + from pyop2.backends.opencl import opencl_backend + metafunc.parametrize("backend", [opencl_backend]) + + +def test_new_backend_raises_not_implemented_error(): + from pyop2.backends import AbstractComputeBackend + unimplemented_backend = AbstractComputeBackend() + + attrs = ["GlobalKernel", "Parloop", "Set", "ExtrudedSet", "MixedSet", + "Subset", "DataSet", "MixedDataSet", "Map", "MixedMap", "Dat", + "MixedDat", "DatView", "Mat", "Global", "GlobalDataSet", + "PETScVecType"] + + for attr in attrs: + with pytest.raises(NotImplementedError): + getattr(unimplemented_backend, attr) + + +def test_dat_with_petscvec_representation(backend): + op2.set_offloading_backend(backend) + + nelems = 9 + data = np.random.rand(nelems) + set_ = op2.compute_backend.Set(nelems) + dset = op2.compute_backend.DataSet(set_, 1) + dat = op2.compute_backend.Dat(dset, data.copy()) + + assert isinstance(dat.data_ro, np.ndarray) + dat.data[:] *= 3 + + with op2.offloading(): + assert isinstance(dat.data_ro, cla.Array) + dat.data[:] *= 2 + + assert isinstance(dat.data_ro, np.ndarray) + np.testing.assert_allclose(dat.data_ro, 6*data) + + +def test_dat_not_as_petscvec(backend): + op2.set_offloading_backend(backend) + + nelems = 9 + data = np.random.randint(low=-10, high=10, + size=nelems, + dtype=np.int64) + set_ = op2.compute_backend.Set(nelems) + dset = op2.compute_backend.DataSet(set_, 1) + dat = op2.compute_backend.Dat(dset, data.copy()) + + assert isinstance(dat.data_ro, np.ndarray) + dat.data[:] *= 3 + + with op2.offloading(): + assert isinstance(dat.data_ro, cla.Array) + dat.data[:] *= 2 + + assert isinstance(dat.data_ro, np.ndarray) + np.testing.assert_allclose(dat.data_ro, 6*data) + + +def test_global_reductions(backend): + op2.set_offloading_backend(backend) + + sum_knl = lp.make_function( + "{ : }", + """ + g[0] = g[0] + x[0] + """, + [lp.GlobalArg("g,x", + dtype="float64", + shape=lp.auto, + )], + name="sum_knl", + target=lp.CWithGNULibcTarget(), + lang_version=(2018, 2)) + + rng = np.random.default_rng() + nelems = 4_000 + data_to_sum = rng.random(nelems) + + with op2.offloading(): + + set_ = op2.compute_backend.Set(4_000) + dset = op2.compute_backend.DataSet(set_, 1) + + dat = op2.compute_backend.Dat(dset, data_to_sum.copy()) + glob = op2.compute_backend.Global(1, 0, np.float64, "g") + + op2.parloop(op2.Kernel(sum_knl, "sum_knl"), + set_, + glob(op2.INC), + dat(op2.READ)) + + assert isinstance(glob.data_ro, cla.Array) + assert allclose(glob.data_ro[0], data_to_sum.sum()) diff --git a/test/unit/test_petsc.py b/test/unit/test_petsc.py index 57068a7aa..0d5dbe7ca 100644 --- a/test/unit/test_petsc.py +++ b/test/unit/test_petsc.py @@ -74,10 +74,10 @@ def test_mixed_vec_access(self): assert np.allclose(v.array_r, [1.0, 2.0]) d.data[0][:] = 0.0 - d.data[0][:] = 0.0 + d.data[1][:] = 0.0 with d.vec_wo as v: - assert np.allclose(v.array_r, [1.0, 2.0]) + assert np.allclose(v.array_r, [0.0, 0.0]) v.array[:] = 1 assert d.data[0][0] == 1 diff --git a/test/unit/test_subset.py b/test/unit/test_subset.py index a2c2f9f9b..603c852bf 100644 --- a/test/unit/test_subset.py +++ b/test/unit/test_subset.py @@ -270,8 +270,8 @@ def test_set_set_operations(self): s = a.symmetric_difference(a) assert u is a assert i is a - assert d._indices.size == 0 - assert s._indices.size == 0 + assert d.indices.size == 0 + assert s.indices.size == 0 def test_set_subset_operations(self): """Test standard set operations between a set and a subset""" @@ -283,8 +283,8 @@ def test_set_subset_operations(self): s = a.symmetric_difference(b) assert u is a assert i is b - assert (d._indices == [0, 1, 4, 6, 8, 9]).all() - assert (s._indices == d._indices).all() + assert (d.indices == [0, 1, 4, 6, 8, 9]).all() + assert (s.indices == d.indices).all() def test_subset_set_operations(self): """Test standard set operations between a subset and a set""" @@ -296,8 +296,8 @@ def test_subset_set_operations(self): s = b.symmetric_difference(a) assert u is a assert i is b - assert d._indices.size == 0 - assert (s._indices == [0, 1, 4, 6, 8, 9]).all() + assert d.indices.size == 0 + assert (s.indices == [0, 1, 4, 6, 8, 9]).all() def test_subset_subset_operations(self): """Test standard set operations between two subsets""" @@ -308,7 +308,7 @@ def test_subset_subset_operations(self): i = b.intersection(c) d = b.difference(c) s = b.symmetric_difference(c) - assert (u._indices == [2, 3, 4, 5, 6, 7, 8]).all() - assert (i._indices == [2, ]).all() - assert (d._indices == [3, 5, 7]).all() - assert (s._indices == [3, 4, 5, 6, 7, 8]).all() + assert (u.indices == [2, 3, 4, 5, 6, 7, 8]).all() + assert (i.indices == [2, ]).all() + assert (d.indices == [3, 5, 7]).all() + assert (s.indices == [3, 4, 5, 6, 7, 8]).all()