From fd88cb7af9649e76886a5fd1674935e1f02bdd88 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 14 May 2023 20:24:56 +0200 Subject: [PATCH 1/2] Add PyFixedString and PyFixedUnicode implementors of Element to support Unicode arrays whose element length is know at compile time. --- CHANGELOG.md | 1 + src/datetime.rs | 10 ++- src/lib.rs | 2 + src/strings.rs | 230 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/array.rs | 151 ++++++++++++++++++++++++++++++- 5 files changed, 389 insertions(+), 5 deletions(-) create mode 100644 src/strings.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e5f65c1e..c7a7202bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog - Unreleased + - Add support for ASCII (`PyFixedString`) and Unicode (`PyFixedUnicode`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378)) - v0.19.0 - Add `PyUntypedArray` as an untyped base type for `PyArray` which can be used to inspect arguments before more targeted downcasts. This is accompanied by some methods like `dtype` and `shape` moving from `PyArray` to `PyUntypedArray`. They are still accessible though, as `PyArray` dereferences to `PyUntypedArray` via the `Deref` trait. ([#369](https://github.com/PyO3/rust-numpy/pull/369)) diff --git a/src/datetime.rs b/src/datetime.rs index bd5ef54e0..f5df21211 100644 --- a/src/datetime.rs +++ b/src/datetime.rs @@ -223,8 +223,8 @@ impl TypeDescriptors { fn from_unit<'py>(&'py self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> &'py PyArrayDescr { let mut dtypes = self.dtypes.get(py).borrow_mut(); - match dtypes.get_or_insert_with(Default::default).entry(unit) { - Entry::Occupied(entry) => entry.into_mut().clone().into_ref(py), + let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) { + Entry::Occupied(entry) => entry.into_mut(), Entry::Vacant(entry) => { let dtype = PyArrayDescr::new_from_npy_type(py, self.npy_type); @@ -237,9 +237,11 @@ impl TypeDescriptors { metadata.meta.num = 1; } - entry.insert(dtype.into()).clone().into_ref(py) + entry.insert(dtype.into()) } - } + }; + + dtype.clone().into_ref(py) } } diff --git a/src/lib.rs b/src/lib.rs index 80a4a0dbd..04ab7efe0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,6 +82,7 @@ mod dtype; mod error; pub mod npyffi; mod slice_container; +mod strings; mod sum_products; mod untyped_array; @@ -105,6 +106,7 @@ pub use crate::convert::{IntoPyArray, NpyIndex, ToNpyDims, ToPyArray}; pub use crate::dtype::{dtype, Complex32, Complex64, Element, PyArrayDescr}; pub use crate::error::{BorrowError, FromVecError, NotContiguousError}; pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; +pub use crate::strings::{PyFixedString, PyFixedUnicode}; pub use crate::sum_products::{dot, einsum, inner}; pub use crate::untyped_array::PyUntypedArray; diff --git a/src/strings.rs b/src/strings.rs new file mode 100644 index 000000000..f76cc01b4 --- /dev/null +++ b/src/strings.rs @@ -0,0 +1,230 @@ +//! Types to support arrays of [ASCII][ascii] and [UCS4][ucs4] strings +//! +//! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING +//! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE + +use std::cell::RefCell; +use std::collections::hash_map::Entry; +use std::convert::TryInto; +use std::fmt; +use std::mem::size_of; +use std::os::raw::c_char; +use std::str; + +use pyo3::{ + ffi::{Py_UCS1, Py_UCS4}, + sync::GILProtected, + Py, Python, +}; +use rustc_hash::FxHashMap; + +use crate::dtype::{Element, PyArrayDescr}; +use crate::npyffi::NPY_TYPES; + +/// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence. +/// +/// Note that when creating arrays of ASCII strings without an explicit `dtype`, +/// NumPy will automatically determine the smallest possible array length at runtime. +/// +/// For example, +/// +/// ```python +/// array = numpy.array([b"foo", b"bar", b"foobar"]) +/// ``` +/// +/// yields `S6` for `array.dtype`. +/// +/// On the Rust side however, the length `N` of `PyFixedString` must always be given +/// explicitly and as a compile-time constant. For this work reliably, the Python code +/// should set the `dtype` explicitly, e.g. +/// +/// ```python +/// numpy.array([b"foo", b"bar", b"foobar"], dtype='S12') +/// ``` +/// +/// always matching `PyArray1>`. +/// +/// # Example +/// +/// ```rust +/// # use pyo3::Python; +/// use numpy::{PyArray1, PyFixedString}; +/// +/// # Python::with_gil(|py| { +/// let array = PyArray1::>::from_vec(py, vec![[b'f', b'o', b'o'].into()]); +/// +/// assert!(array.dtype().to_string().contains("S3")); +/// # }); +/// ``` +/// +/// [numpy-bytes]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_ +#[repr(transparent)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PyFixedString(pub [Py_UCS1; N]); + +impl fmt::Display for PyFixedString { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.write_str(str::from_utf8(&self.0).unwrap().trim_end_matches('\0')) + } +} + +impl From<[Py_UCS1; N]> for PyFixedString { + fn from(val: [Py_UCS1; N]) -> Self { + Self(val) + } +} + +unsafe impl Element for PyFixedString { + const IS_COPY: bool = true; + + fn get_dtype(py: Python) -> &PyArrayDescr { + static DTYPES: TypeDescriptors = TypeDescriptors::new(); + + unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::()) } + } +} + +/// A newtype wrapper around [`[PyUCS4; N]`][Py_UCS4] to handle [`str_` scalars][numpy-str] while satisfying coherence. +/// +/// Note that when creating arrays of Unicode strings without an explicit `dtype`, +/// NumPy will automatically determine the smallest possible array length at runtime. +/// +/// For example, +/// +/// ```python +/// numpy.array(["foo🐍", "bar🦀", "foobar"]) +/// ``` +/// +/// yields `U6` for `array.dtype`. +/// +/// On the Rust side however, the length `N` of `PyFixedUnicode` must always be given +/// explicitly and as a compile-time constant. For this work reliably, the Python code +/// should set the `dtype` explicitly, e.g. +/// +/// ```python +/// numpy.array(["foo🐍", "bar🦀", "foobar"], dtype='U12') +/// ``` +/// +/// always matching `PyArray1>`. +/// +/// # Example +/// +/// ```rust +/// # use pyo3::Python; +/// use numpy::{PyArray1, PyFixedUnicode}; +/// +/// # Python::with_gil(|py| { +/// let array = PyArray1::>::from_vec(py, vec![[b'b' as _, b'a' as _, b'r' as _].into()]); +/// +/// assert!(array.dtype().to_string().contains("U3")); +/// # }); +/// ``` +/// +/// [numpy-str]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.str_ +#[repr(transparent)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct PyFixedUnicode(pub [Py_UCS4; N]); + +impl fmt::Display for PyFixedUnicode { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + for character in self.0 { + if character == 0 { + break; + } + + write!(fmt, "{}", char::from_u32(character).unwrap())?; + } + + Ok(()) + } +} + +impl From<[Py_UCS4; N]> for PyFixedUnicode { + fn from(val: [Py_UCS4; N]) -> Self { + Self(val) + } +} + +unsafe impl Element for PyFixedUnicode { + const IS_COPY: bool = true; + + fn get_dtype(py: Python) -> &PyArrayDescr { + static DTYPES: TypeDescriptors = TypeDescriptors::new(); + + unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::()) } + } +} + +struct TypeDescriptors { + #[allow(clippy::type_complexity)] + dtypes: GILProtected>>>>, +} + +impl TypeDescriptors { + const fn new() -> Self { + Self { + dtypes: GILProtected::new(RefCell::new(None)), + } + } + + /// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size` + #[allow(clippy::wrong_self_convention)] + unsafe fn from_size<'py>( + &'py self, + py: Python<'py>, + npy_type: NPY_TYPES, + byteorder: c_char, + size: usize, + ) -> &'py PyArrayDescr { + let mut dtypes = self.dtypes.get(py).borrow_mut(); + + let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) { + Entry::Occupied(entry) => entry.into_mut(), + Entry::Vacant(entry) => { + let dtype = PyArrayDescr::new_from_npy_type(py, npy_type); + + let descr = &mut *dtype.as_dtype_ptr(); + descr.elsize = size.try_into().unwrap(); + descr.byteorder = byteorder; + + entry.insert(dtype.into()) + } + }; + + dtype.clone().into_ref(py) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn format_fixed_string() { + assert_eq!( + PyFixedString([b'f', b'o', b'o', 0, 0, 0]).to_string(), + "foo" + ); + assert_eq!( + PyFixedString([b'f', b'o', b'o', b'b', b'a', b'r']).to_string(), + "foobar" + ); + } + + #[test] + fn format_fixed_unicode() { + assert_eq!( + PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]).to_string(), + "foo" + ); + assert_eq!( + PyFixedUnicode([0x1F980, 0x1F40D, 0, 0, 0, 0]).to_string(), + "🦀🐍" + ); + assert_eq!( + PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _]) + .to_string(), + "foobar" + ); + } +} diff --git a/tests/array.rs b/tests/array.rs index 6476cc293..6cfa8ac63 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -5,7 +5,7 @@ use half::f16; use ndarray::{array, s, Array1, Dim}; use numpy::{ dtype, get_array_module, npyffi::NPY_ORDER, pyarray, PyArray, PyArray1, PyArray2, PyArrayDescr, - PyArrayDyn, ToPyArray, + PyArrayDyn, PyFixedString, PyFixedUnicode, ToPyArray, }; use pyo3::{ py_run, pyclass, pymethods, @@ -562,3 +562,152 @@ fn half_works() { ); }); } + +#[test] +fn ascii_strings_with_explicit_dtype_works() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + let locals = [("np", np)].into_py_dict(py); + + let array = py + .eval( + "np.array([b'foo', b'bar', b'foobar'], dtype='S6')", + None, + Some(locals), + ) + .unwrap() + .downcast::>>() + .unwrap(); + + { + let array = array.readonly(); + let array = array.as_array(); + + assert_eq!(array[0].0, [b'f', b'o', b'o', 0, 0, 0]); + assert_eq!(array[1].0, [b'b', b'a', b'r', 0, 0, 0]); + assert_eq!(array[2].0, [b'f', b'o', b'o', b'b', b'a', b'r']); + } + + { + let mut array = array.readwrite(); + let mut array = array.as_array_mut(); + + array[2].0[5] = b'z'; + } + + py_run!(py, array np, "assert array[2] == b'foobaz'"); + }); +} + +#[test] +fn unicode_strings_with_explicit_dtype_works() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + let locals = [("np", np)].into_py_dict(py); + + let array = py + .eval( + "np.array(['foo', 'bar', 'foobar'], dtype='U6')", + None, + Some(locals), + ) + .unwrap() + .downcast::>>() + .unwrap(); + + { + let array = array.readonly(); + let array = array.as_array(); + + assert_eq!(array[0].0, [b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]); + assert_eq!(array[1].0, [b'b' as _, b'a' as _, b'r' as _, 0, 0, 0]); + assert_eq!( + array[2].0, + [b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _] + ); + } + + { + let mut array = array.readwrite(); + let mut array = array.as_array_mut(); + + array[2].0[5] = b'z' as _; + } + + py_run!(py, array np, "assert array[2] == 'foobaz'"); + }); +} + +#[test] +fn ascii_strings_ignore_byteorder() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + let locals = [("np", np)].into_py_dict(py); + + let native_endian_works = py + .eval( + "np.array([b'foo', b'bar'], dtype='=S3')", + None, + Some(locals), + ) + .unwrap() + .downcast::>>() + .is_ok(); + + let little_endian_works = py + .eval( + "np.array(['bfoo', b'bar'], dtype='>>() + .is_ok(); + + let big_endian_works = py + .eval( + "np.array([b'foo', b'bar'], dtype='>S3')", + None, + Some(locals), + ) + .unwrap() + .downcast::>>() + .is_ok(); + + match (native_endian_works, little_endian_works, big_endian_works) { + (true, true, true) => (), + _ => panic!("All byteorders should work",), + } + }); +} + +#[test] +fn unicode_strings_respect_byteorder() { + Python::with_gil(|py| { + let np = py.eval("__import__('numpy')", None, None).unwrap(); + let locals = [("np", np)].into_py_dict(py); + + let native_endian_works = py + .eval("np.array(['foo', 'bar'], dtype='=U3')", None, Some(locals)) + .unwrap() + .downcast::>>() + .is_ok(); + + let little_endian_works = py + .eval("np.array(['foo', 'bar'], dtype='>>() + .is_ok(); + + let big_endian_works = py + .eval("np.array(['foo', 'bar'], dtype='>U3')", None, Some(locals)) + .unwrap() + .downcast::>>() + .is_ok(); + + match (native_endian_works, little_endian_works, big_endian_works) { + (true, true, false) | (true, false, true) => (), + _ => panic!("Only native byteorder should work"), + } + }); +} From ea5033f5173480551dd3b5b4c301137b190e2205 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Mon, 15 May 2023 18:23:40 +0200 Subject: [PATCH 2/2] Bump MSRV to 1.56 and Rust edition to 2021. --- .github/workflows/ci.yml | 2 +- CHANGELOG.md | 1 + Cargo.toml | 4 ++-- src/borrow/shared.rs | 14 ++------------ src/npyffi/mod.rs | 2 +- 5 files changed, 7 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 523066509..376ded49c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -128,7 +128,7 @@ jobs: with: python-version: 3.8 - name: Install Rust - uses: dtolnay/rust-toolchain@1.48.0 + uses: dtolnay/rust-toolchain@1.56.0 - uses: Swatinem/rust-cache@v2 with: workspaces: examples/simple diff --git a/CHANGELOG.md b/CHANGELOG.md index c7a7202bf..b6758c2a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog - Unreleased + - Increase MSRV to 1.56 released in October 2021 and available in Debain 12, RHEL 9 and Alpine 3.17 following the same change for PyO3. ([#378](https://github.com/PyO3/rust-numpy/pull/378)) - Add support for ASCII (`PyFixedString`) and Unicode (`PyFixedUnicode`) string arrays, i.e. dtypes `SN` and `UN` where `N` is the number of characters. ([#378](https://github.com/PyO3/rust-numpy/pull/378)) - v0.19.0 diff --git a/Cargo.toml b/Cargo.toml index 44ff6f75b..bc9cada11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ authors = [ ] description = "PyO3-based Rust bindings of the NumPy C-API" documentation = "https://docs.rs/numpy" -edition = "2018" -rust-version = "1.48" +edition = "2021" +rust-version = "1.56" repository = "https://github.com/PyO3/rust-numpy" categories = ["api-bindings", "development-tools::ffi", "science"] keywords = ["python", "numpy", "ffi", "pyo3"] diff --git a/src/borrow/shared.rs b/src/borrow/shared.rs index fb26a1deb..2193f36ef 100644 --- a/src/borrow/shared.rs +++ b/src/borrow/shared.rs @@ -125,7 +125,7 @@ fn insert_shared(py: Python) -> PyResult<*const Shared> { let module = get_array_module(py)?; let capsule: &PyCapsule = match module.getattr("_RUST_NUMPY_BORROW_CHECKING_API") { - Ok(capsule) => capsule.try_into()?, + Ok(capsule) => PyTryInto::try_into(capsule)?, Err(_err) => { let flags: *mut BorrowFlags = Box::into_raw(Box::default()); @@ -437,17 +437,7 @@ fn gcd_strides(array: *mut PyArrayObject) -> isize { let strides = unsafe { from_raw_parts((*array).strides, nd) }; - reduce(strides.iter().copied(), gcd).unwrap_or(1) -} - -// FIXME(adamreichold): Use `Iterator::reduce` from std when our MSRV reaches 1.51. -fn reduce(mut iter: I, f: F) -> Option -where - I: Iterator, - F: FnMut(I::Item, I::Item) -> I::Item, -{ - let first = iter.next()?; - Some(iter.fold(first, f)) + strides.iter().copied().reduce(gcd).unwrap_or(1) } #[cfg(test)] diff --git a/src/npyffi/mod.rs b/src/npyffi/mod.rs index 7cdd512e9..245e406c2 100644 --- a/src/npyffi/mod.rs +++ b/src/npyffi/mod.rs @@ -19,7 +19,7 @@ use pyo3::{ fn get_numpy_api(py: Python, module: &str, capsule: &str) -> PyResult<*const *const c_void> { let module = PyModule::import(py, module)?; - let capsule: &PyCapsule = module.getattr(capsule)?.try_into()?; + let capsule: &PyCapsule = PyTryInto::try_into(module.getattr(capsule)?)?; let api = capsule.pointer() as *const *const c_void;