From db8e69b91f7d5a03e362c0bb4cd4827b5323941e Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Mon, 2 Feb 2026 13:45:15 +0100 Subject: [PATCH 01/15] spec: Add initial tooling to check data formats, prepare for more elaborate type checking --- spec/chip.typ | 1 - spec/memory.typ | 1 + spec/src/bitwise.toml | 24 +-- spec/src/branch.toml | 2 +- spec/src/config.toml | 1 - spec/src/cpu.toml | 7 + spec/src/lt.toml | 2 +- spec/src/memw.toml | 29 ++- spec/src/mul.toml | 6 +- spec/src/page.toml | 2 + spec/src/shift.toml | 4 +- spec/tooling/chip.py | 435 ++++++++++++++++++++++++++++++++++++++++++ 12 files changed, 489 insertions(+), 25 deletions(-) create mode 100644 spec/tooling/chip.py diff --git a/spec/chip.typ b/spec/chip.typ index 10479943e..4749b886e 100644 --- a/spec/chip.typ +++ b/spec/chip.typ @@ -113,7 +113,6 @@ } if "poly" in def { - // assert(false, message: repr(index_all(var_name, gather_indices(def)))) ( [], table.cell(align: right, emph[definition]), diff --git a/spec/memory.typ b/spec/memory.typ index 1fcb7b54e..62059de37 100644 --- a/spec/memory.typ +++ b/spec/memory.typ @@ -229,3 +229,4 @@ add the required balancing terms to the LogUp sum. = Future topics of interest - Optimize memory systems after determining factual bottlenecks (e.g. taking inspiration from Twist and Shout, or other recent research) +- Double check whether IS_BYTE constraints are needed for fini diff --git a/spec/src/bitwise.toml b/spec/src/bitwise.toml index 9b4a3f951..75e8faee4 100644 --- a/spec/src/bitwise.toml +++ b/spec/src/bitwise.toml @@ -4,67 +4,67 @@ name = "BITWISE" name = "X" type = "Byte" desc = "" -precomputed = "true" +precomputed = true [[variables.input]] name = "Y" type = "Byte" desc = "" -precomputed = "true" +precomputed = true [[variables.input]] name = "Z" type = "B4" desc = "" -precomputed = "true" +precomputed = true [[variables.output]] name = "AND" type = "Byte" desc = "the binary AND of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "OR" type = "Byte" desc = "the binary OR of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "XOR" type = "Byte" desc = "the binary XOR of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "MSB8" type = "Bit" desc = "the most significant bit of `X`" -precomputed = "true" +precomputed = true [[variables.output]] name = "MSB16" type = "Bit" desc = "the most significant bit of `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "ZERO" type = "Bit" desc = "whether $#`X` = 0$, $#`Y` = 0$ and $#`Z` = 0$." -precomputed = "true" +precomputed = true [[variables.output]] name = "SLL" type = "Half" desc = "`X||Y` logically left-shifted by `Z`: $((#`X` + 256#`Y`) #`<<` #`Z`) mod 2^16$" -precomputed = "true" +precomputed = true [[variables.output]] name = "SLLC" type = "Half" desc = "`X||Y` logically right-shifted by `Z`: $(#`X` + 256#`Y`) #`>>` (16 - #`Z`)$" -precomputed = "true" +precomputed = true [[variables.multiplicity]] name = "μ_AND" @@ -197,4 +197,4 @@ kind = "interaction" tag = "HWSLC" input = [["+", "X", ["*", 256, "Y"]], "Z"] output = "SLLC" -multiplicity = ["-", "μ_HWSLC"] \ No newline at end of file +multiplicity = ["-", "μ_HWSLC"] diff --git a/spec/src/branch.toml b/spec/src/branch.toml index e66974c8e..a521b85ee 100644 --- a/spec/src/branch.toml +++ b/spec/src/branch.toml @@ -145,4 +145,4 @@ kind = "interaction" tag = "BRANCH" input = ["pc", "offset", "register", "JALR"] output = "next_pc" -multiplicity = "-μ" +multiplicity = ["-", "μ"] diff --git a/spec/src/config.toml b/spec/src/config.toml index d836f80e5..cf628e92f 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -19,7 +19,6 @@ desc = "Variable that can only assume values in the range $[0, 2^4)$." [[variables.types]] label = "Byte" subtypes = ["BaseField"] -count = 1 desc = "Variable that can only assume values in the range $[0, 2^8)$." [[variables.types]] diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index c151b6eff..ccc2fcd0c 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -343,6 +343,7 @@ name = "decode" kind = "interaction" tag = "DECODE" input = ["pc", "imm", "packed_decode"] +multiplicity = 1 [[constraint_groups]] @@ -515,34 +516,40 @@ ref = "cpu:c:range_EBREAK" kind = "interaction" tag = "IS_BYTE" input = ["rs1"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = ["rs2"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = ["rd"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "arg1", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "arg2", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "res", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraint_groups]] diff --git a/spec/src/lt.toml b/spec/src/lt.toml index 10497b637..1941dbb7a 100644 --- a/spec/src/lt.toml +++ b/spec/src/lt.toml @@ -160,4 +160,4 @@ kind = "interaction" tag = "LT" input = [["cast", "lhs", "DWordWL"], ["cast", "rhs", "DWordWL"], "signed"] output = "lt" -multiplicity = "-μ" +multiplicity = ["-", "μ"] diff --git a/spec/src/memw.toml b/spec/src/memw.toml index f7276a9cd..af005c2b4 100644 --- a/spec/src/memw.toml +++ b/spec/src/memw.toml @@ -129,7 +129,7 @@ kind = "template" tag = "ADD" input = ["base_address", 1] output = ["cast", ["idx", "address_add", 0], "DWordWL"] -multiplicity = "w2" +cond = "w2" [[constraints.consistency]] kind = "template" @@ -137,7 +137,7 @@ tag = "ADD" input = ["base_address", ["+", "i", 1]] output = ["cast", ["idx", "address_add", "i"], "DWordWL"] iter = ["i", 1, 2] -multiplicity = "w4" +cond = "w4" [[constraints.consistency]] kind = "template" @@ -145,16 +145,37 @@ tag = "ADD" input = ["base_address", ["+", "i", 1]] output = ["cast", ["idx", "address_add", "i"], "DWordWL"] iter = ["i", 3, 6] -multiplicity = "write8" +cond = "write8" + +[[constraints.consistency]] +kind = "interaction" +tag = "IS_HALFWORD" +input = [["idx", ["idx", "address_add", "i"], "j"]] +iters = [ + ["i", 0, 0], + ["j", 0, 3], +] +multiplicity = "w2" [[constraints.consistency]] kind = "interaction" tag = "IS_HALFWORD" input = [["idx", ["idx", "address_add", "i"], "j"]] iters = [ - ["i", 0, 6], + ["i", 1, 2], ["j", 0, 3], ] +multiplicity = "w4" + +[[constraints.consistency]] +kind = "interaction" +tag = "IS_HALFWORD" +input = [["idx", ["idx", "address_add", "i"], "j"]] +iters = [ + ["i", 3, 6], + ["j", 0, 3], +] +multiplicity = "write8" [[constraints.consistency]] kind = "interaction" diff --git a/spec/src/mul.toml b/spec/src/mul.toml index 238bfe01f..a798c682d 100644 --- a/spec/src/mul.toml +++ b/spec/src/mul.toml @@ -182,7 +182,7 @@ name = "prod" [[constraints.prod]] kind = "arith" constraint = "$#`raw_product[i]` = sum_(#`k`=0)^1 2^(16k) sum_(#`j`=0)^(2i+k) #`lhs_ext[j]` dot #`rhs_ext[2i+k-j]`$" -poly = ["-", ["sum", ["=", "k", 0], "1", ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]] +poly = ["-", ["sum", ["=", "k", 0], 1, ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]] iter = ["i", 0, 3] ref = "mul:c:raw_product" @@ -192,7 +192,7 @@ name = "lookup" [[constraints.lookup]] kind = "interaction" tag = "MUL" -input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "0"] +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", 0] output = ["cast", "lo", "DWordWL"] multiplicity = ["-", "μ_lo"] ref = "mul:c:lookup_lo" @@ -200,7 +200,7 @@ ref = "mul:c:lookup_lo" [[constraints.lookup]] kind = "interaction" tag = "MUL" -input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "1"] +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", 1] output = ["cast", "hi", "DWordWL"] multiplicity = ["-", "μ_hi"] ref = "mul:c:lookup_hi" diff --git a/spec/src/page.toml b/spec/src/page.toml index 8053d63df..9ec7ed621 100644 --- a/spec/src/page.toml +++ b/spec/src/page.toml @@ -38,11 +38,13 @@ name = "all" kind = "interaction" tag = "IS_BYTE" input = ["init"] +multiplicity = 1 [[constraints.all]] kind = "interaction" tag = "IS_BYTE" input = ["fini"] +multiplicity = 1 [[constraints.all]] kind = "interaction" diff --git a/spec/src/shift.toml b/spec/src/shift.toml index 591efb839..e2ca8d12d 100644 --- a/spec/src/shift.toml +++ b/spec/src/shift.toml @@ -203,7 +203,7 @@ tag = "ZERO" input = ["bit_shift"] output = "zbs" ref = "shift:c:zbs" -multiplicity = "μ" +cond = "μ" [[constraint_groups]] @@ -292,5 +292,5 @@ kind = "interaction" tag = "SHIFT" input = ["in", "shift", "direction", "signed", "word_instr"] output = "out" -multiplicity = "-μ" +multiplicity = ["-", "μ"] ref = "shift:c:lookup" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py new file mode 100644 index 000000000..bd3156468 --- /dev/null +++ b/spec/tooling/chip.py @@ -0,0 +1,435 @@ +from dataclasses import dataclass +import sys +import tomllib +from typing import Optional, Union + +class ErrorReporter: + def __init__(self, location): + self.reported = False + self.location = location + + def update_location(self, loc): + self.reported = False + self.location = loc + + def error(self, message): + self.reported = True + print(f"ERROR {self.location}: {message}", file=sys.stderr) + + def asserts(self, condition, message): + if not condition: + self.error(message) + +reporter = ErrorReporter("unknown") + +def assert_no_unexpected(data, possible_keys): + for key in data.keys(): + reporter.asserts(key in possible_keys, f"Unexpected key: {key!r}") + +type Expr = (int + | VarExpr + | IdxExpr + | CastExpr + | MulExpr + | AddExpr + | SubExpr + | PowExpr + | SumExpr + | NotExpr + | DummyExpr + ) + +@dataclass +class VarExpr: + name: str + +@dataclass +class IdxExpr: + base: Expr + idx: Expr + +@dataclass +class CastExpr: + base: Expr + type: "Type" + +@dataclass +class MulExpr: + factors: list[Expr] + +@dataclass +class AddExpr: + terms: list[Expr] + +@dataclass +class SubExpr: + head: Expr + subs: list[Expr] + +@dataclass +class PowExpr: + base: Expr + exp: Expr + +@dataclass +class SumExpr: + iter: "Iter" + terms: Expr + +@dataclass +class NotExpr: + inner: Expr + +@dataclass +class DummyExpr: + pass + +def build_expr(config: Optional["Config"], data) -> Expr: + # TODO + # Does this need config, or do we delay any config-checking to when we use the expr? + match data: + case int(x): + return x + case str(x): + reporter.asserts(x.isidentifier(), f"Invalid identifier name for variable {x!r}") + return x + case ["idx", x, y]: + return IdxExpr(build_expr(config, x), build_expr(config, y)) + case ["cast", x, t]: + assert config is not None + return CastExpr(build_expr(config, x), Type(config.variables.types, t)) + case ["*", *factors]: + return MulExpr([build_expr(config, f) for f in factors]) + case ["+", *terms]: + return AddExpr([build_expr(config, t) for t in terms]) + case ["-", head, *subs]: + return SubExpr(build_expr(config, head), [build_expr(config, s) for s in subs]) + case ["^", base, exp]: + return PowExpr(build_expr(config, base), build_expr(config, exp)) + case ["sum", ["=", str(var), start], stop, terms]: + return SumExpr(Iter(config, var, start, stop), build_expr(config, terms)) + case ["not", e]: + return NotExpr(build_expr(config, e)) + case other: + reporter.error(f"Unknown expression: {other!r}") + return DummyExpr() + +@dataclass +class Iter: + name: str + start: Expr + stop: Expr + + def __init__(self, config, name, start, stop): + self.name = name + reporter.asserts(isinstance(self.name, str), f"iter name is not a string: {self.name!r}") + reporter.asserts(self.name.isidentifier(), f"Not a valid identifier: {self.name!r}") + self.start = build_expr(config, start) + self.stop = build_expr(config, stop) + +def iters_of(obj, name = None) -> list[Iter]: + """Return a list of iterators needed by `obj`. Taken from `iters` or `iter`. + Prepend `name` to every iterator, if given. + Adapted from the corresponding typst implementation.""" + def clean_iter(it): + arr = it if isinstance(it, list) else [it] + if name is not None: + arr = [name] + arr + + if len(arr) == 2: + # Assume single-element range + arr.append(arr[-1]) + + if len(arr) != 3: + reporter.error(f"Invalid length iter: {arr!r}") + return Iter(config, "_", 0, 0) + return Iter(config, *arr) + + if "iters" in obj: + reporter.asserts("iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}") + return [clean_iter(it) for it in obj["iters"]] + elif "iter" in obj: + return [clean_iter(obj["iter"])] + else: + return [] + +@dataclass +class Type: + base: Union["Type", str] + dimension: Optional[int] + + def __init__(self, valid_types: Optional[list["TypeConfig"]], data): + match data: + case str(x): + reporter.asserts(valid_types is None or x in [tc.label for tc in valid_types], f"Invalid variable type: {x!r}") + self.base = x + self.dimension = None + case [base, int(dim)]: + self.base = Type(valid_types, base) + self.dimension = dim + case other: + reporter.error(f"Unable to parse type: {other!r}") + +@dataclass +class TypeConfig: + label: str + subtypes: list[Type] + desc: str + preprocessed: bool + + def __init__(self, data, valid_types=None): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.label = data["label"] + self.subtypes = [Type(valid_types, tp) for tp in data["subtypes"]] + self.desc = data["desc"] + self.preprocessed = data.get("preprocessed", False) + +@dataclass +class ConfigCategories: + all: list[str] + instantiated: list[str] + + def __init__(self, data): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.all = data["all"] + self.instantiated = data["instantiated"] + reporter.asserts(all(isinstance(v, str) for v in self.all), f"Something's not a string: {self.all}") + reporter.asserts(all(isinstance(v, str) for v in self.instantiated), f"Something's not a string: {self.instantiated}") + + +@dataclass +class ConfigVariables: + types: list[TypeConfig] + categories: ConfigCategories + + def __init__(self, data): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.types = [] + for tp in data["types"]: + if tp["subtypes"] == [tp["label"]]: + self.types.append(TypeConfig(tp, valid_types=None)) + else: + self.types.append(TypeConfig(tp, valid_types=self.types)) + self.categories = ConfigCategories(data["categories"]) + +@dataclass +class ConfigMetadata: + version: int + + def __init__(self, data): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.version = data["version"] + reporter.asserts(isinstance(self.version, int), f"version {self.version!r} is not an int") + +@dataclass +class Config: + metadata: ConfigMetadata + variables: ConfigVariables + + def __init__(self, data): + """Construct a Config from toml-parsed data""" + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.metadata= ConfigMetadata(data["metadata"]) + self.variables = ConfigVariables(data["variables"]) + + @classmethod + def from_file(cls, filename): + reporter.update_location(filename) + return cls(tomllib.load(open(filename, "rb"))) + + @classmethod + def from_string(cls, s): + reporter.update_location("") + return cls(tomllib.loads(s)) + + +@dataclass +class Variable: + category: str + name: str + type: Type + desc: str + pad: Expr + precomputed: bool + + def __init__(self, config: Config, category: str, data): + self.category = category + assert_no_unexpected(data, Variable.__annotations__.keys()) + self.name = data["name"] + reporter.asserts(isinstance(self.name, str), f"{self.name!r} is not a string") + reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") + self.type = Type(config.variables.types, data["type"]) + self.desc = data["desc"] + reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") + self.pad = build_expr(None, data.get("pad", 0)) + self.precomputed = data.get("precomputed", False) + reporter.asserts(isinstance(self.precomputed, bool), f"precomputed is not a bool: {self.precomputed!r}") + +@dataclass +class VirtualDef: + # A list of polynomials with each a set of iters they range over + defs: list[tuple[list[Iter], Expr]] + + def __init__(self, config: Config, name: str, tp: Type, data): + # TODO? More sanity checking the format (or is that duplicating work done in typst already) + if "poly" in data: + idx = data.get("idx", None) + self.defs = [(iters_of(data, name = idx), build_expr(config, data["poly"]))] + elif "polys" in data: + idx = data.get("idx", None) + self.defs = [(iters_of(poly, name = idx), build_expr(config, poly["poly"])) for poly in data["polys"]] + else: + self.defs = [([], build_expr(config, data))] + +@dataclass +class VirtualVariable(Variable): + def_: VirtualDef + + def __init__(self, config: Config, category: str, data): + assert_no_unexpected(data, set(Variable.__annotations__.keys()) | {"def"}) + reporter.asserts("def" in data, f"Missing def for virtual column: {data!r}") + def_ = data.pop("def", {}) + super().__init__(config, category, data) + self.def_ = VirtualDef(config, self.name, self.type, def_) + +@dataclass +class Assumption: + desc: str + iters: list[Iter] + + def __init__(self, config: Config, data): + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"}) + self.desc = data["desc"] + self.iters = iters_of(data) + +@dataclass +class ArithConstraint: + constraint: str + desc: str + poly: Expr + iters: list[Iter] + + def __init__(self, config: Config, data): + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) + assert data["kind"] == "arith" + self.constraint = data["constraint"] + reporter.asserts(isinstance(self.constraint, str), f"Constraint not a string: {self.constraint!r}") + self.desc = data.get("desc", "") + reporter.asserts(isinstance(self.desc, str), f"desc is not a string: {self.desc!r}") + self.poly = build_expr(config, data["poly"]) + self.iters = iters_of(data) + + +@dataclass +class TemplateConstraint: + tag: str + desc: str + input: list[Expr] + output: Optional[Expr] + cond: Optional[Expr] + iters: list[Iter] + + def __init__(self, config: Config, data): + assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) + assert data["kind"] == "template" + self.tag = data["tag"] + reporter.asserts(isinstance(self.tag, str), f"tag is not a string: {self.tag!r}") + self.desc = data.get("desc", "") + reporter.asserts(isinstance(self.desc, str), f"Description is not a string: {self.desc!r}") + self.input = [build_expr(config, inp) for inp in data["input"]] + if "output" in data: + self.output = build_expr(config, data["output"]) + else: + self.output = None + if "cond" in data: + self.cond = build_expr(config, data["cond"]) + else: + self.cond = None + self.iters = iters_of(data) + +@dataclass +class InteractionConstraint: + tag: str + input: list[Expr] + output: Optional[Expr] + multiplicity: Expr + iters: list[Iter] + + def __init__(self, config: Config, data): + assert data["kind"] == "interaction" + self.tag = data["tag"] + reporter.asserts(isinstance(self.tag, str), f"tag {self.tag!r} is not a string") + self.input = [build_expr(config, inp) for inp in data["input"]] + if "output" in data: + self.output = build_expr(config, data["output"]) + else: + self.output = None + self.multiplicity = build_expr(config, data["multiplicity"]) + self.iters = iters_of(data) + +@dataclass +class DummyConstraint: + pass + +type Constraint = ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint + +def build_constraint(config, data) -> Constraint: + match data["kind"]: + case "arith": + return ArithConstraint(config, data) + case "template": + return TemplateConstraint(config, data) + case "interaction": + return InteractionConstraint(config, data) + case other: + reporter.error(f"Unknown constraint kind: {other!r}") + return DummyConstraint() + +@dataclass +class Chip: + config: Config + name: str + variables: list[Variable] + assumptions: list[Assumption] + constraints: list[Constraint] + + def __init__(self, config: Config, data): + """Construct a chip from toml-parsed data""" + assert_no_unexpected(data, set(type(self).__annotations__.keys()) | {"constraint_groups"}) + assert_no_unexpected(data["variables"], config.variables.categories.all) + self.config = config + self.name = data["name"] + reporter.asserts(isinstance(self.name, str), f"name is not a string: {self.name!r}") + reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") + self.variables = [(Variable if cat != "virtual" else VirtualVariable)(config, cat, var) for cat, vars in data["variables"].items() for var in vars] + self.assumptions = [Assumption(config, asm) for asm in data.get("assumptions", [])] + constraint_groups = [grp["name"] for grp in data.get("constraint_groups", [])] + assert_no_unexpected(data.get("constraints", {}), constraint_groups) + self.constraints = [build_constraint(config, con) for group in data.get("constraints", {}).values() for con in group] + + @classmethod + def from_file(cls, config, filename): + reporter.update_location(filename) + return cls(config, tomllib.load(open(filename, "rb"))) + + @classmethod + def from_string(cls, config, s): + reporter.update_location("") + return cls(config, tomllib.loads(s)) + + +if __name__ == "__main__": + import pprint + config = Config.from_file(sys.argv[1]) + if reporter.reported: + sys.exit(1) + reported = False + chips = [] + for file in sys.argv[2:]: + if file == sys.argv[1]: + continue + print("Processing", file) + chips.append(Chip.from_file(config, file)) + reported = reported or reporter.reported + if not reported: + pprint.pprint(chips) From 7f02bdfd6e76d08c3ec1e88d7ea2873b1b7e58f8 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 3 Feb 2026 10:28:28 +0100 Subject: [PATCH 02/15] Initial type checking --- spec/tooling/chip.py | 288 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 246 insertions(+), 42 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index bd3156468..ec60d5c3d 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1,32 +1,42 @@ -from dataclasses import dataclass +from pathlib import Path +import copy import sys import tomllib -from typing import Optional, Union +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from typing import Optional, Union, Never + +def Bit_type(): + return Type(None, "Bit") class ErrorReporter: - def __init__(self, location): + reported: bool + location: str + + def __init__(self, location: str): self.reported = False self.location = location - def update_location(self, loc): + def update_location(self, loc: str): self.reported = False self.location = loc - def error(self, message): + def error(self, message: str): self.reported = True print(f"ERROR {self.location}: {message}", file=sys.stderr) - def asserts(self, condition, message): + def asserts(self, condition: bool, message: str): if not condition: self.error(message) reporter = ErrorReporter("unknown") -def assert_no_unexpected(data, possible_keys): +def assert_no_unexpected(data: dict, possible_keys: Iterable[str]): for key in data.keys(): reporter.asserts(key in possible_keys, f"Unexpected key: {key!r}") - -type Expr = (int + + +type Expr = (LitExpr | VarExpr | IdxExpr | CastExpr @@ -39,60 +49,180 @@ def assert_no_unexpected(data, possible_keys): | DummyExpr ) +# We can either have an explicit int literal (or const expression) or a known type +# Returning 0 as dummy value should work in most cases, as constants can be used for +# almost anything. The only exception being indexing. +type TypeCheck = Type | int + +@dataclass +class Environment: + config: "Config" + valmap: dict[str, int] + typemap: dict[str, "Type"] + + def resolve_index(self, base: "Type", idx: int) -> TypeCheck: + if base.dimension is not None: + if not (0 <= idx < base.dimension): + reporter.error(f"Index out of range for {base!r}: {idx!r}") + idx = 0 + if isinstance(base.base, str): + return Type(None, base.base) + else: + return base.base + + assert isinstance(base.base, str), "We somehow made a type that's not an array, but has a non-str base" + typeconfigs = [tc for tc in self.config.variables.types if tc.label == base.base] + if len(typeconfigs) != 1: + reporter.error(f"Unable to resolve type: {base!r}") + return 0 + typeconfig = typeconfigs[0] + if not (0 <= idx < len(typeconfig.subtypes)): + reporter.error(f"Index out of range for {base!r}: {idx!r}") + idx = 0 + return typeconfig.subtypes[idx] + +def type_match(a: TypeCheck, b: TypeCheck, context: str) -> TypeCheck: + """Check that `a` and `b` are "compatible" TypeCheck values. + That is, either one of them is a constant, or the type is the same""" + # TODO: improve here; e.g. by allowing thing to match if their subtype is identical? + # Then would have to return the subtype to be sure? + # Maybe break everything down to subtypes, then it's purely structural matching? + if isinstance(a, int): + return b + if isinstance(b, int): + return a + reporter.asserts(a == b, f"Type mismatch between {a!r} and {b!r} [{context}]") + return 0 + +@dataclass +class LitExpr: + lit: int + + def typecheck(self, _env: Environment) -> TypeCheck: + return self.lit + @dataclass class VarExpr: name: str + def typecheck(self, env: Environment) -> TypeCheck: + if self.name in env.valmap: + return env.valmap[self.name] + if self.name in env.typemap: + return env.typemap[self.name] + reporter.error(f"Unknown variable: {self.name!r}") + return 0 + @dataclass class IdxExpr: base: Expr idx: Expr + def typecheck(self, env: Environment) -> TypeCheck: + base = self.base.typecheck(env) + idx = self.idx.typecheck(env) + if isinstance(base, int): + reporter.error(f"Trying to index a constant value: {self.base!r}") + return 0 + if not isinstance(idx, int): + reporter.error(f"Trying to index with a non-constant: {self.idx!r}") + return 0 + return env.resolve_index(base, idx) + @dataclass class CastExpr: base: Expr type: "Type" + def typecheck(self, env: Environment) -> TypeCheck: + _base = self.base.typecheck(env) + # TODO? encode/list valid casts + return self.type + @dataclass class MulExpr: factors: list[Expr] + def typecheck(self, env: Environment) -> TypeCheck: + t: TypeCheck = 0 + for f in self.factors: + t = type_match(t, f.typecheck(env), repr(self)) + return t + @dataclass class AddExpr: terms: list[Expr] + def typecheck(self, env: Environment) -> TypeCheck: + t: TypeCheck = 0 + for term in self.terms: + t = type_match(t, term.typecheck(env), repr(self)) + return t + @dataclass class SubExpr: head: Expr subs: list[Expr] + def typecheck(self, env: Environment) -> TypeCheck: + t = self.head.typecheck(env) + for term in self.subs: + t = type_match(t, term.typecheck(env), repr(self)) + return t + @dataclass class PowExpr: base: Expr exp: Expr + def typecheck(self, env: Environment) -> TypeCheck: + base = self.base.typecheck(env) + exp = self.exp.typecheck(env) + if not isinstance(base, int): + reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") + return 0 + if not isinstance(exp, int): + reporter.error(f"Invalid exponentiation with non-const exponent: {self.exp!r}") + return 0 + return base**exp + + @dataclass class SumExpr: iter: "Iter" terms: Expr + def typecheck(self, env: Environment) -> TypeCheck: + t: TypeCheck = 0 + for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): + t = type_match(t, tc, repr(self)) + return t + @dataclass class NotExpr: inner: Expr + def typecheck(self, env: Environment) -> TypeCheck: + inner = self.inner.typecheck(env) + if isinstance(inner, int): + reporter.asserts(inner in {0, 1}, f"Not a bool passed to `not`: {self.inner!r}") + return 1 - inner + reporter.asserts(inner == Bit_type(), f"Not a bool passed to `not`: {self.inner!r}") + return Bit_type() + @dataclass class DummyExpr: - pass + def typecheck(self, _env: Environment) -> TypeCheck: + return 0 -def build_expr(config: Optional["Config"], data) -> Expr: - # TODO +def build_expr(config: Optional["Config"], data: object) -> Expr: # Does this need config, or do we delay any config-checking to when we use the expr? match data: case int(x): - return x + return LitExpr(x) case str(x): reporter.asserts(x.isidentifier(), f"Invalid identifier name for variable {x!r}") - return x + return VarExpr(x) case ["idx", x, y]: return IdxExpr(build_expr(config, x), build_expr(config, y)) case ["cast", x, t]: @@ -107,6 +237,7 @@ def build_expr(config: Optional["Config"], data) -> Expr: case ["^", base, exp]: return PowExpr(build_expr(config, base), build_expr(config, exp)) case ["sum", ["=", str(var), start], stop, terms]: + assert config is not None return SumExpr(Iter(config, var, start, stop), build_expr(config, terms)) case ["not", e]: return NotExpr(build_expr(config, e)) @@ -120,14 +251,30 @@ class Iter: start: Expr stop: Expr - def __init__(self, config, name, start, stop): + def __init__(self, config: "Config", name: str, start: object, stop: object): self.name = name reporter.asserts(isinstance(self.name, str), f"iter name is not a string: {self.name!r}") reporter.asserts(self.name.isidentifier(), f"Not a valid identifier: {self.name!r}") self.start = build_expr(config, start) self.stop = build_expr(config, stop) -def iters_of(obj, name = None) -> list[Iter]: + def typecheck[T](self, env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: + start = self.start.typecheck(env) + if not isinstance(start, int): + reporter.error(f"Starting value of summation not a const: {self!r}") + start = 0 + stop = self.stop.typecheck(env) + if not isinstance(stop, int): + reporter.error(f"Ending value of summation not a const: {self!r}") + stop = 0 + + for i in range(start, stop + 1): + old_env = copy.deepcopy(env) + env.valmap[self.name] = i + yield from callback(env) + env = old_env + +def iters_of(obj: dict, name = None) -> list[Iter]: """Return a list of iterators needed by `obj`. Taken from `iters` or `iter`. Prepend `name` to every iterator, if given. Adapted from the corresponding typst implementation.""" @@ -158,7 +305,7 @@ class Type: base: Union["Type", str] dimension: Optional[int] - def __init__(self, valid_types: Optional[list["TypeConfig"]], data): + def __init__(self, valid_types: Optional[list["TypeConfig"]], data: object): match data: case str(x): reporter.asserts(valid_types is None or x in [tc.label for tc in valid_types], f"Invalid variable type: {x!r}") @@ -177,7 +324,7 @@ class TypeConfig: desc: str preprocessed: bool - def __init__(self, data, valid_types=None): + def __init__(self, data: dict, valid_types=None): assert_no_unexpected(data, type(self).__annotations__.keys()) self.label = data["label"] self.subtypes = [Type(valid_types, tp) for tp in data["subtypes"]] @@ -189,7 +336,7 @@ class ConfigCategories: all: list[str] instantiated: list[str] - def __init__(self, data): + def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.all = data["all"] self.instantiated = data["instantiated"] @@ -202,7 +349,7 @@ class ConfigVariables: types: list[TypeConfig] categories: ConfigCategories - def __init__(self, data): + def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.types = [] for tp in data["types"]: @@ -216,7 +363,7 @@ def __init__(self, data): class ConfigMetadata: version: int - def __init__(self, data): + def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.version = data["version"] reporter.asserts(isinstance(self.version, int), f"version {self.version!r} is not an int") @@ -226,19 +373,19 @@ class Config: metadata: ConfigMetadata variables: ConfigVariables - def __init__(self, data): + def __init__(self, data: dict): """Construct a Config from toml-parsed data""" assert_no_unexpected(data, type(self).__annotations__.keys()) self.metadata= ConfigMetadata(data["metadata"]) self.variables = ConfigVariables(data["variables"]) @classmethod - def from_file(cls, filename): - reporter.update_location(filename) + def from_file(cls, filename: str | Path) -> "Config": + reporter.update_location(str(filename)) return cls(tomllib.load(open(filename, "rb"))) @classmethod - def from_string(cls, s): + def from_string(cls, s: str) -> "Config": reporter.update_location("") return cls(tomllib.loads(s)) @@ -252,7 +399,7 @@ class Variable: pad: Expr precomputed: bool - def __init__(self, config: Config, category: str, data): + def __init__(self, config: Config, category: str, data: dict): self.category = category assert_no_unexpected(data, Variable.__annotations__.keys()) self.name = data["name"] @@ -265,12 +412,18 @@ def __init__(self, config: Config, category: str, data): self.precomputed = data.get("precomputed", False) reporter.asserts(isinstance(self.precomputed, bool), f"precomputed is not a bool: {self.precomputed!r}") +def all_iters[T](its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: + if not its: + yield from callback(env) + else: + yield from its[0].typecheck(env, lambda e: all_iters(its[1:], e, callback)) + @dataclass class VirtualDef: # A list of polynomials with each a set of iters they range over defs: list[tuple[list[Iter], Expr]] - def __init__(self, config: Config, name: str, tp: Type, data): + def __init__(self, config: Config, name: str, tp: Type, data: dict): # TODO? More sanity checking the format (or is that duplicating work done in typst already) if "poly" in data: idx = data.get("idx", None) @@ -285,19 +438,23 @@ def __init__(self, config: Config, name: str, tp: Type, data): class VirtualVariable(Variable): def_: VirtualDef - def __init__(self, config: Config, category: str, data): + def __init__(self, config: Config, category: str, data: dict): assert_no_unexpected(data, set(Variable.__annotations__.keys()) | {"def"}) reporter.asserts("def" in data, f"Missing def for virtual column: {data!r}") def_ = data.pop("def", {}) super().__init__(config, category, data) self.def_ = VirtualDef(config, self.name, self.type, def_) + def typecheck(self, env: Environment) -> TypeCheck: + # TODO + return 0 + @dataclass class Assumption: desc: str iters: list[Iter] - def __init__(self, config: Config, data): + def __init__(self, config: Config, data: dict): assert_no_unexpected(data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"}) self.desc = data["desc"] self.iters = iters_of(data) @@ -309,7 +466,7 @@ class ArithConstraint: poly: Expr iters: list[Iter] - def __init__(self, config: Config, data): + def __init__(self, config: Config, data: dict): assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) assert data["kind"] == "arith" self.constraint = data["constraint"] @@ -319,6 +476,18 @@ def __init__(self, config: Config, data): self.poly = build_expr(config, data["poly"]) self.iters = iters_of(data) + def typecheck(self, env: Environment) -> Iterable[Never]: + # TODO: is there any reason to typecheck if something is equatable to 0? + # Iteration for the side effect of typechecking and reporting errors + for _ in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): + pass + return [] + +@dataclass +class TemplateSignature: + tag: str + input: list[TypeCheck] + output: Optional[TypeCheck] @dataclass class TemplateConstraint: @@ -329,7 +498,7 @@ class TemplateConstraint: cond: Optional[Expr] iters: list[Iter] - def __init__(self, config: Config, data): + def __init__(self, config: Config, data: dict): assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) assert data["kind"] == "template" self.tag = data["tag"] @@ -347,6 +516,22 @@ def __init__(self, config: Config, data): self.cond = None self.iters = iters_of(data) + def typecheck(self, env: Environment) -> Iterable[TemplateSignature]: + def callback(e: Environment) -> Iterable[TemplateSignature]: + # TODO: Should we be able to check cond somehow? + if self.cond is not None: + self.cond.typecheck(e) + return [TemplateSignature(self.tag, + [inp.typecheck(e) for inp in self.input], + self.output.typecheck(e) if self.output else None)] + return all_iters(self.iters, env, callback) + +@dataclass +class InteractionSignature: + tag: str + input: list[TypeCheck] + output: Optional[TypeCheck] + @dataclass class InteractionConstraint: tag: str @@ -355,7 +540,7 @@ class InteractionConstraint: multiplicity: Expr iters: list[Iter] - def __init__(self, config: Config, data): + def __init__(self, config: Config, data: dict): assert data["kind"] == "interaction" self.tag = data["tag"] reporter.asserts(isinstance(self.tag, str), f"tag {self.tag!r} is not a string") @@ -367,13 +552,23 @@ def __init__(self, config: Config, data): self.multiplicity = build_expr(config, data["multiplicity"]) self.iters = iters_of(data) + def typecheck(self, env: Environment) -> Iterable[InteractionSignature]: + def callback(e: Environment) -> Iterable[InteractionSignature]: + # TODO: Should we be able to check multiplicity somehow? + self.multiplicity.typecheck(e) + return [InteractionSignature(self.tag, + [inp.typecheck(e) for inp in self.input], + self.output.typecheck(e) if self.output else None)] + return all_iters(self.iters, env, callback) + @dataclass class DummyConstraint: - pass + def typecheck(self, env: Environment) -> list[Never]: + return [] type Constraint = ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint -def build_constraint(config, data) -> Constraint: +def build_constraint(config, data: dict) -> Constraint: match data["kind"]: case "arith": return ArithConstraint(config, data) @@ -393,7 +588,7 @@ class Chip: assumptions: list[Assumption] constraints: list[Constraint] - def __init__(self, config: Config, data): + def __init__(self, config: Config, data: dict): """Construct a chip from toml-parsed data""" assert_no_unexpected(data, set(type(self).__annotations__.keys()) | {"constraint_groups"}) assert_no_unexpected(data["variables"], config.variables.categories.all) @@ -408,28 +603,37 @@ def __init__(self, config: Config, data): self.constraints = [build_constraint(config, con) for group in data.get("constraints", {}).values() for con in group] @classmethod - def from_file(cls, config, filename): - reporter.update_location(filename) + def from_file(cls, config: Config, filename: str | Path) -> "Chip": + reporter.update_location(str(filename)) return cls(config, tomllib.load(open(filename, "rb"))) @classmethod - def from_string(cls, config, s): + def from_string(cls, config: Config, s: str) -> "Chip": reporter.update_location("") return cls(config, tomllib.loads(s)) + + def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: + env = Environment(self.config, {}, {v.name: v.type for v in self.variables}) + for v in self.variables: + if isinstance(v, VirtualVariable): + v.typecheck(env) + for c in self.constraints: + yield from c.typecheck(env) if __name__ == "__main__": - import pprint config = Config.from_file(sys.argv[1]) if reporter.reported: sys.exit(1) reported = False - chips = [] + chips: list[Chip] = [] for file in sys.argv[2:]: if file == sys.argv[1]: continue - print("Processing", file) chips.append(Chip.from_file(config, file)) reported = reported or reporter.reported if not reported: - pprint.pprint(chips) + for chip in chips: + reporter.update_location(f"Chip {chip.name}") + # TODO: do something with the signatures + (list(chip.typecheck())) From d3ae40b14623833a74c39ee605364722445a27ac Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 3 Feb 2026 10:29:22 +0100 Subject: [PATCH 03/15] ruff format --- spec/tooling/chip.py | 245 ++++++++++++++++++++++++++++++++----------- 1 file changed, 183 insertions(+), 62 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index ec60d5c3d..d01b19fd8 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -6,13 +6,15 @@ from dataclasses import dataclass from typing import Optional, Union, Never + def Bit_type(): return Type(None, "Bit") + class ErrorReporter: reported: bool location: str - + def __init__(self, location: str): self.reported = False self.location = location @@ -29,31 +31,35 @@ def asserts(self, condition: bool, message: str): if not condition: self.error(message) + reporter = ErrorReporter("unknown") + def assert_no_unexpected(data: dict, possible_keys: Iterable[str]): for key in data.keys(): reporter.asserts(key in possible_keys, f"Unexpected key: {key!r}") - - -type Expr = (LitExpr - | VarExpr - | IdxExpr - | CastExpr - | MulExpr - | AddExpr - | SubExpr - | PowExpr - | SumExpr - | NotExpr - | DummyExpr - ) + + +type Expr = ( + LitExpr + | VarExpr + | IdxExpr + | CastExpr + | MulExpr + | AddExpr + | SubExpr + | PowExpr + | SumExpr + | NotExpr + | DummyExpr +) # We can either have an explicit int literal (or const expression) or a known type # Returning 0 as dummy value should work in most cases, as constants can be used for # almost anything. The only exception being indexing. type TypeCheck = Type | int + @dataclass class Environment: config: "Config" @@ -70,8 +76,12 @@ def resolve_index(self, base: "Type", idx: int) -> TypeCheck: else: return base.base - assert isinstance(base.base, str), "We somehow made a type that's not an array, but has a non-str base" - typeconfigs = [tc for tc in self.config.variables.types if tc.label == base.base] + assert isinstance(base.base, str), ( + "We somehow made a type that's not an array, but has a non-str base" + ) + typeconfigs = [ + tc for tc in self.config.variables.types if tc.label == base.base + ] if len(typeconfigs) != 1: reporter.error(f"Unable to resolve type: {base!r}") return 0 @@ -81,6 +91,7 @@ def resolve_index(self, base: "Type", idx: int) -> TypeCheck: idx = 0 return typeconfig.subtypes[idx] + def type_match(a: TypeCheck, b: TypeCheck, context: str) -> TypeCheck: """Check that `a` and `b` are "compatible" TypeCheck values. That is, either one of them is a constant, or the type is the same""" @@ -94,6 +105,7 @@ def type_match(a: TypeCheck, b: TypeCheck, context: str) -> TypeCheck: reporter.asserts(a == b, f"Type mismatch between {a!r} and {b!r} [{context}]") return 0 + @dataclass class LitExpr: lit: int @@ -101,6 +113,7 @@ class LitExpr: def typecheck(self, _env: Environment) -> TypeCheck: return self.lit + @dataclass class VarExpr: name: str @@ -113,6 +126,7 @@ def typecheck(self, env: Environment) -> TypeCheck: reporter.error(f"Unknown variable: {self.name!r}") return 0 + @dataclass class IdxExpr: base: Expr @@ -129,6 +143,7 @@ def typecheck(self, env: Environment) -> TypeCheck: return 0 return env.resolve_index(base, idx) + @dataclass class CastExpr: base: Expr @@ -139,6 +154,7 @@ def typecheck(self, env: Environment) -> TypeCheck: # TODO? encode/list valid casts return self.type + @dataclass class MulExpr: factors: list[Expr] @@ -149,6 +165,7 @@ def typecheck(self, env: Environment) -> TypeCheck: t = type_match(t, f.typecheck(env), repr(self)) return t + @dataclass class AddExpr: terms: list[Expr] @@ -159,6 +176,7 @@ def typecheck(self, env: Environment) -> TypeCheck: t = type_match(t, term.typecheck(env), repr(self)) return t + @dataclass class SubExpr: head: Expr @@ -170,6 +188,7 @@ def typecheck(self, env: Environment) -> TypeCheck: t = type_match(t, term.typecheck(env), repr(self)) return t + @dataclass class PowExpr: base: Expr @@ -182,10 +201,12 @@ def typecheck(self, env: Environment) -> TypeCheck: reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") return 0 if not isinstance(exp, int): - reporter.error(f"Invalid exponentiation with non-const exponent: {self.exp!r}") + reporter.error( + f"Invalid exponentiation with non-const exponent: {self.exp!r}" + ) return 0 return base**exp - + @dataclass class SumExpr: @@ -198,6 +219,7 @@ def typecheck(self, env: Environment) -> TypeCheck: t = type_match(t, tc, repr(self)) return t + @dataclass class NotExpr: inner: Expr @@ -205,23 +227,31 @@ class NotExpr: def typecheck(self, env: Environment) -> TypeCheck: inner = self.inner.typecheck(env) if isinstance(inner, int): - reporter.asserts(inner in {0, 1}, f"Not a bool passed to `not`: {self.inner!r}") + reporter.asserts( + inner in {0, 1}, f"Not a bool passed to `not`: {self.inner!r}" + ) return 1 - inner - reporter.asserts(inner == Bit_type(), f"Not a bool passed to `not`: {self.inner!r}") + reporter.asserts( + inner == Bit_type(), f"Not a bool passed to `not`: {self.inner!r}" + ) return Bit_type() + @dataclass class DummyExpr: def typecheck(self, _env: Environment) -> TypeCheck: return 0 + def build_expr(config: Optional["Config"], data: object) -> Expr: # Does this need config, or do we delay any config-checking to when we use the expr? match data: case int(x): return LitExpr(x) case str(x): - reporter.asserts(x.isidentifier(), f"Invalid identifier name for variable {x!r}") + reporter.asserts( + x.isidentifier(), f"Invalid identifier name for variable {x!r}" + ) return VarExpr(x) case ["idx", x, y]: return IdxExpr(build_expr(config, x), build_expr(config, y)) @@ -233,7 +263,9 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: case ["+", *terms]: return AddExpr([build_expr(config, t) for t in terms]) case ["-", head, *subs]: - return SubExpr(build_expr(config, head), [build_expr(config, s) for s in subs]) + return SubExpr( + build_expr(config, head), [build_expr(config, s) for s in subs] + ) case ["^", base, exp]: return PowExpr(build_expr(config, base), build_expr(config, exp)) case ["sum", ["=", str(var), start], stop, terms]: @@ -245,6 +277,7 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: reporter.error(f"Unknown expression: {other!r}") return DummyExpr() + @dataclass class Iter: name: str @@ -253,12 +286,18 @@ class Iter: def __init__(self, config: "Config", name: str, start: object, stop: object): self.name = name - reporter.asserts(isinstance(self.name, str), f"iter name is not a string: {self.name!r}") - reporter.asserts(self.name.isidentifier(), f"Not a valid identifier: {self.name!r}") + reporter.asserts( + isinstance(self.name, str), f"iter name is not a string: {self.name!r}" + ) + reporter.asserts( + self.name.isidentifier(), f"Not a valid identifier: {self.name!r}" + ) self.start = build_expr(config, start) self.stop = build_expr(config, stop) - def typecheck[T](self, env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: + def typecheck[T]( + self, env: Environment, callback: Callable[[Environment], Iterable[T]] + ) -> Iterable[T]: start = self.start.typecheck(env) if not isinstance(start, int): reporter.error(f"Starting value of summation not a const: {self!r}") @@ -274,10 +313,12 @@ def typecheck[T](self, env: Environment, callback: Callable[[Environment], Itera yield from callback(env) env = old_env -def iters_of(obj: dict, name = None) -> list[Iter]: + +def iters_of(obj: dict, name=None) -> list[Iter]: """Return a list of iterators needed by `obj`. Taken from `iters` or `iter`. Prepend `name` to every iterator, if given. Adapted from the corresponding typst implementation.""" + def clean_iter(it): arr = it if isinstance(it, list) else [it] if name is not None: @@ -293,13 +334,16 @@ def clean_iter(it): return Iter(config, *arr) if "iters" in obj: - reporter.asserts("iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}") + reporter.asserts( + "iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}" + ) return [clean_iter(it) for it in obj["iters"]] elif "iter" in obj: return [clean_iter(obj["iter"])] else: return [] + @dataclass class Type: base: Union["Type", str] @@ -308,7 +352,10 @@ class Type: def __init__(self, valid_types: Optional[list["TypeConfig"]], data: object): match data: case str(x): - reporter.asserts(valid_types is None or x in [tc.label for tc in valid_types], f"Invalid variable type: {x!r}") + reporter.asserts( + valid_types is None or x in [tc.label for tc in valid_types], + f"Invalid variable type: {x!r}", + ) self.base = x self.dimension = None case [base, int(dim)]: @@ -317,6 +364,7 @@ def __init__(self, valid_types: Optional[list["TypeConfig"]], data: object): case other: reporter.error(f"Unable to parse type: {other!r}") + @dataclass class TypeConfig: label: str @@ -331,6 +379,7 @@ def __init__(self, data: dict, valid_types=None): self.desc = data["desc"] self.preprocessed = data.get("preprocessed", False) + @dataclass class ConfigCategories: all: list[str] @@ -340,8 +389,14 @@ def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.all = data["all"] self.instantiated = data["instantiated"] - reporter.asserts(all(isinstance(v, str) for v in self.all), f"Something's not a string: {self.all}") - reporter.asserts(all(isinstance(v, str) for v in self.instantiated), f"Something's not a string: {self.instantiated}") + reporter.asserts( + all(isinstance(v, str) for v in self.all), + f"Something's not a string: {self.all}", + ) + reporter.asserts( + all(isinstance(v, str) for v in self.instantiated), + f"Something's not a string: {self.instantiated}", + ) @dataclass @@ -359,6 +414,7 @@ def __init__(self, data: dict): self.types.append(TypeConfig(tp, valid_types=self.types)) self.categories = ConfigCategories(data["categories"]) + @dataclass class ConfigMetadata: version: int @@ -366,7 +422,10 @@ class ConfigMetadata: def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.version = data["version"] - reporter.asserts(isinstance(self.version, int), f"version {self.version!r} is not an int") + reporter.asserts( + isinstance(self.version, int), f"version {self.version!r} is not an int" + ) + @dataclass class Config: @@ -376,9 +435,9 @@ class Config: def __init__(self, data: dict): """Construct a Config from toml-parsed data""" assert_no_unexpected(data, type(self).__annotations__.keys()) - self.metadata= ConfigMetadata(data["metadata"]) + self.metadata = ConfigMetadata(data["metadata"]) self.variables = ConfigVariables(data["variables"]) - + @classmethod def from_file(cls, filename: str | Path) -> "Config": reporter.update_location(str(filename)) @@ -398,7 +457,7 @@ class Variable: desc: str pad: Expr precomputed: bool - + def __init__(self, config: Config, category: str, data: dict): self.category = category assert_no_unexpected(data, Variable.__annotations__.keys()) @@ -410,14 +469,21 @@ def __init__(self, config: Config, category: str, data: dict): reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") self.pad = build_expr(None, data.get("pad", 0)) self.precomputed = data.get("precomputed", False) - reporter.asserts(isinstance(self.precomputed, bool), f"precomputed is not a bool: {self.precomputed!r}") + reporter.asserts( + isinstance(self.precomputed, bool), + f"precomputed is not a bool: {self.precomputed!r}", + ) + -def all_iters[T](its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]]) -> Iterable[T]: +def all_iters[T]( + its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]] +) -> Iterable[T]: if not its: yield from callback(env) else: yield from its[0].typecheck(env, lambda e: all_iters(its[1:], e, callback)) + @dataclass class VirtualDef: # A list of polynomials with each a set of iters they range over @@ -427,17 +493,21 @@ def __init__(self, config: Config, name: str, tp: Type, data: dict): # TODO? More sanity checking the format (or is that duplicating work done in typst already) if "poly" in data: idx = data.get("idx", None) - self.defs = [(iters_of(data, name = idx), build_expr(config, data["poly"]))] + self.defs = [(iters_of(data, name=idx), build_expr(config, data["poly"]))] elif "polys" in data: idx = data.get("idx", None) - self.defs = [(iters_of(poly, name = idx), build_expr(config, poly["poly"])) for poly in data["polys"]] + self.defs = [ + (iters_of(poly, name=idx), build_expr(config, poly["poly"])) + for poly in data["polys"] + ] else: self.defs = [([], build_expr(config, data))] + @dataclass class VirtualVariable(Variable): def_: VirtualDef - + def __init__(self, config: Config, category: str, data: dict): assert_no_unexpected(data, set(Variable.__annotations__.keys()) | {"def"}) reporter.asserts("def" in data, f"Missing def for virtual column: {data!r}") @@ -449,16 +519,20 @@ def typecheck(self, env: Environment) -> TypeCheck: # TODO return 0 + @dataclass class Assumption: desc: str iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected(data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"}) + assert_no_unexpected( + data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"} + ) self.desc = data["desc"] self.iters = iters_of(data) + @dataclass class ArithConstraint: constraint: str @@ -467,12 +541,19 @@ class ArithConstraint: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) + assert_no_unexpected( + data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} + ) assert data["kind"] == "arith" self.constraint = data["constraint"] - reporter.asserts(isinstance(self.constraint, str), f"Constraint not a string: {self.constraint!r}") + reporter.asserts( + isinstance(self.constraint, str), + f"Constraint not a string: {self.constraint!r}", + ) self.desc = data.get("desc", "") - reporter.asserts(isinstance(self.desc, str), f"desc is not a string: {self.desc!r}") + reporter.asserts( + isinstance(self.desc, str), f"desc is not a string: {self.desc!r}" + ) self.poly = build_expr(config, data["poly"]) self.iters = iters_of(data) @@ -483,12 +564,14 @@ def typecheck(self, env: Environment) -> Iterable[Never]: pass return [] + @dataclass class TemplateSignature: tag: str input: list[TypeCheck] output: Optional[TypeCheck] + @dataclass class TemplateConstraint: tag: str @@ -499,12 +582,18 @@ class TemplateConstraint: iters: list[Iter] def __init__(self, config: Config, data: dict): - assert_no_unexpected(data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"}) + assert_no_unexpected( + data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} + ) assert data["kind"] == "template" self.tag = data["tag"] - reporter.asserts(isinstance(self.tag, str), f"tag is not a string: {self.tag!r}") + reporter.asserts( + isinstance(self.tag, str), f"tag is not a string: {self.tag!r}" + ) self.desc = data.get("desc", "") - reporter.asserts(isinstance(self.desc, str), f"Description is not a string: {self.desc!r}") + reporter.asserts( + isinstance(self.desc, str), f"Description is not a string: {self.desc!r}" + ) self.input = [build_expr(config, inp) for inp in data["input"]] if "output" in data: self.output = build_expr(config, data["output"]) @@ -521,17 +610,24 @@ def callback(e: Environment) -> Iterable[TemplateSignature]: # TODO: Should we be able to check cond somehow? if self.cond is not None: self.cond.typecheck(e) - return [TemplateSignature(self.tag, - [inp.typecheck(e) for inp in self.input], - self.output.typecheck(e) if self.output else None)] + return [ + TemplateSignature( + self.tag, + [inp.typecheck(e) for inp in self.input], + self.output.typecheck(e) if self.output else None, + ) + ] + return all_iters(self.iters, env, callback) + @dataclass class InteractionSignature: tag: str input: list[TypeCheck] output: Optional[TypeCheck] + @dataclass class InteractionConstraint: tag: str @@ -556,17 +652,27 @@ def typecheck(self, env: Environment) -> Iterable[InteractionSignature]: def callback(e: Environment) -> Iterable[InteractionSignature]: # TODO: Should we be able to check multiplicity somehow? self.multiplicity.typecheck(e) - return [InteractionSignature(self.tag, - [inp.typecheck(e) for inp in self.input], - self.output.typecheck(e) if self.output else None)] + return [ + InteractionSignature( + self.tag, + [inp.typecheck(e) for inp in self.input], + self.output.typecheck(e) if self.output else None, + ) + ] + return all_iters(self.iters, env, callback) + @dataclass class DummyConstraint: def typecheck(self, env: Environment) -> list[Never]: return [] -type Constraint = ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint + +type Constraint = ( + ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint +) + def build_constraint(config, data: dict) -> Constraint: match data["kind"]: @@ -580,6 +686,7 @@ def build_constraint(config, data: dict) -> Constraint: reporter.error(f"Unknown constraint kind: {other!r}") return DummyConstraint() + @dataclass class Chip: config: Config @@ -587,21 +694,35 @@ class Chip: variables: list[Variable] assumptions: list[Assumption] constraints: list[Constraint] - + def __init__(self, config: Config, data: dict): """Construct a chip from toml-parsed data""" - assert_no_unexpected(data, set(type(self).__annotations__.keys()) | {"constraint_groups"}) + assert_no_unexpected( + data, set(type(self).__annotations__.keys()) | {"constraint_groups"} + ) assert_no_unexpected(data["variables"], config.variables.categories.all) self.config = config self.name = data["name"] - reporter.asserts(isinstance(self.name, str), f"name is not a string: {self.name!r}") + reporter.asserts( + isinstance(self.name, str), f"name is not a string: {self.name!r}" + ) reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") - self.variables = [(Variable if cat != "virtual" else VirtualVariable)(config, cat, var) for cat, vars in data["variables"].items() for var in vars] - self.assumptions = [Assumption(config, asm) for asm in data.get("assumptions", [])] + self.variables = [ + (Variable if cat != "virtual" else VirtualVariable)(config, cat, var) + for cat, vars in data["variables"].items() + for var in vars + ] + self.assumptions = [ + Assumption(config, asm) for asm in data.get("assumptions", []) + ] constraint_groups = [grp["name"] for grp in data.get("constraint_groups", [])] assert_no_unexpected(data.get("constraints", {}), constraint_groups) - self.constraints = [build_constraint(config, con) for group in data.get("constraints", {}).values() for con in group] - + self.constraints = [ + build_constraint(config, con) + for group in data.get("constraints", {}).values() + for con in group + ] + @classmethod def from_file(cls, config: Config, filename: str | Path) -> "Chip": reporter.update_location(str(filename)) @@ -619,7 +740,7 @@ def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: v.typecheck(env) for c in self.constraints: yield from c.typecheck(env) - + if __name__ == "__main__": config = Config.from_file(sys.argv[1]) From 6b7b916d83ece3dc20356e5250b8e89fe650a3b0 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 12:15:29 +0100 Subject: [PATCH 04/15] Update some more typing mismatches --- spec/src/branch.toml | 4 ++-- spec/src/config.toml | 13 ++++++++----- spec/src/cpu.toml | 12 ++++++------ spec/src/is_bit.toml | 2 +- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/spec/src/branch.toml b/spec/src/branch.toml index a521b85ee..0659479c5 100644 --- a/spec/src/branch.toml +++ b/spec/src/branch.toml @@ -59,7 +59,7 @@ name = "next_pc_unmasked" type = "DWordWL" desc = "The combination of `next_pc_high`, `next_pc_low[1]` and `unmasked_low_byte` to constrain the addition. This is the computed value for the next pc, before masking off the LSB as required by the ISA." def = {idx = "i", polys = [ - {iter = 0, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 0]], ["*", ["^", 2, 8], ["idx", "next_pc_low", 1]], ["idx", "unmasked_low_byte", 0]]}, + {iter = 0, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 0]], ["*", ["^", 2, 8], ["idx", "next_pc_low", 1]], "unmasked_low_byte"]}, {iter = 1, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 2]], ["idx", "next_pc_high", 1]]}, ]} @@ -124,7 +124,7 @@ multiplicity = "μ" [[constraints.all]] kind = "interaction" tag = "AND_BYTE" -input = [["idx", "unmasked_low_byte", 0], 254] +input = ["unmasked_low_byte", 254] output = ["idx", "next_pc_low", 0] multiplicity = "μ" diff --git a/spec/src/config.toml b/spec/src/config.toml index cf628e92f..af6b34971 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -4,36 +4,43 @@ version = 1 [[variables.types]] label = "BaseField" subtypes = ["BaseField"] +range = [0, 18446744069414584320] desc = "Variable that can assume any value in the base field." [[variables.types]] label = "Bit" subtypes = ["BaseField"] +range = [0, 1] desc = "Variable that can only assume values in the set ${0,1}$." [[variables.types]] label = "B4" subtypes = ["BaseField"] +range = [0, 15] desc = "Variable that can only assume values in the range $[0, 2^4)$." [[variables.types]] label = "Byte" subtypes = ["BaseField"] +range = [0, 255] desc = "Variable that can only assume values in the range $[0, 2^8)$." [[variables.types]] label = "Half" subtypes = ["BaseField"] +range = [0, 65535] desc = "Variable that can only assume values in the range $[0, 2^16)$." [[variables.types]] label = "B20" subtypes = ["BaseField"] +range = [0, 1048575] desc = "Variable that can only assume values in the range $[0, 2^20)$." [[variables.types]] label = "Word" subtypes = ["BaseField"] +range = [0, 4294967295] desc = "Variable that can only assume values in the range $[0, 2^32)$." [[variables.types]] @@ -52,14 +59,10 @@ desc = """\ Represented as an array of four `Byte` variables.\ """ -[[variables.types]] -label = "B35" -subtypes = ["BaseField"] -desc = "Variable that can only assume values in the range $[0, 2^35)$." - [[variables.types]] label = "B51" subtypes = ["BaseField"] +range = [0, 2251799813685247] desc = "Variable that can only assume values in the range $[0, 2^51)$." [[variables.types]] diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index ccc2fcd0c..e415b9ae3 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -647,7 +647,7 @@ prefix = "M" [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rs1"], "rv1", ["+", "timestamp", 0], 1, 0, 0] +input = [1, ["*", 2, "rs1"], "rv1", ["+", "timestamp", ["cast", 0, "DWordWL"]], 1, 0, 0] output = "rv1" multiplicity = "read_register1" @@ -661,7 +661,7 @@ iter = ["i", 0, 2] [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rs2"], "rv2", ["+", "timestamp", 1], 1, 0, 0] +input = [1, ["*", 2, "rs2"], "rv2", ["+", "timestamp", ["cast", 1, "DWordWL"]], 1, 0, 0] output = "rv2" multiplicity = "read_register2" @@ -675,13 +675,13 @@ iter = ["i", 0, 2] [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rd"], "rvd", ["+", "timestamp", 2], 1, 0, 0] +input = [1, ["*", 2, "rd"], "rvd", ["+", "timestamp", ["cast", 2, "DWordWL"]], 1, 0, 0] multiplicity = "write_register" [[constraints.mem]] kind = "interaction" tag = "LOAD" -input = [0, "res", ["+", "timestamp", 0], "memory_2bytes", "memory_4bytes", "memory_8bytes", "signed"] +input = [0, "res", ["+", "timestamp", ["cast", 0, "DWordWL"]], "memory_2bytes", "memory_4bytes", "memory_8bytes", "signed"] output = "rvd" multiplicity = "LOAD" @@ -689,14 +689,14 @@ multiplicity = "LOAD" [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [0, "res", "rv2", ["+", "timestamp", 1], "memory_2bytes", "memory_4bytes", "memory_8bytes"] +input = [0, "res", "rv2", ["+", "timestamp", ["cast", 1, "DWordWL"]], "memory_2bytes", "memory_4bytes", "memory_8bytes"] multiplicity = "STORE" # TODO: no types available, so no casting yet [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, 255], "next_pc", ["+", "timestamp", 1], 1, 0, 0] +input = [1, ["*", 2, 255], "next_pc", ["+", "timestamp", ["cast", 1, "DWordWL"]], 1, 0, 0] output = "pc" multiplicity = ["not", "pad"] diff --git a/spec/src/is_bit.toml b/spec/src/is_bit.toml index 47e96a27e..a72b5f648 100644 --- a/spec/src/is_bit.toml +++ b/spec/src/is_bit.toml @@ -16,5 +16,5 @@ name = "all" [[constraints.all]] kind = "arith" constraint = "$#`cond` => #`X` (1-#`X`) = 0$" -poly = ["*", "cond", "X", ["not", "X"]] +poly = ["*", "cond", "X", ["-", 1, "X"]] ref = "isbit:c:isbit" From 615b76b2f690689e2e78733804c7b05c244fa8fd Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 12:16:31 +0100 Subject: [PATCH 05/15] Move to range-based type checks --- spec/tooling/chip.py | 369 ++++++++++++++++++++++++++----------------- 1 file changed, 224 insertions(+), 145 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index d01b19fd8..f303b8261 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1,14 +1,9 @@ from pathlib import Path -import copy import sys import tomllib from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Optional, Union, Never - - -def Bit_type(): - return Type(None, "Bit") +from typing import Optional, Never class ErrorReporter: @@ -40,6 +35,26 @@ def assert_no_unexpected(data: dict, possible_keys: Iterable[str]): reporter.asserts(key in possible_keys, f"Unexpected key: {key!r}") +@dataclass(frozen=True) +class Range: + low: int + high: int + + def is_bool(self): + return self.is_lit and self.low >= 0 and self.high <= 1 + + def is_lit(self): + return self.low == self.high + + def get_lit(self) -> int: + assert self.is_lit() + return self.low + + +type Type = list[Type] | Range + +DEFAULT_TYPE: Type = Range(0, 0) + type Expr = ( LitExpr | VarExpr @@ -54,77 +69,33 @@ def assert_no_unexpected(data: dict, possible_keys: Iterable[str]): | DummyExpr ) -# We can either have an explicit int literal (or const expression) or a known type -# Returning 0 as dummy value should work in most cases, as constants can be used for -# almost anything. The only exception being indexing. -type TypeCheck = Type | int - @dataclass class Environment: config: "Config" - valmap: dict[str, int] - typemap: dict[str, "Type"] - - def resolve_index(self, base: "Type", idx: int) -> TypeCheck: - if base.dimension is not None: - if not (0 <= idx < base.dimension): - reporter.error(f"Index out of range for {base!r}: {idx!r}") - idx = 0 - if isinstance(base.base, str): - return Type(None, base.base) - else: - return base.base - - assert isinstance(base.base, str), ( - "We somehow made a type that's not an array, but has a non-str base" - ) - typeconfigs = [ - tc for tc in self.config.variables.types if tc.label == base.base - ] - if len(typeconfigs) != 1: - reporter.error(f"Unable to resolve type: {base!r}") - return 0 - typeconfig = typeconfigs[0] - if not (0 <= idx < len(typeconfig.subtypes)): - reporter.error(f"Index out of range for {base!r}: {idx!r}") - idx = 0 - return typeconfig.subtypes[idx] - - -def type_match(a: TypeCheck, b: TypeCheck, context: str) -> TypeCheck: - """Check that `a` and `b` are "compatible" TypeCheck values. - That is, either one of them is a constant, or the type is the same""" - # TODO: improve here; e.g. by allowing thing to match if their subtype is identical? - # Then would have to return the subtype to be sure? - # Maybe break everything down to subtypes, then it's purely structural matching? - if isinstance(a, int): - return b - if isinstance(b, int): - return a - reporter.asserts(a == b, f"Type mismatch between {a!r} and {b!r} [{context}]") - return 0 + valmap: dict[str, Range] + typemap: dict[str, Type] @dataclass class LitExpr: lit: int - def typecheck(self, _env: Environment) -> TypeCheck: - return self.lit + def typecheck(self, _env: Environment) -> Type: + return Range(self.lit, self.lit) @dataclass class VarExpr: name: str - def typecheck(self, env: Environment) -> TypeCheck: + def typecheck(self, env: Environment) -> Type: if self.name in env.valmap: return env.valmap[self.name] if self.name in env.typemap: return env.typemap[self.name] reporter.error(f"Unknown variable: {self.name!r}") - return 0 + return DEFAULT_TYPE @dataclass @@ -132,26 +103,36 @@ class IdxExpr: base: Expr idx: Expr - def typecheck(self, env: Environment) -> TypeCheck: + def typecheck(self, env: Environment) -> Type: base = self.base.typecheck(env) idx = self.idx.typecheck(env) - if isinstance(base, int): - reporter.error(f"Trying to index a constant value: {self.base!r}") - return 0 - if not isinstance(idx, int): - reporter.error(f"Trying to index with a non-constant: {self.idx!r}") - return 0 - return env.resolve_index(base, idx) + if not isinstance(idx, Range) or not idx.is_lit(): + reporter.error(f"Invalid index: {idx!r}") + return Range(-1, -1) + idxlit = idx.get_lit() + if not isinstance(base, list): + reporter.error(f"Indexing into non-array type: {self!r}") + return DEFAULT_TYPE + if not (0 <= idxlit < len(base)): + reporter.error(f"Index out of range {self!r}") + idxlit = 0 + return base[idxlit] @dataclass class CastExpr: base: Expr - type: "Type" + type: Type - def typecheck(self, env: Environment) -> TypeCheck: - _base = self.base.typecheck(env) - # TODO? encode/list valid casts + def typecheck(self, env: Environment) -> Type: + base = self.base.typecheck(env) + # TODO? Detect more sorts of invalid casts + baselen = len(base) if isinstance(base, list) else 1 + castlen = len(self.type) if isinstance(self.type, list) else 1 + reporter.asserts( + baselen >= castlen or (isinstance(base, Range) and base.is_lit()), + f"Casting from fewer columns to more: {self!r} {base} {self.type}", + ) return self.type @@ -159,10 +140,22 @@ def typecheck(self, env: Environment) -> TypeCheck: class MulExpr: factors: list[Expr] - def typecheck(self, env: Environment) -> TypeCheck: - t: TypeCheck = 0 + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + reporter.error(f"Multiplication of non-scalar types: {self!r}") + return DEFAULT_TYPE + elif isinstance(a, list): + return [self.type_match(x, b) for x in a] + elif isinstance(b, list): + return self.type_match(b, a) + else: + extrema = [x * y for x in [a.low, a.high] for y in [b.low, b.high]] + return Range(min(extrema), max(extrema)) + + def typecheck(self, env: Environment) -> Type: + t: Type = Range(1, 1) for f in self.factors: - t = type_match(t, f.typecheck(env), repr(self)) + t = self.type_match(t, f.typecheck(env)) return t @@ -170,10 +163,26 @@ def typecheck(self, env: Environment) -> TypeCheck: class AddExpr: terms: list[Expr] - def typecheck(self, env: Environment) -> TypeCheck: - t: TypeCheck = 0 - for term in self.terms: - t = type_match(t, term.typecheck(env), repr(self)) + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + assert False + reporter.error(f"Adding array types of different length {self!r}") + return [DEFAULT_TYPE for _ in range(len(b))] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Adding of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low + b.low, a.high + b.high) + + def typecheck(self, env: Environment) -> Type: + if not self.terms: + reporter.error("Empty add") + return Range(0, 0) + t: Type = self.terms[0].typecheck(env) + for term in self.terms[1:]: + t = self.type_match(t, term.typecheck(env)) return t @@ -182,10 +191,22 @@ class SubExpr: head: Expr subs: list[Expr] - def typecheck(self, env: Environment) -> TypeCheck: + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + reporter.error(f"Subtracting array types of different length {self!r}") + return [DEFAULT_TYPE for _ in range(len(a))] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Subtraction of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low - b.high, a.high - b.low) + + def typecheck(self, env: Environment) -> Type: t = self.head.typecheck(env) for term in self.subs: - t = type_match(t, term.typecheck(env), repr(self)) + t = self.type_match(t, term.typecheck(env)) return t @@ -194,18 +215,19 @@ class PowExpr: base: Expr exp: Expr - def typecheck(self, env: Environment) -> TypeCheck: + def typecheck(self, env: Environment) -> Type: base = self.base.typecheck(env) exp = self.exp.typecheck(env) - if not isinstance(base, int): + if isinstance(base, list) or not base.is_lit(): reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") - return 0 - if not isinstance(exp, int): + return DEFAULT_TYPE + if isinstance(exp, list) or not exp.is_lit(): reporter.error( f"Invalid exponentiation with non-const exponent: {self.exp!r}" ) - return 0 - return base**exp + return DEFAULT_TYPE + val = base.get_lit() ** exp.get_lit() + return Range(val, val) @dataclass @@ -213,10 +235,22 @@ class SumExpr: iter: "Iter" terms: Expr - def typecheck(self, env: Environment) -> TypeCheck: - t: TypeCheck = 0 + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + reporter.error(f"Summing array types of different length {self!r}") + return [DEFAULT_TYPE for _ in range(len(b))] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Summing of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low + b.low, a.high + b.high) + + def typecheck(self, env: Environment) -> Type: + t: Type = Range(0, 0) for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): - t = type_match(t, tc, repr(self)) + t = self.type_match(t, tc) return t @@ -224,23 +258,20 @@ def typecheck(self, env: Environment) -> TypeCheck: class NotExpr: inner: Expr - def typecheck(self, env: Environment) -> TypeCheck: + def typecheck(self, env: Environment) -> Type: inner = self.inner.typecheck(env) - if isinstance(inner, int): + if isinstance(inner, list) or not inner.is_bool(): reporter.asserts( inner in {0, 1}, f"Not a bool passed to `not`: {self.inner!r}" ) - return 1 - inner - reporter.asserts( - inner == Bit_type(), f"Not a bool passed to `not`: {self.inner!r}" - ) - return Bit_type() + return Range(0, 1) + return Range(1 - inner.high, 1 - inner.low) @dataclass class DummyExpr: - def typecheck(self, _env: Environment) -> TypeCheck: - return 0 + def typecheck(self, _env: Environment) -> Type: + return DEFAULT_TYPE def build_expr(config: Optional["Config"], data: object) -> Expr: @@ -257,7 +288,8 @@ def build_expr(config: Optional["Config"], data: object) -> Expr: return IdxExpr(build_expr(config, x), build_expr(config, y)) case ["cast", x, t]: assert config is not None - return CastExpr(build_expr(config, x), Type(config.variables.types, t)) + assert isinstance(t, (list, str)) + return CastExpr(build_expr(config, x), build_type(config, t)) case ["*", *factors]: return MulExpr([build_expr(config, f) for f in factors]) case ["+", *terms]: @@ -299,19 +331,24 @@ def typecheck[T]( self, env: Environment, callback: Callable[[Environment], Iterable[T]] ) -> Iterable[T]: start = self.start.typecheck(env) - if not isinstance(start, int): + if isinstance(start, list) or not start.is_lit(): reporter.error(f"Starting value of summation not a const: {self!r}") - start = 0 + start = Range(0, 0) stop = self.stop.typecheck(env) - if not isinstance(stop, int): + if isinstance(stop, list) or not stop.is_lit(): reporter.error(f"Ending value of summation not a const: {self!r}") - stop = 0 - - for i in range(start, stop + 1): - old_env = copy.deepcopy(env) - env.valmap[self.name] = i + stop = Range(start.get_lit(), start.get_lit()) + + # While it's tempting to replace this loop by an assignment of Range(start, stop + 1) to self.name + # that would break both detection of literals, and narrowing down to the correct type for indexing + # heterogenous array types + for i in range(start.get_lit(), stop.get_lit() + 1): + old_val: Optional[Range] = env.valmap.get(self.name, None) + env.valmap[self.name] = Range(i, i) yield from callback(env) - env = old_env + env.valmap.pop(self.name) + if old_val is not None: + env.valmap[self.name] = old_val def iters_of(obj: dict, name=None) -> list[Iter]: @@ -344,41 +381,46 @@ def clean_iter(it): return [] -@dataclass -class Type: - base: Union["Type", str] - dimension: Optional[int] - - def __init__(self, valid_types: Optional[list["TypeConfig"]], data: object): - match data: - case str(x): - reporter.asserts( - valid_types is None or x in [tc.label for tc in valid_types], - f"Invalid variable type: {x!r}", - ) - self.base = x - self.dimension = None - case [base, int(dim)]: - self.base = Type(valid_types, base) - self.dimension = dim - case other: - reporter.error(f"Unable to parse type: {other!r}") - - @dataclass class TypeConfig: label: str subtypes: list[Type] + range: Optional[Range] desc: str preprocessed: bool - def __init__(self, data: dict, valid_types=None): + def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.label = data["label"] - self.subtypes = [Type(valid_types, tp) for tp in data["subtypes"]] + if "range" in data: + reporter.asserts( + data["subtypes"] == [default_name], + f"Specified a range non a non-base composite type: {data!r}", + ) + reporter.asserts( + isinstance(data["range"], list) and len(data["range"]) == 2, + f"Invalid range: {data!r}", + ) + start, stop = data["range"] + if not isinstance(start, int): + reporter.error(f"Range start not an int: {data!r}") + start = 0 + if not isinstance(stop, int): + reporter.error(f"Range end not an int: {data!r}") + stop = start + self.range = Range(start, stop) + self.subtypes = [] + else: + self.range = None + self.subtypes = [lookup(tp) for tp in data["subtypes"]] self.desc = data["desc"] self.preprocessed = data.get("preprocessed", False) + def as_type(self): + if self.range is not None: + return self.range + return self.subtypes[:] + @dataclass class ConfigCategories: @@ -407,13 +449,20 @@ class ConfigVariables: def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.types = [] + base_type = None for tp in data["types"]: - if tp["subtypes"] == [tp["label"]]: - self.types.append(TypeConfig(tp, valid_types=None)) - else: - self.types.append(TypeConfig(tp, valid_types=self.types)) + if base_type is None: + base_type = tp["label"] + self.types.append(TypeConfig(base_type, self.lookup_type, tp)) self.categories = ConfigCategories(data["categories"]) + def lookup_type(self, typename: str) -> Type: + matches = [t for t in self.types if t.label == typename] + if len(matches) != 1: + reporter.error(f"Couldn't lookup type by name: {typename!r}") + return DEFAULT_TYPE + return matches[0].as_type() + @dataclass class ConfigMetadata: @@ -449,6 +498,16 @@ def from_string(cls, s: str) -> "Config": return cls(tomllib.loads(s)) +def build_type(config: Config, data: list | str): + if isinstance(data, list): + if len(data) != 2: + reporter.error(f"Invalid type: {data!r}") + return DEFAULT_TYPE + return [build_type(config, data[0]) for _ in range(data[1])] + else: + return config.variables.lookup_type(data) + + @dataclass class Variable: category: str @@ -464,7 +523,7 @@ def __init__(self, config: Config, category: str, data: dict): self.name = data["name"] reporter.asserts(isinstance(self.name, str), f"{self.name!r} is not a string") reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") - self.type = Type(config.variables.types, data["type"]) + self.type = build_type(config, data["type"]) self.desc = data["desc"] reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") self.pad = build_expr(None, data.get("pad", 0)) @@ -490,7 +549,6 @@ class VirtualDef: defs: list[tuple[list[Iter], Expr]] def __init__(self, config: Config, name: str, tp: Type, data: dict): - # TODO? More sanity checking the format (or is that duplicating work done in typst already) if "poly" in data: idx = data.get("idx", None) self.defs = [(iters_of(data, name=idx), build_expr(config, data["poly"]))] @@ -515,9 +573,11 @@ def __init__(self, config: Config, category: str, data: dict): super().__init__(config, category, data) self.def_ = VirtualDef(config, self.name, self.type, def_) - def typecheck(self, env: Environment) -> TypeCheck: + def typecheck(self, env: Environment) -> Type: # TODO - return 0 + # - Check no indices are covered twice and everything covered + # - Check type fits? At least structure should match + return DEFAULT_TYPE @dataclass @@ -558,18 +618,29 @@ def __init__(self, config: Config, data: dict): self.iters = iters_of(data) def typecheck(self, env: Environment) -> Iterable[Never]: - # TODO: is there any reason to typecheck if something is equatable to 0? - # Iteration for the side effect of typechecking and reporting errors - for _ in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): - pass + # TODO? Should we check that there's no overflow of the modulus? + # This would probably struggle due to things like one-hot invariants + + def check_includes_zero(t: Type): + if isinstance(t, Range): + reporter.asserts( + t.low <= 0 <= t.high, + f"Unsatisfiable constraint, 0 not in range: {self!r} {t}", + ) + else: + for sub in t: + check_includes_zero(sub) + + for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): + check_includes_zero(t) return [] @dataclass class TemplateSignature: tag: str - input: list[TypeCheck] - output: Optional[TypeCheck] + input: list[Type] + output: Optional[Type] @dataclass @@ -624,8 +695,8 @@ def callback(e: Environment) -> Iterable[TemplateSignature]: @dataclass class InteractionSignature: tag: str - input: list[TypeCheck] - output: Optional[TypeCheck] + input: list[Type] + output: Optional[Type] @dataclass @@ -734,7 +805,15 @@ def from_string(cls, config: Config, s: str) -> "Chip": return cls(config, tomllib.loads(s)) def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: - env = Environment(self.config, {}, {v.name: v.type for v in self.variables}) + typemap = {} + for v in self.variables: + if isinstance(v.type, list) and len(v.type) == 1: + t = v.type[0] + else: + t = v.type + typemap[v.name] = t + + env = Environment(self.config, {}, typemap) for v in self.variables: if isinstance(v, VirtualVariable): v.typecheck(env) From 9adb4fea276d34b1434f61f66de1fe4af8719381 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 12:22:55 +0100 Subject: [PATCH 06/15] Avoid casting to more limbs by leveraging scalar-array mult and literal casts --- spec/src/cpu.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index e415b9ae3..5906e876a 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -618,7 +618,7 @@ multiplicity = "SHIFT" [[constraints.alu]] kind = "template" tag = "ADD" -input = ["pc", ["cast", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], "DWordWL"]] +input = ["pc", ["*", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], ["cast", 1, "DWordWL"]]] output = ["cast", "res", "DWordWL"] cond = "JALR" @@ -817,6 +817,6 @@ multiplicity = "branch_cond" [[constraints.misc]] kind = "template" tag = "ADD" -input = ["pc", ["cast", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], "DWordWL"]] +input = ["pc", ["*", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], ["cast", 1, "DWordWL"]]] output = "next_pc" desc = "Increment `pc` to `next_pc` if we're not branching" From 96ebfbe0d7e8dd70665b93494cf45fee4e9385b1 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 15:56:53 +0100 Subject: [PATCH 07/15] toml fixes to pass type checks --- spec/src/branch.toml | 2 +- spec/src/cpu.toml | 2 +- spec/src/page.toml | 8 +++++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/spec/src/branch.toml b/spec/src/branch.toml index 0659479c5..beb3c1922 100644 --- a/spec/src/branch.toml +++ b/spec/src/branch.toml @@ -11,7 +11,7 @@ pad = 0 [[variables.input]] name = "offset" -type = "Word" +type = "DWordWL" desc = "The offset from the base address to jump to" pad = 0 diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index 5906e876a..994fda508 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -802,7 +802,7 @@ poly = ["+", ["-", "branch_cond"], "JALR", ["*", ["idx", "res", 0], ["not", "mp_selector"], "BLT"], - ["*", ["not", ["idx", "res", 0]], "mp_selector", "BLT"], + ["*", ["-", 1, ["idx", "res", 0]], "mp_selector", "BLT"], ["*", "is_equal", ["not", "mp_selector"], "BEQ"], ["*", ["not", "is_equal"], "mp_selector", "BEQ"] ] diff --git a/spec/src/page.toml b/spec/src/page.toml index 9ec7ed621..21ec76757 100644 --- a/spec/src/page.toml +++ b/spec/src/page.toml @@ -2,6 +2,12 @@ name = "PAGE" # Input +# TODO: add `page` as "constant" column or smth +[[variables.input]] +name = "page" +type = "DWordWL" +desc = "Constant column containing the page base address; should be integrated into the constraints directly" + [[variables.input]] name = "offset" type = "RowIndex" @@ -28,7 +34,7 @@ desc = "The timestamp at which this address was last accessed" name = "address" type = "DWordWL" desc = "Adding `offset` to the page base address `page`. `page` is a constant with respect to a single instance of this table." -def = ["+", "page", ["cast", "offset", "DWordWL"]] +def = ["+", "page", ["*", "offset", ["cast", 1, "DWordWL"]]] [[constraint_groups]] From 54127d8fab1bb81d9b6045a62b3a2b6a655e99ee Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 15:57:13 +0100 Subject: [PATCH 08/15] Type check virtual definitions properly now --- spec/tooling/chip.py | 116 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 109 insertions(+), 7 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index f303b8261..1c71b38e1 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -41,7 +41,7 @@ class Range: high: int def is_bool(self): - return self.is_lit and self.low >= 0 and self.high <= 1 + return self.low >= 0 and self.high <= 1 def is_lit(self): return self.low == self.high @@ -332,11 +332,11 @@ def typecheck[T]( ) -> Iterable[T]: start = self.start.typecheck(env) if isinstance(start, list) or not start.is_lit(): - reporter.error(f"Starting value of summation not a const: {self!r}") + reporter.error(f"Starting value of iterator not a const: {self!r}") start = Range(0, 0) stop = self.stop.typecheck(env) if isinstance(stop, list) or not stop.is_lit(): - reporter.error(f"Ending value of summation not a const: {self!r}") + reporter.error(f"Ending value of iterator not a const: {self!r}") stop = Range(start.get_lit(), start.get_lit()) # While it's tempting to replace this loop by an assignment of Range(start, stop + 1) to self.name @@ -574,10 +574,112 @@ def __init__(self, config: Config, category: str, data: dict): self.def_ = VirtualDef(config, self.name, self.type, def_) def typecheck(self, env: Environment) -> Type: - # TODO - # - Check no indices are covered twice and everything covered - # - Check type fits? At least structure should match - return DEFAULT_TYPE + def structure_match(a: Type, b: Type): + if isinstance(a, Range) and isinstance(b, (Range, type(None))): + return True + elif isinstance(a, list) and isinstance(b, list): + return len(a) == len(b) and all( + structure_match(x, y) for x, y in zip(a, b) + ) + else: + return False + + def handle_iters( + env: Environment, + iters: list[Iter], + poly: Expr, + expected: Type, + indices: list[int], + seen: set[tuple], + ): + if not iters: + asn = poly.typecheck(env) + # Check not doubly defined + for s in seen: + ln = min(len(s), len(indices)) + if s[:ln] == tuple(indices[:ln]): + reporter.error( + f"Double definition for virtual column: {self!r} at index {indices}" + ) + break + # check asn structure matches assigned + reporter.asserts( + structure_match(asn, expected), + f"Invalid structure for definition to virtual column: {self!r}", + ) + # Check type fits? + + seen.add(tuple(indices)) + else: + it, *its = iters + # Some duplicated code/concepts from Iter.typecheck + start = it.start.typecheck(env) + if isinstance(start, list) or not start.is_lit(): + reporter.error( + f"Starting value of virtual def iter not a const: {self!r}" + ) + start = Range(0, 0) + stop = it.stop.typecheck(env) + if isinstance(stop, list) or not stop.is_lit(): + reporter.error( + f"Ending value of virtual def iter not a const: {self!r}" + ) + stop = Range(start.get_lit(), start.get_lit()) + + for i in range(start.get_lit(), stop.get_lit() + 1): + if isinstance(expected, Range): + reporter.error( + f"Virtual definition has an iter for a scalar: {self!r}" + ) + break + if not 0 <= i < len(expected): + reporter.error( + f"Virtual definition index {i} out of range for {expected}: {self!r}" + ) + break + old_val: Optional[Range] = env.valmap.get(it.name, None) + env.valmap[it.name] = Range(i, i) + handle_iters(env, its, poly, expected[i], indices + [i], seen) + env.valmap.pop(it.name) + if old_val is not None: + env.valmap[it.name] = old_val + + def is_covered(seen: set[tuple], indices: list[int]) -> bool: + for s in seen: + if len(s) <= len(indices) and s == tuple(indices[:len(s)]): + return True + return False + + def check_covered(t: Type, seen: set[tuple], indices: list[int]): + if isinstance(t, Range): + reporter.asserts(is_covered(seen, indices), f"Virtual column {self.name!r} not completely defined") + else: + for i in range(len(t)): + check_covered(t[i], seen, indices + [i]) + + # Special case for better error messages + if isinstance(self.type, Range): + reporter.asserts( + len(self.def_.defs) == 1 and not self.def_.defs[0][0], + f"Invalid def for scalar column: {self!r}", + ) + assigned_type = self.def_.defs[0][1].typecheck(env) + if not isinstance(assigned_type, Range): + reporter.error( + f"Assigning non-scalar type to scalar virtual column: {self!r}" + ) + return self.type + # Check type fits? + # Leaving this out because it produces too much noise with one-hot assumptions + # reporter.asserts(self.type.low <= assigned_type.low <= assigned_type.high <= self.type.high, f"Definition may not fit in virtual column: {self!r}") + else: + # Check no indices are covered twice + seen: set[tuple] = set() + for iters, poly in self.def_.defs: + handle_iters(env, iters, poly, self.type, [], seen) + # Check everything is covered + check_covered(self.type, seen, []) + return self.type @dataclass From b9a218beb29270a273e2205d24a6c8623440a179 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 16:01:56 +0100 Subject: [PATCH 09/15] ruff format --- spec/tooling/chip.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 1c71b38e1..d7419b40f 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -646,13 +646,16 @@ def handle_iters( def is_covered(seen: set[tuple], indices: list[int]) -> bool: for s in seen: - if len(s) <= len(indices) and s == tuple(indices[:len(s)]): + if len(s) <= len(indices) and s == tuple(indices[: len(s)]): return True return False def check_covered(t: Type, seen: set[tuple], indices: list[int]): if isinstance(t, Range): - reporter.asserts(is_covered(seen, indices), f"Virtual column {self.name!r} not completely defined") + reporter.asserts( + is_covered(seen, indices), + f"Virtual column {self.name!r} not completely defined", + ) else: for i in range(len(t)): check_covered(t[i], seen, indices + [i]) From 223fa81b6fc098f5c42084e94f4ca92315f5d963 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Thu, 5 Feb 2026 16:30:43 +0100 Subject: [PATCH 10/15] Make typst compile by turning big range values to string --- spec/src/config.toml | 2 +- spec/tooling/chip.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spec/src/config.toml b/spec/src/config.toml index af6b34971..dbcbc50ff 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -4,7 +4,7 @@ version = 1 [[variables.types]] label = "BaseField" subtypes = ["BaseField"] -range = [0, 18446744069414584320] +range = [0, "18446744069414584320"] desc = "Variable that can assume any value in the base field." [[variables.types]] diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index d7419b40f..44e7d28ad 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -402,13 +402,13 @@ def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict) f"Invalid range: {data!r}", ) start, stop = data["range"] - if not isinstance(start, int): + if not isinstance(start, int) and not (isinstance(start, str) and start.isdigit()): reporter.error(f"Range start not an int: {data!r}") start = 0 - if not isinstance(stop, int): + if not isinstance(stop, int) and not (isinstance(stop, str) and stop.isdigit()): reporter.error(f"Range end not an int: {data!r}") stop = start - self.range = Range(start, stop) + self.range = Range(int(start), int(stop)) self.subtypes = [] else: self.range = None From d4b9fda80f9eba50cca2a26dcddaaa30a950a4d4 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Mon, 9 Feb 2026 11:25:09 +0100 Subject: [PATCH 11/15] Switch some isinstance checks around to make both mypy and ty work --- spec/tooling/chip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 44e7d28ad..d85d07c69 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -110,7 +110,7 @@ def typecheck(self, env: Environment) -> Type: reporter.error(f"Invalid index: {idx!r}") return Range(-1, -1) idxlit = idx.get_lit() - if not isinstance(base, list): + if isinstance(base, Range): reporter.error(f"Indexing into non-array type: {self!r}") return DEFAULT_TYPE if not (0 <= idxlit < len(base)): @@ -144,7 +144,7 @@ def type_match(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): reporter.error(f"Multiplication of non-scalar types: {self!r}") return DEFAULT_TYPE - elif isinstance(a, list): + elif not isinstance(a, Range): return [self.type_match(x, b) for x in a] elif isinstance(b, list): return self.type_match(b, a) From 7ef197062860e57211d8c3fa81607437ffff948c Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Mon, 9 Feb 2026 11:46:12 +0100 Subject: [PATCH 12/15] Fix issues after rebasing on spec/main --- spec/src/dvrm.toml | 8 ++++---- spec/src/shift.toml | 2 +- spec/tooling/chip.py | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/spec/src/dvrm.toml b/spec/src/dvrm.toml index d93449228..52583907c 100644 --- a/spec/src/dvrm.toml +++ b/spec/src/dvrm.toml @@ -376,13 +376,13 @@ desc = "Each row contributes the following to the LogUp sum" [[constraints.output]] kind = "interaction" tag = "DVRM" -input = ["n", "d", "signed", "0"] +input = ["n", "d", "signed", 0] output = ["cast", "q", "DWordWL"] -multiplicity = "-μ_q" +multiplicity = ["-", "μ_q"] [[constraints.output]] kind = "interaction" tag = "DVRM" -input = ["n", "d", "signed", "1"] +input = ["n", "d", "signed", 1] output = ["cast", "r", "DWordWL"] -multiplicity = "-μ_r" \ No newline at end of file +multiplicity = ["-", "μ_r"] diff --git a/spec/src/shift.toml b/spec/src/shift.toml index e2ca8d12d..bd6c471a6 100644 --- a/spec/src/shift.toml +++ b/spec/src/shift.toml @@ -203,7 +203,7 @@ tag = "ZERO" input = ["bit_shift"] output = "zbs" ref = "shift:c:zbs" -cond = "μ" +multiplicity = "μ" [[constraint_groups]] diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index d85d07c69..c17686c3d 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -821,6 +821,7 @@ def __init__(self, config: Config, data: dict): self.output = build_expr(config, data["output"]) else: self.output = None + assert "multiplicity" in data, data self.multiplicity = build_expr(config, data["multiplicity"]) self.iters = iters_of(data) @@ -928,12 +929,13 @@ def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: if __name__ == "__main__": config = Config.from_file(sys.argv[1]) + signatures = sys.argv[2] # Later if reporter.reported: sys.exit(1) reported = False chips: list[Chip] = [] - for file in sys.argv[2:]: - if file == sys.argv[1]: + for file in sys.argv[3:]: + if file in sys.argv[1:3]: continue chips.append(Chip.from_file(config, file)) reported = reported or reporter.reported From 992c5b35a9faf5983a33947fec138315227a9dc0 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Mon, 9 Feb 2026 15:34:53 +0100 Subject: [PATCH 13/15] Address review comments --- spec/src/config.toml | 34 ++--- spec/tooling/chip.py | 308 ++++++++++++++++++++++++------------------- 2 files changed, 184 insertions(+), 158 deletions(-) diff --git a/spec/src/config.toml b/spec/src/config.toml index dbcbc50ff..0f6ef11d6 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -43,22 +43,6 @@ subtypes = ["BaseField"] range = [0, 4294967295] desc = "Variable that can only assume values in the range $[0, 2^32)$." -[[variables.types]] -label = "WordHL" -subtypes = ["Half", "Half"] -desc = """\ - Variable that can only assume values in the range $[0, 2^32)$. \\ - Represented as an array of two `Half` variables.\ - """ - -[[variables.types]] -label = "WordBL" -subtypes = ["Byte", "Byte", "Byte", "Byte"] -desc = """\ - Variable that can only assume values in the range $[0, 2^32)$. \\ - Represented as an array of four `Byte` variables.\ - """ - [[variables.types]] label = "B51" subtypes = ["BaseField"] @@ -98,6 +82,15 @@ desc = """\ The `Word` is the *least* significant digit. """ +[[variables.types]] +label = "DWordWHH" +subtypes = ["Half", "Half", "Word"] +desc = """\ + Variable that can only assume values in the range $[0, 2^64)$. \\ + Represented as a `Word` and two `Half` variables.\ + The `Word` is the *most* significant digit. + """ + [[variables.types]] label = "QuadHL" subtypes = ["Half", "Half", "Half", "Half", "Half", "Half", "Half", "Half"] @@ -114,15 +107,6 @@ desc = """\ Represented as an array of four `Word` variables.\ """ -[[variables.types]] -label = "DWordWHH" -subtypes = ["Half", "Half", "Word"] -desc = """\ - Variable that can only assume values in the range $[0, 2^64)$. \\ - Represented as a `Word` and two `Half` variables.\ - The `Word` is the *most* significant digit. - """ - [[variables.types]] label = "Timestamp" subtypes = ["DWordWL"] diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index c17686c3d..dfdb59e42 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -1,9 +1,9 @@ -from pathlib import Path import sys import tomllib from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Optional, Never +from pathlib import Path +from typing import Never, Optional, Self class ErrorReporter: @@ -40,6 +40,10 @@ class Range: low: int high: int + @classmethod + def lit(cls, x: int) -> Self: + return cls(x, x) + def is_bool(self): return self.low >= 0 and self.high <= 1 @@ -53,7 +57,7 @@ def get_lit(self) -> int: type Type = list[Type] | Range -DEFAULT_TYPE: Type = Range(0, 0) +DEFAULT_TYPE: Type = Range.lit(0) type Expr = ( LitExpr @@ -76,13 +80,16 @@ class Environment: valmap: dict[str, Range] typemap: dict[str, Type] + def with_val(self, key: str, val: Range) -> Self: + return type(self)(self.config, {**self.valmap, key: val}, self.typemap) + @dataclass class LitExpr: lit: int def typecheck(self, _env: Environment) -> Type: - return Range(self.lit, self.lit) + return Range.lit(self.lit) @dataclass @@ -98,6 +105,15 @@ def typecheck(self, env: Environment) -> Type: return DEFAULT_TYPE +@dataclass +class ArrExpr: + elems: list[Expr] + + def typecheck(self, env: Environment) -> Type: + reporter.asserts(self.elems != [], f"Empty array: {self!r}") + return [e.typecheck(env) for e in self.elems] + + @dataclass class IdxExpr: base: Expr @@ -108,7 +124,7 @@ def typecheck(self, env: Environment) -> Type: idx = self.idx.typecheck(env) if not isinstance(idx, Range) or not idx.is_lit(): reporter.error(f"Invalid index: {idx!r}") - return Range(-1, -1) + return Range.lit(-1) idxlit = idx.get_lit() if isinstance(base, Range): reporter.error(f"Indexing into non-array type: {self!r}") @@ -153,7 +169,8 @@ def type_match(self, a: Type, b: Type) -> Type: return Range(min(extrema), max(extrema)) def typecheck(self, env: Environment) -> Type: - t: Type = Range(1, 1) + reporter.asserts(self.factors != [], f"Empty product: {self!r}") + t: Type = Range.lit(1) for f in self.factors: t = self.type_match(t, f.typecheck(env)) return t @@ -166,9 +183,8 @@ class AddExpr: def type_match(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): - assert False reporter.error(f"Adding array types of different length {self!r}") - return [DEFAULT_TYPE for _ in range(len(b))] + return [DEFAULT_TYPE for _ in b] return [self.type_match(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Adding of scalar and array types {self!r}") @@ -179,7 +195,7 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: if not self.terms: reporter.error("Empty add") - return Range(0, 0) + return Range.lit(0) t: Type = self.terms[0].typecheck(env) for term in self.terms[1:]: t = self.type_match(t, term.typecheck(env)) @@ -195,7 +211,7 @@ def type_match(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Subtracting array types of different length {self!r}") - return [DEFAULT_TYPE for _ in range(len(a))] + return [DEFAULT_TYPE for _ in a] return [self.type_match(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Subtraction of scalar and array types {self!r}") @@ -205,6 +221,11 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: t = self.head.typecheck(env) + if not self.subs: + if not isinstance(t, Range): + reporter.error(f"Negating a non-scalar type: {self!r}") + return t + return Range(-t.high, -t.low) for term in self.subs: t = self.type_match(t, term.typecheck(env)) return t @@ -226,8 +247,8 @@ def typecheck(self, env: Environment) -> Type: f"Invalid exponentiation with non-const exponent: {self.exp!r}" ) return DEFAULT_TYPE - val = base.get_lit() ** exp.get_lit() - return Range(val, val) + val = pow(base.get_lit(), exp.get_lit(), env.config.variables.prime) + return Range.lit(val) @dataclass @@ -239,7 +260,7 @@ def type_match(self, a: Type, b: Type) -> Type: if isinstance(a, list) and isinstance(b, list): if len(a) != len(b): reporter.error(f"Summing array types of different length {self!r}") - return [DEFAULT_TYPE for _ in range(len(b))] + return [DEFAULT_TYPE for _ in b] return [self.type_match(x, y) for x, y in zip(a, b)] elif isinstance(a, list) or isinstance(b, list): reporter.error(f"Summing of scalar and array types {self!r}") @@ -248,7 +269,7 @@ def type_match(self, a: Type, b: Type) -> Type: return Range(a.low + b.low, a.high + b.high) def typecheck(self, env: Environment) -> Type: - t: Type = Range(0, 0) + t: Type = Range.lit(0) for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): t = self.type_match(t, tc) return t @@ -261,9 +282,7 @@ class NotExpr: def typecheck(self, env: Environment) -> Type: inner = self.inner.typecheck(env) if isinstance(inner, list) or not inner.is_bool(): - reporter.asserts( - inner in {0, 1}, f"Not a bool passed to `not`: {self.inner!r}" - ) + reporter.error(f"Not a bool passed to `not`: {self.inner!r}") return Range(0, 1) return Range(1 - inner.high, 1 - inner.low) @@ -333,22 +352,17 @@ def typecheck[T]( start = self.start.typecheck(env) if isinstance(start, list) or not start.is_lit(): reporter.error(f"Starting value of iterator not a const: {self!r}") - start = Range(0, 0) + start = Range.lit(0) stop = self.stop.typecheck(env) if isinstance(stop, list) or not stop.is_lit(): reporter.error(f"Ending value of iterator not a const: {self!r}") - stop = Range(start.get_lit(), start.get_lit()) + stop = Range.lit(start.get_lit()) # While it's tempting to replace this loop by an assignment of Range(start, stop + 1) to self.name # that would break both detection of literals, and narrowing down to the correct type for indexing # heterogenous array types for i in range(start.get_lit(), stop.get_lit() + 1): - old_val: Optional[Range] = env.valmap.get(self.name, None) - env.valmap[self.name] = Range(i, i) - yield from callback(env) - env.valmap.pop(self.name) - if old_val is not None: - env.valmap[self.name] = old_val + yield from callback(env.with_val(self.name, Range.lit(i))) def iters_of(obj: dict, name=None) -> list[Iter]: @@ -395,19 +409,24 @@ def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict) if "range" in data: reporter.asserts( data["subtypes"] == [default_name], - f"Specified a range non a non-base composite type: {data!r}", + f"Specified a range on a non-base composite type: {data!r}", ) reporter.asserts( isinstance(data["range"], list) and len(data["range"]) == 2, f"Invalid range: {data!r}", ) start, stop = data["range"] - if not isinstance(start, int) and not (isinstance(start, str) and start.isdigit()): + if not isinstance(start, int) and not ( + isinstance(start, str) and start.isdigit() + ): reporter.error(f"Range start not an int: {data!r}") start = 0 - if not isinstance(stop, int) and not (isinstance(stop, str) and stop.isdigit()): + if not isinstance(stop, int) and not ( + isinstance(stop, str) and stop.isdigit() + ): reporter.error(f"Range end not an int: {data!r}") stop = start + reporter.asserts(int(start) <= int(stop), f"Inverted range: {data!r}") self.range = Range(int(start), int(stop)) self.subtypes = [] else: @@ -416,10 +435,8 @@ def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict) self.desc = data["desc"] self.preprocessed = data.get("preprocessed", False) - def as_type(self): - if self.range is not None: - return self.range - return self.subtypes[:] + def as_type(self) -> Type: + return self.range or self.subtypes[:] @dataclass @@ -439,22 +456,28 @@ def __init__(self, data: dict): all(isinstance(v, str) for v in self.instantiated), f"Something's not a string: {self.instantiated}", ) + reporter.asserts( + set(self.instantiated) < set(self.all), + f"Instantiated not a subset of all: {self!r}", + ) @dataclass class ConfigVariables: types: list[TypeConfig] categories: ConfigCategories + prime: int def __init__(self, data: dict): assert_no_unexpected(data, type(self).__annotations__.keys()) self.types = [] - base_type = None + base_type = data["types"][0]["label"] for tp in data["types"]: - if base_type is None: - base_type = tp["label"] self.types.append(TypeConfig(base_type, self.lookup_type, tp)) self.categories = ConfigCategories(data["categories"]) + basefield = self.lookup_type(base_type) + assert isinstance(basefield, Range) + self.prime = basefield.high + 1 def lookup_type(self, typename: str) -> Type: matches = [t for t in self.types if t.label == typename] @@ -488,12 +511,12 @@ def __init__(self, data: dict): self.variables = ConfigVariables(data["variables"]) @classmethod - def from_file(cls, filename: str | Path) -> "Config": + def from_file(cls, filename: str | Path) -> Self: reporter.update_location(str(filename)) return cls(tomllib.load(open(filename, "rb"))) @classmethod - def from_string(cls, s: str) -> "Config": + def from_string(cls, s: str) -> Self: reporter.update_location("") return cls(tomllib.loads(s)) @@ -543,23 +566,35 @@ def all_iters[T]( yield from its[0].typecheck(env, lambda e: all_iters(its[1:], e, callback)) +@dataclass +class PolyWithIters: + poly: Expr + iters: list[Iter] + + @dataclass class VirtualDef: # A list of polynomials with each a set of iters they range over - defs: list[tuple[list[Iter], Expr]] + defs: list[PolyWithIters] def __init__(self, config: Config, name: str, tp: Type, data: dict): if "poly" in data: idx = data.get("idx", None) - self.defs = [(iters_of(data, name=idx), build_expr(config, data["poly"]))] + self.defs = [ + PolyWithIters( + build_expr(config, data["poly"]), iters_of(data, name=idx) + ) + ] elif "polys" in data: idx = data.get("idx", None) self.defs = [ - (iters_of(poly, name=idx), build_expr(config, poly["poly"])) + PolyWithIters( + build_expr(config, poly["poly"]), iters_of(poly, name=idx) + ) for poly in data["polys"] ] else: - self.defs = [([], build_expr(config, data))] + self.defs = [PolyWithIters(build_expr(config, data), [])] @dataclass @@ -574,12 +609,12 @@ def __init__(self, config: Config, category: str, data: dict): self.def_ = VirtualDef(config, self.name, self.type, def_) def typecheck(self, env: Environment) -> Type: - def structure_match(a: Type, b: Type): + def structure_matches(a: Type, b: Type) -> bool: if isinstance(a, Range) and isinstance(b, (Range, type(None))): return True elif isinstance(a, list) and isinstance(b, list): return len(a) == len(b) and all( - structure_match(x, y) for x, y in zip(a, b) + structure_matches(x, y) for x, y in zip(a, b) ) else: return False @@ -593,7 +628,6 @@ def handle_iters( seen: set[tuple], ): if not iters: - asn = poly.typecheck(env) # Check not doubly defined for s in seen: ln = min(len(s), len(indices)) @@ -602,9 +636,11 @@ def handle_iters( f"Double definition for virtual column: {self!r} at index {indices}" ) break - # check asn structure matches assigned + + val = poly.typecheck(env) + # check val structure matches assigned reporter.asserts( - structure_match(asn, expected), + structure_matches(val, expected), f"Invalid structure for definition to virtual column: {self!r}", ) # Check type fits? @@ -613,36 +649,41 @@ def handle_iters( else: it, *its = iters # Some duplicated code/concepts from Iter.typecheck + # But threading the extra needed state through overly complicates everything start = it.start.typecheck(env) if isinstance(start, list) or not start.is_lit(): reporter.error( f"Starting value of virtual def iter not a const: {self!r}" ) - start = Range(0, 0) + start = Range.lit(0) stop = it.stop.typecheck(env) if isinstance(stop, list) or not stop.is_lit(): reporter.error( f"Ending value of virtual def iter not a const: {self!r}" ) - stop = Range(start.get_lit(), start.get_lit()) + stop = Range.lit(start.get_lit()) + + if isinstance(expected, Range): + reporter.error( + f"Virtual definition has an iter for a scalar: {self!r}" + ) + return + + if not 0 <= start.get_lit() <= stop.get_lit() < len(expected): + reporter.error( + f"Virtual definition index [{start.get_lit()}, {stop.get_lit()}] out of range for {expected}: {self!r}" + ) + return for i in range(start.get_lit(), stop.get_lit() + 1): - if isinstance(expected, Range): - reporter.error( - f"Virtual definition has an iter for a scalar: {self!r}" - ) - break - if not 0 <= i < len(expected): - reporter.error( - f"Virtual definition index {i} out of range for {expected}: {self!r}" - ) - break - old_val: Optional[Range] = env.valmap.get(it.name, None) - env.valmap[it.name] = Range(i, i) - handle_iters(env, its, poly, expected[i], indices + [i], seen) - env.valmap.pop(it.name) - if old_val is not None: - env.valmap[it.name] = old_val + handle_iters( + env.with_val(it.name, Range.lit(i)), + its, + poly, + expected[i], + indices + [i], + seen, + ) def is_covered(seen: set[tuple], indices: list[int]) -> bool: for s in seen: @@ -657,16 +698,16 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): f"Virtual column {self.name!r} not completely defined", ) else: - for i in range(len(t)): - check_covered(t[i], seen, indices + [i]) + for i, elt in enumerate(t): + check_covered(elt, seen, indices + [i]) # Special case for better error messages if isinstance(self.type, Range): reporter.asserts( - len(self.def_.defs) == 1 and not self.def_.defs[0][0], + len(self.def_.defs) == 1 and not self.def_.defs[0].iters, f"Invalid def for scalar column: {self!r}", ) - assigned_type = self.def_.defs[0][1].typecheck(env) + assigned_type = self.def_.defs[0].poly.typecheck(env) if not isinstance(assigned_type, Range): reporter.error( f"Assigning non-scalar type to scalar virtual column: {self!r}" @@ -678,8 +719,10 @@ def check_covered(t: Type, seen: set[tuple], indices: list[int]): else: # Check no indices are covered twice seen: set[tuple] = set() - for iters, poly in self.def_.defs: - handle_iters(env, iters, poly, self.type, [], seen) + for poly_iters in self.def_.defs: + handle_iters( + env, poly_iters.iters, poly_iters.poly, self.type, [], seen + ) # Check everything is covered check_covered(self.type, seen, []) return self.type @@ -733,8 +776,9 @@ def check_includes_zero(t: Type): f"Unsatisfiable constraint, 0 not in range: {self!r} {t}", ) else: - for sub in t: - check_includes_zero(sub) + reporter.error( + f"Non-scalar value for polynomial constraing: {self!r} {t}" + ) for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): check_includes_zero(t) @@ -742,26 +786,42 @@ def check_includes_zero(t: Type): @dataclass -class TemplateSignature: +class Signature: tag: str input: list[Type] output: Optional[Type] @dataclass -class TemplateConstraint: +class InteractionLike: + kind: str + conditional_name: str + conditional_required: bool + signature: type[Signature] + tag: str desc: str input: list[Expr] output: Optional[Expr] - cond: Optional[Expr] + conditional: Optional[Expr] iters: list[Iter] def __init__(self, config: Config, data: dict): assert_no_unexpected( - data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} + data, + { + "tag", + "desc", + "input", + "output", + self.conditional_name, + "kind", + "ref", + "iter", + "iters", + }, ) - assert data["kind"] == "template" + assert data["kind"] == self.kind self.tag = data["tag"] reporter.asserts( isinstance(self.tag, str), f"tag is not a string: {self.tag!r}" @@ -775,19 +835,23 @@ def __init__(self, config: Config, data: dict): self.output = build_expr(config, data["output"]) else: self.output = None - if "cond" in data: - self.cond = build_expr(config, data["cond"]) + if self.conditional_name in data: + self.conditional = build_expr(config, data[self.conditional_name]) else: - self.cond = None + reporter.asserts( + not self.conditional_required, + f"Missing {self.conditional_name}: {data!r}", + ) + self.conditional = None self.iters = iters_of(data) - def typecheck(self, env: Environment) -> Iterable[TemplateSignature]: - def callback(e: Environment) -> Iterable[TemplateSignature]: - # TODO: Should we be able to check cond somehow? - if self.cond is not None: - self.cond.typecheck(e) + def typecheck(self, env: Environment) -> Iterable[Signature]: + def callback(e: Environment) -> Iterable[Signature]: + # TODO: Should we be able to check cond/multiplicity somehow? + if self.conditional is not None: + self.conditional.typecheck(e) return [ - TemplateSignature( + self.signature( self.tag, [inp.typecheck(e) for inp in self.input], self.output.typecheck(e) if self.output else None, @@ -797,47 +861,26 @@ def callback(e: Environment) -> Iterable[TemplateSignature]: return all_iters(self.iters, env, callback) -@dataclass -class InteractionSignature: - tag: str - input: list[Type] - output: Optional[Type] +class TemplateSignature(Signature): + pass -@dataclass -class InteractionConstraint: - tag: str - input: list[Expr] - output: Optional[Expr] - multiplicity: Expr - iters: list[Iter] +class TemplateConstraint(InteractionLike): + kind = "template" + conditional_name = "cond" + conditional_required = False + signature = TemplateSignature - def __init__(self, config: Config, data: dict): - assert data["kind"] == "interaction" - self.tag = data["tag"] - reporter.asserts(isinstance(self.tag, str), f"tag {self.tag!r} is not a string") - self.input = [build_expr(config, inp) for inp in data["input"]] - if "output" in data: - self.output = build_expr(config, data["output"]) - else: - self.output = None - assert "multiplicity" in data, data - self.multiplicity = build_expr(config, data["multiplicity"]) - self.iters = iters_of(data) - def typecheck(self, env: Environment) -> Iterable[InteractionSignature]: - def callback(e: Environment) -> Iterable[InteractionSignature]: - # TODO: Should we be able to check multiplicity somehow? - self.multiplicity.typecheck(e) - return [ - InteractionSignature( - self.tag, - [inp.typecheck(e) for inp in self.input], - self.output.typecheck(e) if self.output else None, - ) - ] +class InteractionSignature(Signature): + pass - return all_iters(self.iters, env, callback) + +class InteractionConstraint(InteractionLike): + kind = "interaction" + conditional_name = "multiplicity" + conditional_required = True + signature = InteractionSignature @dataclass @@ -889,28 +932,26 @@ def __init__(self, config: Config, data: dict): for cat, vars in data["variables"].items() for var in vars ] - self.assumptions = [ - Assumption(config, asm) for asm in data.get("assumptions", []) - ] + self.assumptions = [Assumption(config, a) for a in data.get("assumptions", [])] constraint_groups = [grp["name"] for grp in data.get("constraint_groups", [])] assert_no_unexpected(data.get("constraints", {}), constraint_groups) self.constraints = [ - build_constraint(config, con) + build_constraint(config, constraint) for group in data.get("constraints", {}).values() - for con in group + for constraint in group ] @classmethod - def from_file(cls, config: Config, filename: str | Path) -> "Chip": + def from_file(cls, config: Config, filename: str | Path) -> Self: reporter.update_location(str(filename)) return cls(config, tomllib.load(open(filename, "rb"))) @classmethod - def from_string(cls, config: Config, s: str) -> "Chip": + def from_string(cls, config: Config, s: str) -> Self: reporter.update_location("") return cls(config, tomllib.loads(s)) - def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: + def typecheck(self) -> Iterable[Signature]: typemap = {} for v in self.variables: if isinstance(v.type, list) and len(v.type) == 1: @@ -929,7 +970,7 @@ def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: if __name__ == "__main__": config = Config.from_file(sys.argv[1]) - signatures = sys.argv[2] # Later + signatures = sys.argv[2] # Later if reporter.reported: sys.exit(1) reported = False @@ -938,9 +979,10 @@ def typecheck(self) -> Iterable[TemplateSignature | InteractionSignature]: if file in sys.argv[1:3]: continue chips.append(Chip.from_file(config, file)) - reported = reported or reporter.reported + reported |= reporter.reported if not reported: for chip in chips: reporter.update_location(f"Chip {chip.name}") # TODO: do something with the signatures - (list(chip.typecheck())) + # Use list for the sideeffect of forcing the generator until we use the content + list(chip.typecheck()) From ddeb505964b43d653758b051cc5963d3e8b722f8 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 10 Feb 2026 14:08:37 +0100 Subject: [PATCH 14/15] Review comments --- spec/tooling/chip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index dfdb59e42..6971a9029 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -457,7 +457,7 @@ def __init__(self, data: dict): f"Something's not a string: {self.instantiated}", ) reporter.asserts( - set(self.instantiated) < set(self.all), + set(self.instantiated) <= set(self.all), f"Instantiated not a subset of all: {self!r}", ) @@ -777,7 +777,7 @@ def check_includes_zero(t: Type): ) else: reporter.error( - f"Non-scalar value for polynomial constraing: {self!r} {t}" + f"Non-scalar value for polynomial constraint: {self!r} {t}" ) for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): From 3d846e6677323522be00755496877457251b740c Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Tue, 10 Feb 2026 15:51:07 +0100 Subject: [PATCH 15/15] lit -> const --- spec/tooling/chip.py | 70 ++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 6971a9029..8a15ae338 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -41,23 +41,23 @@ class Range: high: int @classmethod - def lit(cls, x: int) -> Self: + def const(cls, x: int) -> Self: return cls(x, x) def is_bool(self): return self.low >= 0 and self.high <= 1 - def is_lit(self): + def is_const(self): return self.low == self.high - def get_lit(self) -> int: - assert self.is_lit() + def get_const(self) -> int: + assert self.is_const() return self.low type Type = list[Type] | Range -DEFAULT_TYPE: Type = Range.lit(0) +DEFAULT_TYPE: Type = Range.const(0) type Expr = ( LitExpr @@ -89,7 +89,7 @@ class LitExpr: lit: int def typecheck(self, _env: Environment) -> Type: - return Range.lit(self.lit) + return Range.const(self.lit) @dataclass @@ -122,17 +122,17 @@ class IdxExpr: def typecheck(self, env: Environment) -> Type: base = self.base.typecheck(env) idx = self.idx.typecheck(env) - if not isinstance(idx, Range) or not idx.is_lit(): + if not isinstance(idx, Range) or not idx.is_const(): reporter.error(f"Invalid index: {idx!r}") - return Range.lit(-1) - idxlit = idx.get_lit() + return Range.const(-1) + idxconst = idx.get_const() if isinstance(base, Range): reporter.error(f"Indexing into non-array type: {self!r}") return DEFAULT_TYPE - if not (0 <= idxlit < len(base)): + if not (0 <= idxconst < len(base)): reporter.error(f"Index out of range {self!r}") - idxlit = 0 - return base[idxlit] + idxconst = 0 + return base[idxconst] @dataclass @@ -146,7 +146,7 @@ def typecheck(self, env: Environment) -> Type: baselen = len(base) if isinstance(base, list) else 1 castlen = len(self.type) if isinstance(self.type, list) else 1 reporter.asserts( - baselen >= castlen or (isinstance(base, Range) and base.is_lit()), + baselen >= castlen or (isinstance(base, Range) and base.is_const()), f"Casting from fewer columns to more: {self!r} {base} {self.type}", ) return self.type @@ -170,7 +170,7 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: reporter.asserts(self.factors != [], f"Empty product: {self!r}") - t: Type = Range.lit(1) + t: Type = Range.const(1) for f in self.factors: t = self.type_match(t, f.typecheck(env)) return t @@ -195,7 +195,7 @@ def type_match(self, a: Type, b: Type) -> Type: def typecheck(self, env: Environment) -> Type: if not self.terms: reporter.error("Empty add") - return Range.lit(0) + return Range.const(0) t: Type = self.terms[0].typecheck(env) for term in self.terms[1:]: t = self.type_match(t, term.typecheck(env)) @@ -239,16 +239,16 @@ class PowExpr: def typecheck(self, env: Environment) -> Type: base = self.base.typecheck(env) exp = self.exp.typecheck(env) - if isinstance(base, list) or not base.is_lit(): + if isinstance(base, list) or not base.is_const(): reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") return DEFAULT_TYPE - if isinstance(exp, list) or not exp.is_lit(): + if isinstance(exp, list) or not exp.is_const(): reporter.error( f"Invalid exponentiation with non-const exponent: {self.exp!r}" ) return DEFAULT_TYPE - val = pow(base.get_lit(), exp.get_lit(), env.config.variables.prime) - return Range.lit(val) + val = pow(base.get_const(), exp.get_const(), env.config.variables.prime) + return Range.const(val) @dataclass @@ -269,7 +269,7 @@ def type_match(self, a: Type, b: Type) -> Type: return Range(a.low + b.low, a.high + b.high) def typecheck(self, env: Environment) -> Type: - t: Type = Range.lit(0) + t: Type = Range.const(0) for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): t = self.type_match(t, tc) return t @@ -350,19 +350,19 @@ def typecheck[T]( self, env: Environment, callback: Callable[[Environment], Iterable[T]] ) -> Iterable[T]: start = self.start.typecheck(env) - if isinstance(start, list) or not start.is_lit(): + if isinstance(start, list) or not start.is_const(): reporter.error(f"Starting value of iterator not a const: {self!r}") - start = Range.lit(0) + start = Range.const(0) stop = self.stop.typecheck(env) - if isinstance(stop, list) or not stop.is_lit(): + if isinstance(stop, list) or not stop.is_const(): reporter.error(f"Ending value of iterator not a const: {self!r}") - stop = Range.lit(start.get_lit()) + stop = Range.const(start.get_const()) # While it's tempting to replace this loop by an assignment of Range(start, stop + 1) to self.name - # that would break both detection of literals, and narrowing down to the correct type for indexing + # that would break both detection of consts, and narrowing down to the correct type for indexing # heterogenous array types - for i in range(start.get_lit(), stop.get_lit() + 1): - yield from callback(env.with_val(self.name, Range.lit(i))) + for i in range(start.get_const(), stop.get_const() + 1): + yield from callback(env.with_val(self.name, Range.const(i))) def iters_of(obj: dict, name=None) -> list[Iter]: @@ -651,17 +651,17 @@ def handle_iters( # Some duplicated code/concepts from Iter.typecheck # But threading the extra needed state through overly complicates everything start = it.start.typecheck(env) - if isinstance(start, list) or not start.is_lit(): + if isinstance(start, list) or not start.is_const(): reporter.error( f"Starting value of virtual def iter not a const: {self!r}" ) - start = Range.lit(0) + start = Range.const(0) stop = it.stop.typecheck(env) - if isinstance(stop, list) or not stop.is_lit(): + if isinstance(stop, list) or not stop.is_const(): reporter.error( f"Ending value of virtual def iter not a const: {self!r}" ) - stop = Range.lit(start.get_lit()) + stop = Range.const(start.get_const()) if isinstance(expected, Range): reporter.error( @@ -669,15 +669,15 @@ def handle_iters( ) return - if not 0 <= start.get_lit() <= stop.get_lit() < len(expected): + if not 0 <= start.get_const() <= stop.get_const() < len(expected): reporter.error( - f"Virtual definition index [{start.get_lit()}, {stop.get_lit()}] out of range for {expected}: {self!r}" + f"Virtual definition index [{start.get_const()}, {stop.get_const()}] out of range for {expected}: {self!r}" ) return - for i in range(start.get_lit(), stop.get_lit() + 1): + for i in range(start.get_const(), stop.get_const() + 1): handle_iters( - env.with_val(it.name, Range.lit(i)), + env.with_val(it.name, Range.const(i)), its, poly, expected[i],