Skip to content

Commit 1130a23

Browse files
committed
add some tests for dataclass_array_container
1 parent a5ec976 commit 1130a23

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

test/test_utils.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,16 @@
2222
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2323
THE SOFTWARE.
2424
"""
25+
import pytest
26+
27+
import numpy as np
2528

2629
import logging
2730
logger = logging.getLogger(__name__)
2831

2932

33+
# {{{ test_pt_actx_key_stringification_uniqueness
34+
3035
def test_pt_actx_key_stringification_uniqueness():
3136
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
3237

@@ -36,13 +41,63 @@ def test_pt_actx_key_stringification_uniqueness():
3641
assert (_ary_container_key_stringifier(("tup", 3, "endtup"))
3742
!= _ary_container_key_stringifier(((3,),)))
3843

44+
# }}}
45+
46+
47+
# {{{ test_dataclass_array_container
48+
49+
def test_dataclass_array_container():
50+
from typing import Optional
51+
from dataclasses import dataclass, field
52+
from arraycontext import dataclass_array_container
53+
54+
# {{{ string fields
55+
56+
@dataclass
57+
class ArrayContainerWithStringTypes:
58+
x: np.ndarray
59+
y: "np.ndarray"
60+
61+
with pytest.raises(AssertionError):
62+
dataclass_array_container(ArrayContainerWithStringTypes)
63+
64+
# }}}
65+
66+
# {{{ optional fields
67+
68+
@dataclass
69+
class ArrayContainerWithOptional:
70+
x: np.ndarray
71+
y: Optional[np.ndarray]
72+
73+
cls = dataclass_array_container(ArrayContainerWithOptional)
74+
assert cls is not None
75+
76+
# }}}
77+
78+
# {{{ field(init=False)
79+
80+
@dataclass
81+
class ArrayContainerWithOptionalInit:
82+
x: np.ndarray
83+
y: np.ndarray = field(default=np.zeros(42), init=False, repr=False)
84+
85+
cls = dataclass_array_container(ArrayContainerWithOptionalInit)
86+
ary = cls(x=np.array([1, 1], dtype=object))
87+
88+
from arraycontext import serialize_container
89+
assert len(serialize_container(ary)) == 1
90+
91+
# }}}
92+
93+
# }}}
94+
3995

4096
if __name__ == "__main__":
4197
import sys
4298
if len(sys.argv) > 1:
4399
exec(sys.argv[1])
44100
else:
45-
from pytest import main
46-
main([__file__])
101+
pytest.main([__file__])
47102

48103
# vim: fdm=marker

0 commit comments

Comments
 (0)