From 774d6460b6de3d8dd670305851bd74a5ff07c938 Mon Sep 17 00:00:00 2001 From: Robin Jadoul Date: Fri, 6 Mar 2026 14:56:29 +0100 Subject: [PATCH] spec: Losing some MEMW weight --- spec/memw.typ | 42 +++++-- spec/src/memw.toml | 107 ++++------------- spec/src/memw_aligned.toml | 228 +++++++++++++++++++++++++++++++++++++ spec/tooling/chip.py | 2 +- 4 files changed, 285 insertions(+), 94 deletions(-) create mode 100644 spec/src/memw_aligned.toml diff --git a/spec/memw.typ b/spec/memw.typ index 4b644218a..57907e26c 100644 --- a/spec/memw.typ +++ b/spec/memw.typ @@ -38,6 +38,13 @@ we document it here, keeping the type information as a reading help. = Constraints +We can compute the addresses for the later bytes based on a single bit each, +indicating whether adding `i` to `base_address` overflows the lower limb. +We can safely assume that additions for which this bit is not correctly set +will have either an overflow on the upper or lower word, and hence not match +any existing memory tokens, which are only initialized for correctly formatted +and range-checked doublewords (see @memory). + #render_constraint_table(chip, config, groups: "consistency") As long as `timestamp` is properly range-checked, the presence of `old_timestamp` @@ -45,9 +52,9 @@ in the memory argument automatically ensures appropriate range checking (as long as no external entities provide negative multiplicities without range checking the timestamp). This ensures the assumptions for `LT` are satisfied. -We additionally check that the address does not overflow -for more significant bytes of the access. -#render_constraint_table(chip, config, groups: "overflow") +There is no need to check that the address does not overflow, +as our address calculations are not performed modulo $2^64$ here, +and any overflow will result in an address without matching initialization. The chip adds the following tuples to the lookup argument, to effectuate that part of the memory argument. @@ -56,11 +63,32 @@ to effectuate that part of the memory argument. This chip contributes the following to the lookup argument. #render_constraint_table(chip, config, groups: "output") += Read-size aligned fast path + +#let alignedchip = load_chip("src/memw_aligned.toml", config) +#let aligned = raw(alignedchip.name) + +When a memory access happens at an address with proper alignment +(that is, enough trailing zeros) for its access size, and all accessed +elements were last accessed at the same timestamp, we can +instead use the #aligned chip to save on total column count. +The saving comes from only requiring a single old timestamp to be stored, +as well as being able to guarantee that all values of `add_limb_overflow` would be zero. +A minor extra cost is introduced in the form of a check that the alignment is indeed correct, +and the corresponding decomposition of the `base_address`. + +Further logic remains essentially the same, so we briefly present the relevant tables for this chip. +#let nr_variables = total_nr_variables(alignedchip) +#let nr_columns = total_nr_instantiated_columns(alignedchip, config) + +The #aligned chip only needs #nr_variables variables, expressed through #nr_columns columns. +#render_chip_column_table(alignedchip, config) +#render_chip_assumptions(alignedchip, config) +#render_constraint_table(alignedchip, config) + = Future optimization ideas -- Fast path for aligned memory access where all bytes have the same old timestamp -- MEMB chip that deals does a one-byte write to remove old_timestamp from here (uncertain tradeoffs) -- Compute `base_address[1] + 1` once and have high words of `address_add` as Words -- Improve overflow trapping somehow so we don't need `LT` (could tie into previous one by checking carry bit of the +1) +- `MEMB` chip that does a one-byte write to remove old_timestamp from here (uncertain tradeoffs) +- Additional fast path for registers? (Always guaranteed same timestamp, alignment could be an assumption, always only two values) - Adding `μ_sum`/`w2`/`w4`/`write8` multiplicities to the `IS_HALF` lookups may make some GKR things faster if there are known zeroes. diff --git a/spec/src/memw.toml b/spec/src/memw.toml index 0ae3b9410..c9519e115 100644 --- a/spec/src/memw.toml +++ b/spec/src/memw.toml @@ -48,9 +48,9 @@ Only the elements corresponding to the `writeN` bits are guaranteed""" # Auxiliary [[variables.auxiliary]] -name = "address_add" -type = ["DWordHL", 7] -desc = "`address_add[i] = base_address + i + 1`" +name = "add_limb_overflow" +type = ["Bit", 7] +desc = "Whether adding `i` to `base_address[0]` as a field element exceeds $2^32$" [[variables.auxiliary]] name = "old_timestamp" @@ -71,6 +71,15 @@ type = "Bit" desc = "writing at least 4 bytes" def = ["+", "write4", "write8"] +[[variables.virtual]] +name = "address_add" +type = ["DWordWL", 7] +desc = "`address_add[i] = base_address + i + 1`" +def.iter = ["i", 0, 6] +def.poly = ["arr", + ["+", ["idx", "base_address", 0], "i", 1, ["*", ["-", ["^", 2, 32]], ["idx", "add_limb_overflow", "i"]]], + ["+", ["idx", "base_address", 1], ["idx", "add_limb_overflow", "i"]]] + [[variables.virtual]] name = "μ_sum" type = "Bit" @@ -126,56 +135,9 @@ poly = ["*", "w2", ["not", "μ_sum"]] [[constraints.consistency]] kind = "template" -tag = "ADD" -input = ["base_address", ["cast", 1, "DWordWL"]] -output = ["cast", ["idx", "address_add", 0], "DWordWL"] -cond = "w2" - -[[constraints.consistency]] -kind = "template" -tag = "ADD" -input = ["base_address", ["cast", ["+", "i", 1], "DWordWL"]] -output = ["cast", ["idx", "address_add", "i"], "DWordWL"] -iter = ["i", 1, 2] -cond = "w4" - -[[constraints.consistency]] -kind = "template" -tag = "ADD" -input = ["base_address", ["cast", ["+", "i", 1], "DWordWL"]] -output = ["cast", ["idx", "address_add", "i"], "DWordWL"] -iter = ["i", 3, 6] -cond = "write8" - -[[constraints.consistency]] -kind = "interaction" -tag = "IS_HALF" -input = [["idx", ["idx", "address_add", "i"], "j"]] -iters = [ - ["i", 0, 0], - ["j", 0, 3], -] -multiplicity = "w2" - -[[constraints.consistency]] -kind = "interaction" -tag = "IS_HALF" -input = [["idx", ["idx", "address_add", "i"], "j"]] -iters = [ - ["i", 1, 2], - ["j", 0, 3], -] -multiplicity = "w4" - -[[constraints.consistency]] -kind = "interaction" -tag = "IS_HALF" -input = [["idx", ["idx", "address_add", "i"], "j"]] -iters = [ - ["i", 3, 6], - ["j", 0, 3], -] -multiplicity = "write8" +tag = "IS_BIT" +input = [["idx", "add_limb_overflow", "i"]] +iter = ["i", 0, 6] [[constraints.consistency]] kind = "interaction" @@ -207,33 +169,6 @@ output = 1 iter = ["i", 4, 7] multiplicity = "write8" - -[[constraint_groups]] -name = "overflow" -prefix = "R" - -[[constraints.overflow]] -kind = "interaction" -tag = "LT" -input = ["base_address", ["cast", ["idx", "address_add", 0], "DWordWL"], 0] -output = 1 -multiplicity = "write2" - -[[constraints.overflow]] -kind = "interaction" -tag = "LT" -input = ["base_address", ["cast", ["idx", "address_add", 2], "DWordWL"], 0] -output = 1 -multiplicity = "write4" - -[[constraints.overflow]] -kind = "interaction" -tag = "LT" -input = ["base_address", ["cast", ["idx", "address_add", 6], "DWordWL"], 0] -output = 1 -multiplicity = "write8" - - [[constraint_groups]] name = "memory" prefix = "M" @@ -253,40 +188,40 @@ multiplicity = ["-", "μ_sum"] [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", 0], "DWordWL"], ["idx", "old_timestamp", 1], ["idx", "old", 1]] +input = ["is_register", ["idx", "address_add", 0], ["idx", "old_timestamp", 1], ["idx", "old", 1]] multiplicity = "w2" [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", 0], "DWordWL"], "timestamp", ["idx", "value", 1]] +input = ["is_register", ["idx", "address_add", 0], "timestamp", ["idx", "value", 1]] multiplicity = ["-", "w2"] [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", ["-", "i", 1]], "DWordWL"], ["idx", "old_timestamp", "i"], ["idx", "old", "i"]] +input = ["is_register", ["idx", "address_add", ["-", "i", 1]], ["idx", "old_timestamp", "i"], ["idx", "old", "i"]] multiplicity = "w4" iter = ["i", 2, 3] [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", ["-", "i", 1]], "DWordWL"], "timestamp", ["idx", "value", "i"]] +input = ["is_register", ["idx", "address_add", ["-", "i", 1]], "timestamp", ["idx", "value", "i"]] multiplicity = ["-", "w4"] iter = ["i", 2, 3] [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", ["-", "i", 1]], "DWordWL"], ["idx", "old_timestamp", "i"], ["idx", "old", "i"]] +input = ["is_register", ["idx", "address_add", ["-", "i", 1]], ["idx", "old_timestamp", "i"], ["idx", "old", "i"]] multiplicity = "write8" iter = ["i", 4, 7] [[constraints.memory]] kind = "interaction" tag = "memory" -input = ["is_register", ["cast", ["idx", "address_add", ["-", "i", 1]], "DWordWL"], "timestamp", ["idx", "value", "i"]] +input = ["is_register", ["idx", "address_add", ["-", "i", 1]], "timestamp", ["idx", "value", "i"]] multiplicity = ["-", "write8"] iter = ["i", 4, 7] diff --git a/spec/src/memw_aligned.toml b/spec/src/memw_aligned.toml new file mode 100644 index 000000000..715f57c85 --- /dev/null +++ b/spec/src/memw_aligned.toml @@ -0,0 +1,228 @@ +name = "MEMW-A" + +# Input + +[[variables.input]] +name = "is_register" +type = "Bit" +desc = "Whether the address represents a register index" + +[[variables.input]] +name = "base_address_high" +type = "Word" +desc = "The high word of the base address to read/write from/to, gets offset by $[0, 7]$, depending on how big the access is" + +[[variables.input]] +name = "base_address_mid" +type = "Half" +desc = "The middle halfword of the base address to read/write from/to, gets offset by $[0, 7]$, depending on how big the access is" + +[[variables.input]] +name = "base_address_low" +type = ["Byte", 2] +desc = "The low bytes of the base address to read/write from/to, gets offset by $[0, 7]$, depending on how big the access is" + +[[variables.input]] +name = "value" +type = ["BaseField", 8] +desc = "The values to store in memory. For regular memory, these should be (up to) 8 range-checked `Byte`s; registers are stored as two range-checked `Word`s" + +[[variables.input]] +name = "timestamp" +type = "DWordWL" +desc = "The timestamp at which this memory access is said to occur" + +[[variables.input]] +name = "write2" +type = "Bit" +desc = "Whether to write exactly 2 values" + +[[variables.input]] +name = "write4" +type = "Bit" +desc = "Whether to write exactly 4 values" + +[[variables.input]] +name = "write8" +type = "Bit" +desc = "Whether to write exactly 8 values" + +# Output + +[[variables.output]] +name = "old" +type = ["BaseField", 8] +desc = """The old value written at `base_address`. See `value` for information about representation. +Only the elements corresponding to the `writeN` bits are guaranteed""" + +# Auxiliary + +[[variables.auxiliary]] +name = "old_timestamp" +type = "DWordWL" +desc = "The timestamp at which the address was last accessed" + +# Virtual + +[[variables.virtual]] +name = "base_address" +type = "DWordWL" +desc = "Recomposing the base address from its parts" +defs = {idx = "i", polys = [["+", ["*", ["^", 2, 16], "base_address_mid"], ["*", ["^", 2, 8], ["idx", "base_address_low", 1]], ["idx", "base_address_low", 0]], "base_address_high"]} + +[[variables.virtual]] +name = "w2" +type = "Bit" +desc = "writing at least 2 bytes" +def = ["+", "write2", "write4", "write8"] + +[[variables.virtual]] +name = "w4" +type = "Bit" +desc = "writing at least 4 bytes" +def = ["+", "write4", "write8"] + +[[variables.virtual]] +name = "μ_sum" +type = "Bit" +desc = "" +def = ["+", "μ_read", "μ_write"] + +# Multiplicity + +[[variables.multiplicity]] +name = "μ_read" +type = "Bit" +desc = "Whether we are performing a read (and hence return `out`)" + +[[variables.multiplicity]] +name = "μ_write" +type = "Bit" +desc = "Whether we are performing a write (and hence not return `out`)" + +[[assumptions]] +desc = "`IS_WORD[base_address_high]`" + +[[assumptions]] +desc = "`IS_HALF[base_address_mid]`" + +[[assumptions]] +desc = "`IS_BYTE[base_address_low[i]]`" +iter = ["i", 0, 1] + +[[assumptions]] +desc = "`IS_BIT`" + +[[assumptions]] +desc = "`IS_BIT`" + +[[assumptions]] +desc = "`IS_BIT`" + +[[assumptions]] +desc = "`IS_BIT`" + +[[assumptions]] +desc = "`IS_WORD[timestamp[i]]`" +iter = ["i", 0, 1] + + +[[constraint_groups]] +name = "consistency" + +[[constraints.consistency]] +kind = "template" +tag = "AND_BYTE" +input = [["idx", "base_address_low", 0], ["+", ["*", "write2", 1], ["*", "write4", 3], ["*", "write8", 7]]] +output = 0 + +[[constraints.consistency]] +kind = "template" +tag = "IS_BIT" +input = ["μ_sum"] + +[[constraints.consistency]] +kind = "arith" +constraint = "$#`w2` => #`μ_sum`$" +poly = ["*", "w2", ["not", "μ_sum"]] + +[[constraints.consistency]] +kind = "interaction" +tag = "LT" +input = ["old_timestamp", "timestamp", 0] +output = 1 +multiplicity = "μ_sum" + +[[constraint_groups]] +name = "memory" +prefix = "M" + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", "base_address", "old_timestamp", ["idx", "old", 0]] +multiplicity = "μ_sum" + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", "base_address", "timestamp", ["idx", "value", 0]] +multiplicity = ["-", "μ_sum"] + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", 1, "DWordWL"]], "old_timestamp", ["idx", "old", 1]] +multiplicity = "w2" + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", 1, "DWordWL"]], "timestamp", ["idx", "value", 1]] +multiplicity = ["-", "w2"] + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", "i", "DWordWL"]], "old_timestamp", ["idx", "old", "i"]] +multiplicity = "w4" +iter = ["i", 2, 3] + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", "i", "DWordWL"]], "timestamp", ["idx", "value", "i"]] +multiplicity = ["-", "w4"] +iter = ["i", 2, 3] + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", "i", "DWordWL"]], "old_timestamp", ["idx", "old", "i"]] +multiplicity = "write8" +iter = ["i", 4, 7] + +[[constraints.memory]] +kind = "interaction" +tag = "memory" +input = ["is_register", ["+", "base_address", ["cast", "i", "DWordWL"]], "timestamp", ["idx", "value", "i"]] +multiplicity = ["-", "write8"] +iter = ["i", 4, 7] + + +[[constraint_groups]] +name = "output" +prefix = "O" + +[[constraints.output]] +kind = "interaction" +tag = "MEMW" +input = ["is_register", "base_address", "value", "timestamp", "write2", "write4", "write8"] +output = "old" +multiplicity = "μ_read" + +[[constraints.output]] +kind = "interaction" +tag = "MEMW" +input = ["is_register", "base_address", "value", "timestamp", "write2", "write4", "write8"] +multiplicity = "μ_write" diff --git a/spec/tooling/chip.py b/spec/tooling/chip.py index 58deb4b3c..688743754 100644 --- a/spec/tooling/chip.py +++ b/spec/tooling/chip.py @@ -627,7 +627,7 @@ class VirtualVariable(Variable): def_: VirtualDef def __init__(self, config: Config, category: str, data: dict): - assert_no_unexpected(data, set(Variable.__annotations__.keys()) | {"def"}) + assert_no_unexpected(data, (set(Variable.__annotations__.keys()) | {"def"}) - {"pad"}) reporter.asserts("def" in data, f"Missing def for virtual column: {data!r}") def_ = data.pop("def", {}) super().__init__(config, category, data)