diff --git a/spec/chip.typ b/spec/chip.typ index 10479943e..4749b886e 100644 --- a/spec/chip.typ +++ b/spec/chip.typ @@ -113,7 +113,6 @@ } if "poly" in def { - // assert(false, message: repr(index_all(var_name, gather_indices(def)))) ( [], table.cell(align: right, emph[definition]), diff --git a/spec/memory.typ b/spec/memory.typ index 1fcb7b54e..62059de37 100644 --- a/spec/memory.typ +++ b/spec/memory.typ @@ -229,3 +229,4 @@ add the required balancing terms to the LogUp sum. = Future topics of interest - Optimize memory systems after determining factual bottlenecks (e.g. taking inspiration from Twist and Shout, or other recent research) +- Double check whether IS_BYTE constraints are needed for fini diff --git a/spec/src/bitwise.toml b/spec/src/bitwise.toml index 9b4a3f951..75e8faee4 100644 --- a/spec/src/bitwise.toml +++ b/spec/src/bitwise.toml @@ -4,67 +4,67 @@ name = "BITWISE" name = "X" type = "Byte" desc = "" -precomputed = "true" +precomputed = true [[variables.input]] name = "Y" type = "Byte" desc = "" -precomputed = "true" +precomputed = true [[variables.input]] name = "Z" type = "B4" desc = "" -precomputed = "true" +precomputed = true [[variables.output]] name = "AND" type = "Byte" desc = "the binary AND of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "OR" type = "Byte" desc = "the binary OR of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "XOR" type = "Byte" desc = "the binary XOR of `X` and `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "MSB8" type = "Bit" desc = "the most significant bit of `X`" -precomputed = "true" +precomputed = true [[variables.output]] name = "MSB16" type = "Bit" desc = "the most significant bit of `Y`" -precomputed = "true" +precomputed = true [[variables.output]] name = "ZERO" type = "Bit" desc = "whether $#`X` = 0$, $#`Y` = 0$ and $#`Z` = 0$." -precomputed = "true" +precomputed = true [[variables.output]] name = "SLL" type = "Half" desc = "`X||Y` logically left-shifted by `Z`: $((#`X` + 256#`Y`) #`<<` #`Z`) mod 2^16$" -precomputed = "true" +precomputed = true [[variables.output]] name = "SLLC" type = "Half" desc = "`X||Y` logically right-shifted by `Z`: $(#`X` + 256#`Y`) #`>>` (16 - #`Z`)$" -precomputed = "true" +precomputed = true [[variables.multiplicity]] name = "μ_AND" @@ -197,4 +197,4 @@ kind = "interaction" tag = "HWSLC" input = [["+", "X", ["*", 256, "Y"]], "Z"] output = "SLLC" -multiplicity = ["-", "μ_HWSLC"] \ No newline at end of file +multiplicity = ["-", "μ_HWSLC"] diff --git a/spec/src/branch.toml b/spec/src/branch.toml index e66974c8e..beb3c1922 100644 --- a/spec/src/branch.toml +++ b/spec/src/branch.toml @@ -11,7 +11,7 @@ pad = 0 [[variables.input]] name = "offset" -type = "Word" +type = "DWordWL" desc = "The offset from the base address to jump to" pad = 0 @@ -59,7 +59,7 @@ name = "next_pc_unmasked" type = "DWordWL" desc = "The combination of `next_pc_high`, `next_pc_low[1]` and `unmasked_low_byte` to constrain the addition. This is the computed value for the next pc, before masking off the LSB as required by the ISA." def = {idx = "i", polys = [ - {iter = 0, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 0]], ["*", ["^", 2, 8], ["idx", "next_pc_low", 1]], ["idx", "unmasked_low_byte", 0]]}, + {iter = 0, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 0]], ["*", ["^", 2, 8], ["idx", "next_pc_low", 1]], "unmasked_low_byte"]}, {iter = 1, poly = ["+", ["*", ["^", 2, 16], ["idx", "next_pc_high", 2]], ["idx", "next_pc_high", 1]]}, ]} @@ -124,7 +124,7 @@ multiplicity = "μ" [[constraints.all]] kind = "interaction" tag = "AND_BYTE" -input = [["idx", "unmasked_low_byte", 0], 254] +input = ["unmasked_low_byte", 254] output = ["idx", "next_pc_low", 0] multiplicity = "μ" @@ -145,4 +145,4 @@ kind = "interaction" tag = "BRANCH" input = ["pc", "offset", "register", "JALR"] output = "next_pc" -multiplicity = "-μ" +multiplicity = ["-", "μ"] diff --git a/spec/src/config.toml b/spec/src/config.toml index d836f80e5..0f6ef11d6 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -4,63 +4,49 @@ version = 1 [[variables.types]] label = "BaseField" subtypes = ["BaseField"] +range = [0, "18446744069414584320"] desc = "Variable that can assume any value in the base field." [[variables.types]] label = "Bit" subtypes = ["BaseField"] +range = [0, 1] desc = "Variable that can only assume values in the set ${0,1}$." [[variables.types]] label = "B4" subtypes = ["BaseField"] +range = [0, 15] desc = "Variable that can only assume values in the range $[0, 2^4)$." [[variables.types]] label = "Byte" subtypes = ["BaseField"] -count = 1 +range = [0, 255] desc = "Variable that can only assume values in the range $[0, 2^8)$." [[variables.types]] label = "Half" subtypes = ["BaseField"] +range = [0, 65535] desc = "Variable that can only assume values in the range $[0, 2^16)$." [[variables.types]] label = "B20" subtypes = ["BaseField"] +range = [0, 1048575] desc = "Variable that can only assume values in the range $[0, 2^20)$." [[variables.types]] label = "Word" subtypes = ["BaseField"] +range = [0, 4294967295] desc = "Variable that can only assume values in the range $[0, 2^32)$." -[[variables.types]] -label = "WordHL" -subtypes = ["Half", "Half"] -desc = """\ - Variable that can only assume values in the range $[0, 2^32)$. \\ - Represented as an array of two `Half` variables.\ - """ - -[[variables.types]] -label = "WordBL" -subtypes = ["Byte", "Byte", "Byte", "Byte"] -desc = """\ - Variable that can only assume values in the range $[0, 2^32)$. \\ - Represented as an array of four `Byte` variables.\ - """ - -[[variables.types]] -label = "B35" -subtypes = ["BaseField"] -desc = "Variable that can only assume values in the range $[0, 2^35)$." - [[variables.types]] label = "B51" subtypes = ["BaseField"] +range = [0, 2251799813685247] desc = "Variable that can only assume values in the range $[0, 2^51)$." [[variables.types]] @@ -96,6 +82,15 @@ desc = """\ The `Word` is the *least* significant digit. """ +[[variables.types]] +label = "DWordWHH" +subtypes = ["Half", "Half", "Word"] +desc = """\ + Variable that can only assume values in the range $[0, 2^64)$. \\ + Represented as a `Word` and two `Half` variables.\ + The `Word` is the *most* significant digit. + """ + [[variables.types]] label = "QuadHL" subtypes = ["Half", "Half", "Half", "Half", "Half", "Half", "Half", "Half"] @@ -112,15 +107,6 @@ desc = """\ Represented as an array of four `Word` variables.\ """ -[[variables.types]] -label = "DWordWHH" -subtypes = ["Half", "Half", "Word"] -desc = """\ - Variable that can only assume values in the range $[0, 2^64)$. \\ - Represented as a `Word` and two `Half` variables.\ - The `Word` is the *most* significant digit. - """ - [[variables.types]] label = "Timestamp" subtypes = ["DWordWL"] diff --git a/spec/src/cpu.toml b/spec/src/cpu.toml index c151b6eff..994fda508 100644 --- a/spec/src/cpu.toml +++ b/spec/src/cpu.toml @@ -343,6 +343,7 @@ name = "decode" kind = "interaction" tag = "DECODE" input = ["pc", "imm", "packed_decode"] +multiplicity = 1 [[constraint_groups]] @@ -515,34 +516,40 @@ ref = "cpu:c:range_EBREAK" kind = "interaction" tag = "IS_BYTE" input = ["rs1"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = ["rs2"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = ["rd"] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "arg1", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "arg2", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraints.range]] kind = "interaction" tag = "IS_BYTE" input = [["idx", "res", "i"]] iter = ["i", 0, 7] +multiplicity = 1 [[constraint_groups]] @@ -611,7 +618,7 @@ multiplicity = "SHIFT" [[constraints.alu]] kind = "template" tag = "ADD" -input = ["pc", ["cast", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], "DWordWL"]] +input = ["pc", ["*", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], ["cast", 1, "DWordWL"]]] output = ["cast", "res", "DWordWL"] cond = "JALR" @@ -640,7 +647,7 @@ prefix = "M" [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rs1"], "rv1", ["+", "timestamp", 0], 1, 0, 0] +input = [1, ["*", 2, "rs1"], "rv1", ["+", "timestamp", ["cast", 0, "DWordWL"]], 1, 0, 0] output = "rv1" multiplicity = "read_register1" @@ -654,7 +661,7 @@ iter = ["i", 0, 2] [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rs2"], "rv2", ["+", "timestamp", 1], 1, 0, 0] +input = [1, ["*", 2, "rs2"], "rv2", ["+", "timestamp", ["cast", 1, "DWordWL"]], 1, 0, 0] output = "rv2" multiplicity = "read_register2" @@ -668,13 +675,13 @@ iter = ["i", 0, 2] [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, "rd"], "rvd", ["+", "timestamp", 2], 1, 0, 0] +input = [1, ["*", 2, "rd"], "rvd", ["+", "timestamp", ["cast", 2, "DWordWL"]], 1, 0, 0] multiplicity = "write_register" [[constraints.mem]] kind = "interaction" tag = "LOAD" -input = [0, "res", ["+", "timestamp", 0], "memory_2bytes", "memory_4bytes", "memory_8bytes", "signed"] +input = [0, "res", ["+", "timestamp", ["cast", 0, "DWordWL"]], "memory_2bytes", "memory_4bytes", "memory_8bytes", "signed"] output = "rvd" multiplicity = "LOAD" @@ -682,14 +689,14 @@ multiplicity = "LOAD" [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [0, "res", "rv2", ["+", "timestamp", 1], "memory_2bytes", "memory_4bytes", "memory_8bytes"] +input = [0, "res", "rv2", ["+", "timestamp", ["cast", 1, "DWordWL"]], "memory_2bytes", "memory_4bytes", "memory_8bytes"] multiplicity = "STORE" # TODO: no types available, so no casting yet [[constraints.mem]] kind = "interaction" tag = "MEMW" -input = [1, ["*", 2, 255], "next_pc", ["+", "timestamp", 1], 1, 0, 0] +input = [1, ["*", 2, 255], "next_pc", ["+", "timestamp", ["cast", 1, "DWordWL"]], 1, 0, 0] output = "pc" multiplicity = ["not", "pad"] @@ -795,7 +802,7 @@ poly = ["+", ["-", "branch_cond"], "JALR", ["*", ["idx", "res", 0], ["not", "mp_selector"], "BLT"], - ["*", ["not", ["idx", "res", 0]], "mp_selector", "BLT"], + ["*", ["-", 1, ["idx", "res", 0]], "mp_selector", "BLT"], ["*", "is_equal", ["not", "mp_selector"], "BEQ"], ["*", ["not", "is_equal"], "mp_selector", "BEQ"] ] @@ -810,6 +817,6 @@ multiplicity = "branch_cond" [[constraints.misc]] kind = "template" tag = "ADD" -input = ["pc", ["cast", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], "DWordWL"]] +input = ["pc", ["*", ["+", ["*", 2, "c_type_instruction"], ["*", 4, ["not", "c_type_instruction"]]], ["cast", 1, "DWordWL"]]] output = "next_pc" desc = "Increment `pc` to `next_pc` if we're not branching" diff --git a/spec/src/dvrm.toml b/spec/src/dvrm.toml index d93449228..52583907c 100644 --- a/spec/src/dvrm.toml +++ b/spec/src/dvrm.toml @@ -376,13 +376,13 @@ desc = "Each row contributes the following to the LogUp sum" [[constraints.output]] kind = "interaction" tag = "DVRM" -input = ["n", "d", "signed", "0"] +input = ["n", "d", "signed", 0] output = ["cast", "q", "DWordWL"] -multiplicity = "-μ_q" +multiplicity = ["-", "μ_q"] [[constraints.output]] kind = "interaction" tag = "DVRM" -input = ["n", "d", "signed", "1"] +input = ["n", "d", "signed", 1] output = ["cast", "r", "DWordWL"] -multiplicity = "-μ_r" \ No newline at end of file +multiplicity = ["-", "μ_r"] diff --git a/spec/src/is_bit.toml b/spec/src/is_bit.toml index 47e96a27e..a72b5f648 100644 --- a/spec/src/is_bit.toml +++ b/spec/src/is_bit.toml @@ -16,5 +16,5 @@ name = "all" [[constraints.all]] kind = "arith" constraint = "$#`cond` => #`X` (1-#`X`) = 0$" -poly = ["*", "cond", "X", ["not", "X"]] +poly = ["*", "cond", "X", ["-", 1, "X"]] ref = "isbit:c:isbit" diff --git a/spec/src/lt.toml b/spec/src/lt.toml index 10497b637..1941dbb7a 100644 --- a/spec/src/lt.toml +++ b/spec/src/lt.toml @@ -160,4 +160,4 @@ kind = "interaction" tag = "LT" input = [["cast", "lhs", "DWordWL"], ["cast", "rhs", "DWordWL"], "signed"] output = "lt" -multiplicity = "-μ" +multiplicity = ["-", "μ"] diff --git a/spec/src/memw.toml b/spec/src/memw.toml index f7276a9cd..af005c2b4 100644 --- a/spec/src/memw.toml +++ b/spec/src/memw.toml @@ -129,7 +129,7 @@ kind = "template" tag = "ADD" input = ["base_address", 1] output = ["cast", ["idx", "address_add", 0], "DWordWL"] -multiplicity = "w2" +cond = "w2" [[constraints.consistency]] kind = "template" @@ -137,7 +137,7 @@ tag = "ADD" input = ["base_address", ["+", "i", 1]] output = ["cast", ["idx", "address_add", "i"], "DWordWL"] iter = ["i", 1, 2] -multiplicity = "w4" +cond = "w4" [[constraints.consistency]] kind = "template" @@ -145,16 +145,37 @@ tag = "ADD" input = ["base_address", ["+", "i", 1]] output = ["cast", ["idx", "address_add", "i"], "DWordWL"] iter = ["i", 3, 6] -multiplicity = "write8" +cond = "write8" + +[[constraints.consistency]] +kind = "interaction" +tag = "IS_HALFWORD" +input = [["idx", ["idx", "address_add", "i"], "j"]] +iters = [ + ["i", 0, 0], + ["j", 0, 3], +] +multiplicity = "w2" [[constraints.consistency]] kind = "interaction" tag = "IS_HALFWORD" input = [["idx", ["idx", "address_add", "i"], "j"]] iters = [ - ["i", 0, 6], + ["i", 1, 2], ["j", 0, 3], ] +multiplicity = "w4" + +[[constraints.consistency]] +kind = "interaction" +tag = "IS_HALFWORD" +input = [["idx", ["idx", "address_add", "i"], "j"]] +iters = [ + ["i", 3, 6], + ["j", 0, 3], +] +multiplicity = "write8" [[constraints.consistency]] kind = "interaction" diff --git a/spec/src/mul.toml b/spec/src/mul.toml index 238bfe01f..a798c682d 100644 --- a/spec/src/mul.toml +++ b/spec/src/mul.toml @@ -182,7 +182,7 @@ name = "prod" [[constraints.prod]] kind = "arith" constraint = "$#`raw_product[i]` = sum_(#`k`=0)^1 2^(16k) sum_(#`j`=0)^(2i+k) #`lhs_ext[j]` dot #`rhs_ext[2i+k-j]`$" -poly = ["-", ["sum", ["=", "k", 0], "1", ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]] +poly = ["-", ["sum", ["=", "k", 0], 1, ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]] iter = ["i", 0, 3] ref = "mul:c:raw_product" @@ -192,7 +192,7 @@ name = "lookup" [[constraints.lookup]] kind = "interaction" tag = "MUL" -input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "0"] +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", 0] output = ["cast", "lo", "DWordWL"] multiplicity = ["-", "μ_lo"] ref = "mul:c:lookup_lo" @@ -200,7 +200,7 @@ ref = "mul:c:lookup_lo" [[constraints.lookup]] kind = "interaction" tag = "MUL" -input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "1"] +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", 1] output = ["cast", "hi", "DWordWL"] multiplicity = ["-", "μ_hi"] ref = "mul:c:lookup_hi" diff --git a/spec/src/page.toml b/spec/src/page.toml index 8053d63df..21ec76757 100644 --- a/spec/src/page.toml +++ b/spec/src/page.toml @@ -2,6 +2,12 @@ name = "PAGE" # Input +# TODO: add `page` as "constant" column or smth +[[variables.input]] +name = "page" +type = "DWordWL" +desc = "Constant column containing the page base address; should be integrated into the constraints directly" + [[variables.input]] name = "offset" type = "RowIndex" @@ -28,7 +34,7 @@ desc = "The timestamp at which this address was last accessed" name = "address" type = "DWordWL" desc = "Adding `offset` to the page base address `page`. `page` is a constant with respect to a single instance of this table." -def = ["+", "page", ["cast", "offset", "DWordWL"]] +def = ["+", "page", ["*", "offset", ["cast", 1, "DWordWL"]]] [[constraint_groups]] @@ -38,11 +44,13 @@ name = "all" kind = "interaction" tag = "IS_BYTE" input = ["init"] +multiplicity = 1 [[constraints.all]] kind = "interaction" tag = "IS_BYTE" input = ["fini"] +multiplicity = 1 [[constraints.all]] kind = "interaction" diff --git a/spec/src/shift.toml b/spec/src/shift.toml index 591efb839..bd6c471a6 100644 --- a/spec/src/shift.toml +++ b/spec/src/shift.toml @@ -292,5 +292,5 @@ kind = "interaction" tag = "SHIFT" input = ["in", "shift", "direction", "signed", "word_instr"] output = "out" -multiplicity = "-μ" +multiplicity = ["-", "μ"] ref = "shift:c:lookup" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py new file mode 100644 index 000000000..8a15ae338 --- /dev/null +++ b/spec/tooling/chip.py @@ -0,0 +1,988 @@ +import sys +import tomllib +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from pathlib import Path +from typing import Never, Optional, Self + + +class ErrorReporter: + reported: bool + location: str + + def __init__(self, location: str): + self.reported = False + self.location = location + + def update_location(self, loc: str): + self.reported = False + self.location = loc + + def error(self, message: str): + self.reported = True + print(f"ERROR {self.location}: {message}", file=sys.stderr) + + def asserts(self, condition: bool, message: str): + if not condition: + self.error(message) + + +reporter = ErrorReporter("unknown") + + +def assert_no_unexpected(data: dict, possible_keys: Iterable[str]): + for key in data.keys(): + reporter.asserts(key in possible_keys, f"Unexpected key: {key!r}") + + +@dataclass(frozen=True) +class Range: + low: int + high: int + + @classmethod + def const(cls, x: int) -> Self: + return cls(x, x) + + def is_bool(self): + return self.low >= 0 and self.high <= 1 + + def is_const(self): + return self.low == self.high + + def get_const(self) -> int: + assert self.is_const() + return self.low + + +type Type = list[Type] | Range + +DEFAULT_TYPE: Type = Range.const(0) + +type Expr = ( + LitExpr + | VarExpr + | IdxExpr + | CastExpr + | MulExpr + | AddExpr + | SubExpr + | PowExpr + | SumExpr + | NotExpr + | DummyExpr +) + + +@dataclass +class Environment: + config: "Config" + valmap: dict[str, Range] + typemap: dict[str, Type] + + def with_val(self, key: str, val: Range) -> Self: + return type(self)(self.config, {**self.valmap, key: val}, self.typemap) + + +@dataclass +class LitExpr: + lit: int + + def typecheck(self, _env: Environment) -> Type: + return Range.const(self.lit) + + +@dataclass +class VarExpr: + name: str + + def typecheck(self, env: Environment) -> Type: + if self.name in env.valmap: + return env.valmap[self.name] + if self.name in env.typemap: + return env.typemap[self.name] + reporter.error(f"Unknown variable: {self.name!r}") + return DEFAULT_TYPE + + +@dataclass +class ArrExpr: + elems: list[Expr] + + def typecheck(self, env: Environment) -> Type: + reporter.asserts(self.elems != [], f"Empty array: {self!r}") + return [e.typecheck(env) for e in self.elems] + + +@dataclass +class IdxExpr: + base: Expr + idx: Expr + + def typecheck(self, env: Environment) -> Type: + base = self.base.typecheck(env) + idx = self.idx.typecheck(env) + if not isinstance(idx, Range) or not idx.is_const(): + reporter.error(f"Invalid index: {idx!r}") + return Range.const(-1) + idxconst = idx.get_const() + if isinstance(base, Range): + reporter.error(f"Indexing into non-array type: {self!r}") + return DEFAULT_TYPE + if not (0 <= idxconst < len(base)): + reporter.error(f"Index out of range {self!r}") + idxconst = 0 + return base[idxconst] + + +@dataclass +class CastExpr: + base: Expr + type: Type + + def typecheck(self, env: Environment) -> Type: + base = self.base.typecheck(env) + # TODO? Detect more sorts of invalid casts + baselen = len(base) if isinstance(base, list) else 1 + castlen = len(self.type) if isinstance(self.type, list) else 1 + reporter.asserts( + baselen >= castlen or (isinstance(base, Range) and base.is_const()), + f"Casting from fewer columns to more: {self!r} {base} {self.type}", + ) + return self.type + + +@dataclass +class MulExpr: + factors: list[Expr] + + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + reporter.error(f"Multiplication of non-scalar types: {self!r}") + return DEFAULT_TYPE + elif not isinstance(a, Range): + return [self.type_match(x, b) for x in a] + elif isinstance(b, list): + return self.type_match(b, a) + else: + extrema = [x * y for x in [a.low, a.high] for y in [b.low, b.high]] + return Range(min(extrema), max(extrema)) + + def typecheck(self, env: Environment) -> Type: + reporter.asserts(self.factors != [], f"Empty product: {self!r}") + t: Type = Range.const(1) + for f in self.factors: + t = self.type_match(t, f.typecheck(env)) + return t + + +@dataclass +class AddExpr: + terms: list[Expr] + + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + reporter.error(f"Adding array types of different length {self!r}") + return [DEFAULT_TYPE for _ in b] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Adding of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low + b.low, a.high + b.high) + + def typecheck(self, env: Environment) -> Type: + if not self.terms: + reporter.error("Empty add") + return Range.const(0) + t: Type = self.terms[0].typecheck(env) + for term in self.terms[1:]: + t = self.type_match(t, term.typecheck(env)) + return t + + +@dataclass +class SubExpr: + head: Expr + subs: list[Expr] + + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + reporter.error(f"Subtracting array types of different length {self!r}") + return [DEFAULT_TYPE for _ in a] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Subtraction of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low - b.high, a.high - b.low) + + def typecheck(self, env: Environment) -> Type: + t = self.head.typecheck(env) + if not self.subs: + if not isinstance(t, Range): + reporter.error(f"Negating a non-scalar type: {self!r}") + return t + return Range(-t.high, -t.low) + for term in self.subs: + t = self.type_match(t, term.typecheck(env)) + return t + + +@dataclass +class PowExpr: + base: Expr + exp: Expr + + def typecheck(self, env: Environment) -> Type: + base = self.base.typecheck(env) + exp = self.exp.typecheck(env) + if isinstance(base, list) or not base.is_const(): + reporter.error(f"Invalid exponentiation with non-const base: {self.base!r}") + return DEFAULT_TYPE + if isinstance(exp, list) or not exp.is_const(): + reporter.error( + f"Invalid exponentiation with non-const exponent: {self.exp!r}" + ) + return DEFAULT_TYPE + val = pow(base.get_const(), exp.get_const(), env.config.variables.prime) + return Range.const(val) + + +@dataclass +class SumExpr: + iter: "Iter" + terms: Expr + + def type_match(self, a: Type, b: Type) -> Type: + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + reporter.error(f"Summing array types of different length {self!r}") + return [DEFAULT_TYPE for _ in b] + return [self.type_match(x, y) for x, y in zip(a, b)] + elif isinstance(a, list) or isinstance(b, list): + reporter.error(f"Summing of scalar and array types {self!r}") + return DEFAULT_TYPE + else: + return Range(a.low + b.low, a.high + b.high) + + def typecheck(self, env: Environment) -> Type: + t: Type = Range.const(0) + for tc in self.iter.typecheck(env, lambda e: [self.terms.typecheck(e)]): + t = self.type_match(t, tc) + return t + + +@dataclass +class NotExpr: + inner: Expr + + def typecheck(self, env: Environment) -> Type: + inner = self.inner.typecheck(env) + if isinstance(inner, list) or not inner.is_bool(): + reporter.error(f"Not a bool passed to `not`: {self.inner!r}") + return Range(0, 1) + return Range(1 - inner.high, 1 - inner.low) + + +@dataclass +class DummyExpr: + def typecheck(self, _env: Environment) -> Type: + return DEFAULT_TYPE + + +def build_expr(config: Optional["Config"], data: object) -> Expr: + # Does this need config, or do we delay any config-checking to when we use the expr? + match data: + case int(x): + return LitExpr(x) + case str(x): + reporter.asserts( + x.isidentifier(), f"Invalid identifier name for variable {x!r}" + ) + return VarExpr(x) + case ["idx", x, y]: + return IdxExpr(build_expr(config, x), build_expr(config, y)) + case ["cast", x, t]: + assert config is not None + assert isinstance(t, (list, str)) + return CastExpr(build_expr(config, x), build_type(config, t)) + case ["*", *factors]: + return MulExpr([build_expr(config, f) for f in factors]) + case ["+", *terms]: + return AddExpr([build_expr(config, t) for t in terms]) + case ["-", head, *subs]: + return SubExpr( + build_expr(config, head), [build_expr(config, s) for s in subs] + ) + case ["^", base, exp]: + return PowExpr(build_expr(config, base), build_expr(config, exp)) + case ["sum", ["=", str(var), start], stop, terms]: + assert config is not None + return SumExpr(Iter(config, var, start, stop), build_expr(config, terms)) + case ["not", e]: + return NotExpr(build_expr(config, e)) + case other: + reporter.error(f"Unknown expression: {other!r}") + return DummyExpr() + + +@dataclass +class Iter: + name: str + start: Expr + stop: Expr + + def __init__(self, config: "Config", name: str, start: object, stop: object): + self.name = name + reporter.asserts( + isinstance(self.name, str), f"iter name is not a string: {self.name!r}" + ) + reporter.asserts( + self.name.isidentifier(), f"Not a valid identifier: {self.name!r}" + ) + self.start = build_expr(config, start) + self.stop = build_expr(config, stop) + + def typecheck[T]( + self, env: Environment, callback: Callable[[Environment], Iterable[T]] + ) -> Iterable[T]: + start = self.start.typecheck(env) + if isinstance(start, list) or not start.is_const(): + reporter.error(f"Starting value of iterator not a const: {self!r}") + start = Range.const(0) + stop = self.stop.typecheck(env) + if isinstance(stop, list) or not stop.is_const(): + reporter.error(f"Ending value of iterator not a const: {self!r}") + stop = Range.const(start.get_const()) + + # While it's tempting to replace this loop by an assignment of Range(start, stop + 1) to self.name + # that would break both detection of consts, and narrowing down to the correct type for indexing + # heterogenous array types + for i in range(start.get_const(), stop.get_const() + 1): + yield from callback(env.with_val(self.name, Range.const(i))) + + +def iters_of(obj: dict, name=None) -> list[Iter]: + """Return a list of iterators needed by `obj`. Taken from `iters` or `iter`. + Prepend `name` to every iterator, if given. + Adapted from the corresponding typst implementation.""" + + def clean_iter(it): + arr = it if isinstance(it, list) else [it] + if name is not None: + arr = [name] + arr + + if len(arr) == 2: + # Assume single-element range + arr.append(arr[-1]) + + if len(arr) != 3: + reporter.error(f"Invalid length iter: {arr!r}") + return Iter(config, "_", 0, 0) + return Iter(config, *arr) + + if "iters" in obj: + reporter.asserts( + "iter" not in obj, f"Object has both `iters` and `iter`: {obj!r}" + ) + return [clean_iter(it) for it in obj["iters"]] + elif "iter" in obj: + return [clean_iter(obj["iter"])] + else: + return [] + + +@dataclass +class TypeConfig: + label: str + subtypes: list[Type] + range: Optional[Range] + desc: str + preprocessed: bool + + def __init__(self, default_name: str, lookup: Callable[[str], Type], data: dict): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.label = data["label"] + if "range" in data: + reporter.asserts( + data["subtypes"] == [default_name], + f"Specified a range on a non-base composite type: {data!r}", + ) + reporter.asserts( + isinstance(data["range"], list) and len(data["range"]) == 2, + f"Invalid range: {data!r}", + ) + start, stop = data["range"] + if not isinstance(start, int) and not ( + isinstance(start, str) and start.isdigit() + ): + reporter.error(f"Range start not an int: {data!r}") + start = 0 + if not isinstance(stop, int) and not ( + isinstance(stop, str) and stop.isdigit() + ): + reporter.error(f"Range end not an int: {data!r}") + stop = start + reporter.asserts(int(start) <= int(stop), f"Inverted range: {data!r}") + self.range = Range(int(start), int(stop)) + self.subtypes = [] + else: + self.range = None + self.subtypes = [lookup(tp) for tp in data["subtypes"]] + self.desc = data["desc"] + self.preprocessed = data.get("preprocessed", False) + + def as_type(self) -> Type: + return self.range or self.subtypes[:] + + +@dataclass +class ConfigCategories: + all: list[str] + instantiated: list[str] + + def __init__(self, data: dict): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.all = data["all"] + self.instantiated = data["instantiated"] + reporter.asserts( + all(isinstance(v, str) for v in self.all), + f"Something's not a string: {self.all}", + ) + reporter.asserts( + all(isinstance(v, str) for v in self.instantiated), + f"Something's not a string: {self.instantiated}", + ) + reporter.asserts( + set(self.instantiated) <= set(self.all), + f"Instantiated not a subset of all: {self!r}", + ) + + +@dataclass +class ConfigVariables: + types: list[TypeConfig] + categories: ConfigCategories + prime: int + + def __init__(self, data: dict): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.types = [] + base_type = data["types"][0]["label"] + for tp in data["types"]: + self.types.append(TypeConfig(base_type, self.lookup_type, tp)) + self.categories = ConfigCategories(data["categories"]) + basefield = self.lookup_type(base_type) + assert isinstance(basefield, Range) + self.prime = basefield.high + 1 + + def lookup_type(self, typename: str) -> Type: + matches = [t for t in self.types if t.label == typename] + if len(matches) != 1: + reporter.error(f"Couldn't lookup type by name: {typename!r}") + return DEFAULT_TYPE + return matches[0].as_type() + + +@dataclass +class ConfigMetadata: + version: int + + def __init__(self, data: dict): + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.version = data["version"] + reporter.asserts( + isinstance(self.version, int), f"version {self.version!r} is not an int" + ) + + +@dataclass +class Config: + metadata: ConfigMetadata + variables: ConfigVariables + + def __init__(self, data: dict): + """Construct a Config from toml-parsed data""" + assert_no_unexpected(data, type(self).__annotations__.keys()) + self.metadata = ConfigMetadata(data["metadata"]) + self.variables = ConfigVariables(data["variables"]) + + @classmethod + def from_file(cls, filename: str | Path) -> Self: + reporter.update_location(str(filename)) + return cls(tomllib.load(open(filename, "rb"))) + + @classmethod + def from_string(cls, s: str) -> Self: + reporter.update_location("") + return cls(tomllib.loads(s)) + + +def build_type(config: Config, data: list | str): + if isinstance(data, list): + if len(data) != 2: + reporter.error(f"Invalid type: {data!r}") + return DEFAULT_TYPE + return [build_type(config, data[0]) for _ in range(data[1])] + else: + return config.variables.lookup_type(data) + + +@dataclass +class Variable: + category: str + name: str + type: Type + desc: str + pad: Expr + precomputed: bool + + def __init__(self, config: Config, category: str, data: dict): + self.category = category + assert_no_unexpected(data, Variable.__annotations__.keys()) + self.name = data["name"] + reporter.asserts(isinstance(self.name, str), f"{self.name!r} is not a string") + reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") + self.type = build_type(config, data["type"]) + self.desc = data["desc"] + reporter.asserts(isinstance(self.desc, str), f"{self.desc!r} is not a string") + self.pad = build_expr(None, data.get("pad", 0)) + self.precomputed = data.get("precomputed", False) + reporter.asserts( + isinstance(self.precomputed, bool), + f"precomputed is not a bool: {self.precomputed!r}", + ) + + +def all_iters[T]( + its: list[Iter], env: Environment, callback: Callable[[Environment], Iterable[T]] +) -> Iterable[T]: + if not its: + yield from callback(env) + else: + yield from its[0].typecheck(env, lambda e: all_iters(its[1:], e, callback)) + + +@dataclass +class PolyWithIters: + poly: Expr + iters: list[Iter] + + +@dataclass +class VirtualDef: + # A list of polynomials with each a set of iters they range over + defs: list[PolyWithIters] + + def __init__(self, config: Config, name: str, tp: Type, data: dict): + if "poly" in data: + idx = data.get("idx", None) + self.defs = [ + PolyWithIters( + build_expr(config, data["poly"]), iters_of(data, name=idx) + ) + ] + elif "polys" in data: + idx = data.get("idx", None) + self.defs = [ + PolyWithIters( + build_expr(config, poly["poly"]), iters_of(poly, name=idx) + ) + for poly in data["polys"] + ] + else: + self.defs = [PolyWithIters(build_expr(config, data), [])] + + +@dataclass +class VirtualVariable(Variable): + def_: VirtualDef + + def __init__(self, config: Config, category: str, data: dict): + assert_no_unexpected(data, set(Variable.__annotations__.keys()) | {"def"}) + reporter.asserts("def" in data, f"Missing def for virtual column: {data!r}") + def_ = data.pop("def", {}) + super().__init__(config, category, data) + self.def_ = VirtualDef(config, self.name, self.type, def_) + + def typecheck(self, env: Environment) -> Type: + def structure_matches(a: Type, b: Type) -> bool: + if isinstance(a, Range) and isinstance(b, (Range, type(None))): + return True + elif isinstance(a, list) and isinstance(b, list): + return len(a) == len(b) and all( + structure_matches(x, y) for x, y in zip(a, b) + ) + else: + return False + + def handle_iters( + env: Environment, + iters: list[Iter], + poly: Expr, + expected: Type, + indices: list[int], + seen: set[tuple], + ): + if not iters: + # Check not doubly defined + for s in seen: + ln = min(len(s), len(indices)) + if s[:ln] == tuple(indices[:ln]): + reporter.error( + f"Double definition for virtual column: {self!r} at index {indices}" + ) + break + + val = poly.typecheck(env) + # check val structure matches assigned + reporter.asserts( + structure_matches(val, expected), + f"Invalid structure for definition to virtual column: {self!r}", + ) + # Check type fits? + + seen.add(tuple(indices)) + else: + it, *its = iters + # Some duplicated code/concepts from Iter.typecheck + # But threading the extra needed state through overly complicates everything + start = it.start.typecheck(env) + if isinstance(start, list) or not start.is_const(): + reporter.error( + f"Starting value of virtual def iter not a const: {self!r}" + ) + start = Range.const(0) + stop = it.stop.typecheck(env) + if isinstance(stop, list) or not stop.is_const(): + reporter.error( + f"Ending value of virtual def iter not a const: {self!r}" + ) + stop = Range.const(start.get_const()) + + if isinstance(expected, Range): + reporter.error( + f"Virtual definition has an iter for a scalar: {self!r}" + ) + return + + if not 0 <= start.get_const() <= stop.get_const() < len(expected): + reporter.error( + f"Virtual definition index [{start.get_const()}, {stop.get_const()}] out of range for {expected}: {self!r}" + ) + return + + for i in range(start.get_const(), stop.get_const() + 1): + handle_iters( + env.with_val(it.name, Range.const(i)), + its, + poly, + expected[i], + indices + [i], + seen, + ) + + def is_covered(seen: set[tuple], indices: list[int]) -> bool: + for s in seen: + if len(s) <= len(indices) and s == tuple(indices[: len(s)]): + return True + return False + + def check_covered(t: Type, seen: set[tuple], indices: list[int]): + if isinstance(t, Range): + reporter.asserts( + is_covered(seen, indices), + f"Virtual column {self.name!r} not completely defined", + ) + else: + for i, elt in enumerate(t): + check_covered(elt, seen, indices + [i]) + + # Special case for better error messages + if isinstance(self.type, Range): + reporter.asserts( + len(self.def_.defs) == 1 and not self.def_.defs[0].iters, + f"Invalid def for scalar column: {self!r}", + ) + assigned_type = self.def_.defs[0].poly.typecheck(env) + if not isinstance(assigned_type, Range): + reporter.error( + f"Assigning non-scalar type to scalar virtual column: {self!r}" + ) + return self.type + # Check type fits? + # Leaving this out because it produces too much noise with one-hot assumptions + # reporter.asserts(self.type.low <= assigned_type.low <= assigned_type.high <= self.type.high, f"Definition may not fit in virtual column: {self!r}") + else: + # Check no indices are covered twice + seen: set[tuple] = set() + for poly_iters in self.def_.defs: + handle_iters( + env, poly_iters.iters, poly_iters.poly, self.type, [], seen + ) + # Check everything is covered + check_covered(self.type, seen, []) + return self.type + + +@dataclass +class Assumption: + desc: str + iters: list[Iter] + + def __init__(self, config: Config, data: dict): + assert_no_unexpected( + data, set(self.__annotations__.keys()) | {"iter", "iters", "ref"} + ) + self.desc = data["desc"] + self.iters = iters_of(data) + + +@dataclass +class ArithConstraint: + constraint: str + desc: str + poly: Expr + iters: list[Iter] + + def __init__(self, config: Config, data: dict): + assert_no_unexpected( + data, set(self.__annotations__.keys()) | {"kind", "ref", "iter", "iters"} + ) + assert data["kind"] == "arith" + self.constraint = data["constraint"] + reporter.asserts( + isinstance(self.constraint, str), + f"Constraint not a string: {self.constraint!r}", + ) + self.desc = data.get("desc", "") + reporter.asserts( + isinstance(self.desc, str), f"desc is not a string: {self.desc!r}" + ) + self.poly = build_expr(config, data["poly"]) + self.iters = iters_of(data) + + def typecheck(self, env: Environment) -> Iterable[Never]: + # TODO? Should we check that there's no overflow of the modulus? + # This would probably struggle due to things like one-hot invariants + + def check_includes_zero(t: Type): + if isinstance(t, Range): + reporter.asserts( + t.low <= 0 <= t.high, + f"Unsatisfiable constraint, 0 not in range: {self!r} {t}", + ) + else: + reporter.error( + f"Non-scalar value for polynomial constraint: {self!r} {t}" + ) + + for t in all_iters(self.iters, env, lambda e: [self.poly.typecheck(e)]): + check_includes_zero(t) + return [] + + +@dataclass +class Signature: + tag: str + input: list[Type] + output: Optional[Type] + + +@dataclass +class InteractionLike: + kind: str + conditional_name: str + conditional_required: bool + signature: type[Signature] + + tag: str + desc: str + input: list[Expr] + output: Optional[Expr] + conditional: Optional[Expr] + iters: list[Iter] + + def __init__(self, config: Config, data: dict): + assert_no_unexpected( + data, + { + "tag", + "desc", + "input", + "output", + self.conditional_name, + "kind", + "ref", + "iter", + "iters", + }, + ) + assert data["kind"] == self.kind + self.tag = data["tag"] + reporter.asserts( + isinstance(self.tag, str), f"tag is not a string: {self.tag!r}" + ) + self.desc = data.get("desc", "") + reporter.asserts( + isinstance(self.desc, str), f"Description is not a string: {self.desc!r}" + ) + self.input = [build_expr(config, inp) for inp in data["input"]] + if "output" in data: + self.output = build_expr(config, data["output"]) + else: + self.output = None + if self.conditional_name in data: + self.conditional = build_expr(config, data[self.conditional_name]) + else: + reporter.asserts( + not self.conditional_required, + f"Missing {self.conditional_name}: {data!r}", + ) + self.conditional = None + self.iters = iters_of(data) + + def typecheck(self, env: Environment) -> Iterable[Signature]: + def callback(e: Environment) -> Iterable[Signature]: + # TODO: Should we be able to check cond/multiplicity somehow? + if self.conditional is not None: + self.conditional.typecheck(e) + return [ + self.signature( + self.tag, + [inp.typecheck(e) for inp in self.input], + self.output.typecheck(e) if self.output else None, + ) + ] + + return all_iters(self.iters, env, callback) + + +class TemplateSignature(Signature): + pass + + +class TemplateConstraint(InteractionLike): + kind = "template" + conditional_name = "cond" + conditional_required = False + signature = TemplateSignature + + +class InteractionSignature(Signature): + pass + + +class InteractionConstraint(InteractionLike): + kind = "interaction" + conditional_name = "multiplicity" + conditional_required = True + signature = InteractionSignature + + +@dataclass +class DummyConstraint: + def typecheck(self, env: Environment) -> list[Never]: + return [] + + +type Constraint = ( + ArithConstraint | TemplateConstraint | InteractionConstraint | DummyConstraint +) + + +def build_constraint(config, data: dict) -> Constraint: + match data["kind"]: + case "arith": + return ArithConstraint(config, data) + case "template": + return TemplateConstraint(config, data) + case "interaction": + return InteractionConstraint(config, data) + case other: + reporter.error(f"Unknown constraint kind: {other!r}") + return DummyConstraint() + + +@dataclass +class Chip: + config: Config + name: str + variables: list[Variable] + assumptions: list[Assumption] + constraints: list[Constraint] + + def __init__(self, config: Config, data: dict): + """Construct a chip from toml-parsed data""" + assert_no_unexpected( + data, set(type(self).__annotations__.keys()) | {"constraint_groups"} + ) + assert_no_unexpected(data["variables"], config.variables.categories.all) + self.config = config + self.name = data["name"] + reporter.asserts( + isinstance(self.name, str), f"name is not a string: {self.name!r}" + ) + reporter.asserts(self.name.isidentifier(), f"Invalid identifier: {self.name!r}") + self.variables = [ + (Variable if cat != "virtual" else VirtualVariable)(config, cat, var) + for cat, vars in data["variables"].items() + for var in vars + ] + self.assumptions = [Assumption(config, a) for a in data.get("assumptions", [])] + constraint_groups = [grp["name"] for grp in data.get("constraint_groups", [])] + assert_no_unexpected(data.get("constraints", {}), constraint_groups) + self.constraints = [ + build_constraint(config, constraint) + for group in data.get("constraints", {}).values() + for constraint in group + ] + + @classmethod + def from_file(cls, config: Config, filename: str | Path) -> Self: + reporter.update_location(str(filename)) + return cls(config, tomllib.load(open(filename, "rb"))) + + @classmethod + def from_string(cls, config: Config, s: str) -> Self: + reporter.update_location("") + return cls(config, tomllib.loads(s)) + + def typecheck(self) -> Iterable[Signature]: + typemap = {} + for v in self.variables: + if isinstance(v.type, list) and len(v.type) == 1: + t = v.type[0] + else: + t = v.type + typemap[v.name] = t + + env = Environment(self.config, {}, typemap) + for v in self.variables: + if isinstance(v, VirtualVariable): + v.typecheck(env) + for c in self.constraints: + yield from c.typecheck(env) + + +if __name__ == "__main__": + config = Config.from_file(sys.argv[1]) + signatures = sys.argv[2] # Later + if reporter.reported: + sys.exit(1) + reported = False + chips: list[Chip] = [] + for file in sys.argv[3:]: + if file in sys.argv[1:3]: + continue + chips.append(Chip.from_file(config, file)) + reported |= reporter.reported + if not reported: + for chip in chips: + reporter.update_location(f"Chip {chip.name}") + # TODO: do something with the signatures + # Use list for the sideeffect of forcing the generator until we use the content + list(chip.typecheck())