Add some unit tests#111
Conversation
Code generated using AI and modified afterwards. Model: Opus 4.7
There was a problem hiding this comment.
Pull request overview
This PR fixes a couple of frontend/codegen regressions (sparsity dict parsing + gemmforge discovery import), expands accepted architecture precision strings, and introduces a substantial Python pytest suite plus a CI workflow to run it.
Changes:
- Fix empty dict sparsity parsing by safely handling
next(iter(...))on empty dictionaries. - Adjust gemmforge discovery to use
importlib.utiland add support forF32/F64precision strings. - Add a comprehensive pytest-based unit/integration test suite and a GitHub Actions workflow to run it with coverage.
Reviewed changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
yateto/type.py |
Handles empty dict sparsity patterns without raising StopIteration. |
yateto/codegen/copyscaleadd/factory.py |
Switches to importlib.util for gemmforge spec discovery. |
yateto/arch.py |
Accepts F32/F64 as precision strings in addition to S/D. |
tests/pytest/conftest.py |
Adds fixtures (arch setup/reset, common tensor helpers, AST pipeline helper). |
tests/pytest/test_type.py |
Adds frontend type tests (naming, shape invariants, sparsity/value handling, collections). |
tests/pytest/test_pipeline.py |
Adds Python-level pipeline “integration” tests over example scripts and codegen emission smoke checks. |
tests/pytest/test_opt.py |
Adds direct tests for strengthReduction decisions and degenerate cases. |
tests/pytest/test_memory.py |
Adds tests for DenseMemoryLayout (bbox/stride/alignment/addressing/permutation). |
tests/pytest/test_indices.py |
Adds tests for Indices, Range, BoundingBox, and LoGCost. |
tests/pytest/test_import.py |
Adds import/reload regression tests for top-level package and key submodules. |
tests/pytest/test_generator.py |
Adds tests for Generator, Kernel, and KernelFamily behaviors. |
tests/pytest/test_cost.py |
Adds tests for cost estimators (shape/bbox/exact) and abstract base behavior. |
tests/pytest/test_controlflow.py |
Adds tests for CFG lowering and CFG-level transforms (liveness, substitution, merges). |
tests/pytest/test_ast_visitor.py |
Adds tests for AST visitors (dispatch, cached visitor, flop/sparsity/constant evaluation). |
tests/pytest/test_ast_transformer.py |
Adds tests for AST transformer passes and their ordering/invariants. |
tests/pytest/test_ast_node.py |
Adds tests for core AST node construction/invariants and lowered nodes. |
tests/pytest/test_aspp.py |
Adds tests for aspp dense/general patterns and mixed-type dispatch. |
.github/workflows/yateto-pytest.yml |
Adds a workflow to run the new pytest suite across multiple Python versions with coverage upload. |
Comments suppressed due to low confidence (1)
yateto/codegen/copyscaleadd/factory.py:16
- The
except:block raises a string (raise ('...')), which will itself raiseTypeErrorand masks the real import failure whengemmforgeisn't available. Replace with raising a real exception type (e.g.,ImportError/ModuleNotFoundError) and avoid bareexceptso unrelated errors aren’t misreported; also consider replacing deprecatedloader.load_module()withimportlib.import_module/module_from_spec+exec_module.
import importlib.util
gf_spec = importlib.util.find_spec('gemmforge')
try:
if gf_spec:
gf = gf_spec.loader.load_module()
from .csa_gen import CopyScaleAddGenerator
except RuntimeError as err:
raise err
except:
raise ('gemmforge module is not found. You can install it with pip3. e.g., pip3 install gemmforge')
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| uses: actions/upload-artifact@v4 | ||
| with: | ||
| name: python-coverage | ||
| path: tests/python/coverage.xml |
There was a problem hiding this comment.
pytest-cov --cov-report=xml writes coverage.xml in the working directory by default, but the upload step points to tests/python/coverage.xml (which doesn’t exist in this repo). Either change the upload path to coverage.xml or set --cov-report=xml:tests/python/coverage.xml to match.
| path: tests/python/coverage.xml | |
| path: coverage.xml |
|
|
||
| def test_bad_sparsity_shape_raises(self): | ||
| wrong = np.ones((2, 2), dtype=bool) | ||
| with pytest.raises(Exception): |
There was a problem hiding this comment.
This test is too permissive: pytest.raises(Exception) will pass even for unrelated failures (e.g., NameError). Since the contract is a shape mismatch, it should assert a specific exception type (ideally ValueError) and optionally match the message so regressions in error handling don’t get hidden.
| with pytest.raises(Exception): | |
| with pytest.raises(ValueError): |
There was a problem hiding this comment.
Done (the Exception hid another error in the construction of the ValueError)
| self.precision = precision.upper() | ||
| if self.precision == 'D': | ||
| if self.precision in ('D', 'F64'): | ||
| self.bytesPerReal = 8 | ||
| self.typename = 'double' | ||
| self.epsilon = 2.22e-16 | ||
| elif self.precision == 'S': | ||
| elif self.precision in ('S', 'F32'): | ||
| self.bytesPerReal = 4 |
There was a problem hiding this comment.
Accepting precision strings 'F32'/'F64' without normalizing self.precision breaks downstream code that assumes 'S'/'D' (e.g., formatConstant() only checks == 'S', and GEMM tool naming uses arch.precision.lower() to form symbols like cblas_sgemm/cblas_dgemm). Consider mapping inputs F32->'S' and F64->'D' (or introducing a separate normalized field) so existing codegen paths keep working; also update the docstring that currently says precision is only 'd'/'s'.
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # FusedGEMMs - list-like container of LoGs |
There was a problem hiding this comment.
What is a LoG here? Is it a list of GEMMs?
There was a problem hiding this comment.
LoG == Loop over GEMM. I.e. you execute multiple GEMM operations in series to model more general tensor contractions of two tensors.
The FusedGEMM operation in essence lists multiple matrix multiplications/GEMMs to become one GPU kernel for ChainForge or tinytc (not TensorForge, that one works differently). ... And yes, it's more strictly matrix multiplications right now (without any other more general tensor contraction-ish things); so probably that comment might need an extra hint.
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _lower_to_cfg(kernel, arch): |
There was a problem hiding this comment.
I was under the impression we use _ for private variables and methods within classes to highlight they are supposed to be private. Do we also use them for helper functions? And does the _ here mention that it is not supposed to be used elsewhere?
There was a problem hiding this comment.
Yes; it's supposed to be private to the file. Like anonymous namespaces in C++ (or static functions outside of classes).
| # Diagonal-of-diagonal must be much cheaper - exactly N products | ||
| # (the contraction collapses to elementwise pairing). | ||
| # We expect something a lot less than 112 - let's require <= 16. | ||
| assert cost <= 16 |
There was a problem hiding this comment.
If I understand this correctly, 16 here is arbitrary. Am I correct? If that is the case, can we guarantee that the cost is always less than or equal to 16? I.e., are there possible edge cases where this is not satisfied?
There was a problem hiding this comment.
The test basically multiplies a 4×4 diagonal matrix with another 4×4 diagonal matrix—i.e. we get 4 multiplications (or FMAs against 0—which get subtracted away in the calculation).
I checked it—we get exactly 4 as a result. I've updated the test to require 4.
| P = Tensor("P", (4, 4)) | ||
| kernel = Kernel("k", A["ij"] <= A["ij"], prefetch=P) | ||
| # _prefetch is stored as a list internally. | ||
| assert kernel._prefetch == [P] |
There was a problem hiding this comment.
Are we not supposed to treat _prefetch like a private variable and not access it directly from outside? Is it worth writing a getter function for this?
There was a problem hiding this comment.
Yeah, a getter sounds good I think
vikaskurapati
left a comment
There was a problem hiding this comment.
LGTM. Could be merged after addressing the comments.
importlibinstead ofimportlib.utilis importedF32andF64as precision strings (to support the changes in SeisSol upstream)