diff --git a/spec/book.typ b/spec/book.typ index 3e001e1e0..01a362879 100644 --- a/spec/book.typ +++ b/spec/book.typ @@ -12,6 +12,7 @@ #chapter("shift.typ")[SHIFT chip] #chapter("branch.typ")[BRANCH] #chapter("lt.typ")[LT] + #chapter("mul.typ")[MUL chip] ] ) diff --git a/spec/expr.typ b/spec/expr.typ index 547f2cad2..c452d70d3 100644 --- a/spec/expr.typ +++ b/spec/expr.typ @@ -145,8 +145,8 @@ "*": (pp, rec, e) => mwrap($#e.slice(1).map(rec.with(PREC.mul)).join($dot$)$, pp < PREC.mul), "/": (pp, rec, e) => $#rec(PREC.div, e.at(1)) / #rec(PREC.div, e.at(2))$, "^": (pp, rec, e) => { - assert(type(e.at(1)) == int and type(e.at(2)) == int, message: "Can only exponentiate constants") - $#e.at(1)^#e.at(2)$ + assert(type(e.at(1)) == int, message: "Can only exponentiate constants") + $#e.at(1)^#rec(PREC.MAX, e.at(2))$ }, "=": (pp, rec, e) => $#rec(PREC.eq, e.at(1)) = #rec(PREC.eq, e.at(2))$, ":=": (pp, rec, e) => $#rec(PREC.eq, e.at(1)) := #rec(PREC.eq, e.at(2))$, diff --git a/spec/mul.typ b/spec/mul.typ new file mode 100644 index 000000000..92fafe26b --- /dev/null +++ b/spec/mul.typ @@ -0,0 +1,94 @@ +#import "/book.typ": book-page +#import "/src.typ": load_config, load_chip +#import "/chip.typ": ( + render_chip_column_table, + total_nr_variables, + total_nr_instantiated_columns, + render_constraint_table, + render_chip_assumptions, +) + +#let config = load_config() +#let chip = load_chip("src/mul.toml", config) + +#show: book-page.with(title: "MUL chip") + +#let mul = raw(chip.name) + += #mul chip + +== Columns +#let nr_variables = total_nr_variables(chip) +#let nr_columns = total_nr_instantiated_columns(chip, config) + +The `MUL` chip is comprised of #nr_variables variables that are expressed using #nr_columns columns: +#render_chip_column_table(chip, config) + +#let stackrel(top, bottom) = { + $mat(delim: #none, top; bottom)$ +} + +== Assumptions +The following range checks are assumed to be performed/enforced outside of this chip: +#render_chip_assumptions(chip, config) + +== Constraints + +=== Overview +When `lhs` and `rhs` are _unsigned_ integers, computing their product $mod 2^128$ comes down to evaluating +$ +(sum_(j=0)^3 2^(16j) dot #`lhs`_j) dot (sum_(i=0)^3 2^(16i) dot #`rhs`_i) mod 2^128. +$ +If `lhs` and `rhs` are signed instead, the computation remains nearly identical: +based on their signs, one must either zero or one-extend `lhs` and `rhs` --- forming `lhs_ext` and `rhs_ext` respectively --- and compute their product $mod 2^128$: +$ +(sum_(j=0)^7 2^(16j) dot #`lhs_ext`_j) dot (sum_(i=0)^7 2^(16i) dot #`rhs_ext`_i) mod 2^128. +$ +where `lhs_ext` and `rhs_ext` are treated as _unsigned_ integers. +Note that by setting the extension limbs of `lhs` and/or `rhs` to $0$ when the integer is (i) unsigned or (ii) signed and non-negative, this second formula still applies. +For the purposes of constraining the multiplication operation, we rewrite this formula as +#show math.equation: set block(breakable: true) +$ + &(sum_(j=0)^7 2^(16j) dot #`lhs_ext`_j) dot (sum_(i=0)^7 2^(16i) dot #`rhs_ext`_i) mod 2^128 \ + &equiv sum_(j=0)^7 sum_(i=0)^7 2^(16(i+j)) dot #`lhs_ext`_j dot #`rhs_ext`_i mod 2^128 \ + &stackrel(triangle, equiv) sum_(j=0)^7 sum_(i=0)^(7-j) 2^(16(i+j)) dot #`lhs_ext`_j dot #`rhs_ext`_i mod 2^128 \ + &stackrel(square, equiv) sum_(j=0)^7 sum_(i=j)^(7) 2^(16i) dot #`lhs_ext`_j dot #`rhs_ext`_(i-j) mod 2^128 \ + &stackrel(penta, equiv) sum_(i=0)^7 sum_(j=0)^(i) 2^(16i) dot #`lhs_ext`_j dot #`rhs_ext`_(i-j) mod 2^128 \ + &equiv sum_(i=0)^3 sum_(k=0)^1 sum_(j=0)^(2i+k) 2^(16(2i+k)) dot #`lhs_ext`_j dot #`rhs_ext`_(2i+k-j) mod 2^128 \ + &equiv sum_(i=0)^3 2^(32i) dot sum_(k=0)^1 2^(16k) dot sum_(j=0)^(2i+k) #`lhs_ext`_j dot #`rhs_ext`_(2i+k-j) mod 2^128 +$ +where at step +- $triangle$ we can ignore $i > 7-j$, since that makes $2^(16(i+j)) equiv 0 mod 2^128$, +- $square$ we rewrite the second summation such that $i$ iterates from $j$ to 7, rather than $0$ to $7-j$, and +- $penta$ we swap the sums. + +We let `raw_product` capture the second summation in this last formula (see @mul:c:raw_product). +By construction, $#`raw_product`_i < 2^51$ for all $i in [0, 3]$, far exceeding the 32-bits that fit in a single `Word`-limb. +What remains then is to reduce each limb of `raw_product` $mod 2^32$, carrying the overflow of each limb to the next, constructing the output `res` in doing so. + +This reduce-and-carry operation is constrained @mul:a:res and @mul:c:carry, combined with `carry`'s definition. +@mul:c:carry and `carry`'s definition enforce that +$ + forall i in [0, 3]: #`raw_product`_i + #`carry`_(i-1) - #`res`_i in { k dot 2^32 | k in [0, 2^20) } +$ +with $#`carry`_(-1) = 0$ for simplicity. +In other words: $#`res`_i equiv #`raw_product`_i + #`carry`_(i-1) (mod 2^32)$. +With @mul:a:res forcing $#`res`_i < 2^32$, $#`res`_i$ can only assume one value: $#`raw_product`_i + #`carry`_(i-1) mod 2^32$. + +*Note*: one may have observed that @mul:c:carry requires $#`carry`_i in [0, 2^20)$, while no limb of a valid carry value would ever exceed $2^19$. +This is indeed the case. +However, there is some slack in how tight one has to constrain the `carry` values. +In fact, in this situation it suffices to assert that $#`carry`_i < frac(p, 2^32, style: "skewed") approx 2^31$, where $p$ denotes the field's modulus. +Given that other chips also use 20-bit lookups, using `IS_B20` makes for a simpler design. + +=== Definitions +We constrain `lhs_is_negative` and `rhs_is_negative` according to their definition; `carry` is appropriately range checked. +#render_constraint_table(chip, config, groups: "def") + +=== Product +@mul:c:raw_product defines `raw_product` in terms of the (sign extended) input values `lhs` and `rhs`. +#render_constraint_table(chip, config, groups: "prod") + +=== Lookup +The #mul chip contributes the following to the lookup: +#render_constraint_table(chip, config, groups: "lookup") \ No newline at end of file diff --git a/spec/src/config.toml b/spec/src/config.toml index b66639e2a..389e4b16a 100644 --- a/spec/src/config.toml +++ b/spec/src/config.toml @@ -27,6 +27,11 @@ label = "Half" subtypes = ["BaseField"] desc = "Variable that can only assume values in the range $[0, 2^16)$." +[[variables.types]] +label = "B20" +subtypes = ["BaseField"] +desc = "Variable that can only assume values in the range $[0, 2^20)$." + [[variables.types]] label = "Word" subtypes = ["BaseField"] @@ -48,6 +53,16 @@ desc = """\ Represented as an array of four `Byte` variables.\ """ +[[variables.types]] +label = "B35" +subtypes = ["BaseField"] +desc = "Variable that can only assume values in the range $[0, 2^35)$." + +[[variables.types]] +label = "B51" +subtypes = ["BaseField"] +desc = "Variable that can only assume values in the range $[0, 2^51)$." + [[variables.types]] label = "DWordBL" subtypes = ["Byte", "Byte", "Byte", "Byte", "Byte", "Byte", "Byte", "Byte"] @@ -81,6 +96,22 @@ desc = """\ The `Word` is the *least* significant digit. """ +[[variables.types]] +label = "QuadHL" +subtypes = ["Half", "Half", "Half", "Half", "Half", "Half", "Half", "Half"] +desc = """\ + Variable that can only assume values in the range $[0, 2^128)$. \\ + Represented as an array of eight `Half` variables.\ + """ + +[[variables.types]] +label = "QuadWL" +subtypes = ["Word", "Word", "Word", "Word"] +desc = """\ + Variable that can only assume values in the range $[0, 2^128)$. \\ + Represented as an array of four `Word` variables.\ + """ + [[variables.types]] label = "DWordWHH" subtypes = ["Half", "Half", "Word"] diff --git a/spec/src/mul.toml b/spec/src/mul.toml new file mode 100644 index 000000000..bf9ffc276 --- /dev/null +++ b/spec/src/mul.toml @@ -0,0 +1,179 @@ +name = "MUL" + + +# Input + +[[variables.input]] +name = "lhs" +type = "DWordHL" +desc = "the left hand operator." +pad = 0 + +[[variables.input]] +name = "lhs_signed" +type = "Bit" +desc = "whether to interpret `lhs` as a signed integer (1) or not (0)." +pad = 0 + +[[variables.input]] +name = "rhs" +type = "DWordHL" +desc = "the right hand operator." +pad = 0 + +[[variables.input]] +name = "rhs_signed" +type = "Bit" +desc = "whether to interpret `rhs` as a signed integer (1) or not (0)." +pad = 0 + + +# Output + +[[variables.output]] +name = "res" +type = "QuadWL" +desc = "the (extended) multiplication result" +pad = 0 + +# Auxiliary + +[[variables.auxiliary]] +name = "lhs_is_negative" +type = "Bit" +desc = "whether `lhs` is negative (1) or not (0)" +pad = 0 + +[[variables.auxiliary]] +name = "rhs_is_negative" +type = "Bit" +desc = "whether `rhs` is negative (1) or not (0)" +pad = 0 + +[[variables.auxiliary]] +name = "raw_product" +type = ["B51", 4] +desc = "raw multiplication output" +pad = 0 + +# Virtual + +[[variables.virtual]] +name = "lhs_ext" +type = ["Half", 8] +desc = "sign-extended value of `lhs`" +def = {idx="i", polys=[ + {range=[0, 3], poly=["idx", "lhs", "i"]}, + {range=[4, 7], poly=["*", 0xFFFF, "lhs_is_negative"]}, +]} + +[[variables.virtual]] +name = "rhs_ext" +type = ["Half", 8] +desc = "sign-extended value of `rhs`" +def = {idx="i", polys=[ + {range=[0, 3], poly=["idx", "rhs", "i"]}, + {range=[4, 7], poly=["*", 0xFFFF, "rhs_is_negative"]}, +]} + +[[variables.virtual]] +name = "carry" +type = ["B20", 4] +desc = "carry values" +def = {idx="i", polys=[ + {range=0, poly=["*", ["^", 2, -32], ["-", ["idx", "raw_product", 0], ["idx", "res", 0]]]}, + {range=[1, 3], poly=["*", ["^", 2, -32], ["-", ["+", ["idx", "raw_product", "i"], ["idx", "carry", ["-", "i", 1]]], ["idx", "res", "i"]]]}, +]} + +[[variables.virtual]] +name = "μ_sum" +type = "BaseField" +desc = "sum of multiplicies" +def = ["+", "μ_lo", "μ_hi"] + +# Multiplicity + +[[variables.multiplicity]] +name = "μ_lo" +type = "BaseField" +desc = "" +pad = 0 + +[[variables.multiplicity]] +name = "μ_hi" +type = "BaseField" +desc = "" +pad = 0 + +# Assumptions + +[[assumptions]] +desc = "`IS_HALF[lhs[i]]`" +range = ["i", 0, 3] + +[[assumptions]] +desc = "`IS_HALF[rhs[i]]`" +range = ["i", 0, 3] + +[[assumptions]] +desc = "`IS_WORD[res[i]]`" +range = ["i", 0, 3] +ref = "mul:a:res" + + +# Constraints + +[[constraint_groups]] +name = "def" + +[[constraints.def]] +kind = "template" +tag = "SIGN" +input = [["idx", "lhs", 3], "lhs_signed"] +output = "lhs_is_negative" +ref = "mul:c:lhs_is_negative" + +[[constraints.def]] +kind = "template" +tag = "SIGN" +input = [["idx", "rhs", 3], "rhs_signed"] +output = "rhs_is_negative" +ref = "mul:c:rhs_is_negative" + +[[constraints.def]] +kind = "interaction" +tag = "IS_B20" +input = [["idx", "carry", "i"]] +range = ["i", 0, 3] +multiplicity = "μ_sum" +ref = "mul:c:carry" + +[[constraint_groups]] +name = "prod" + + +[[constraints.prod]] +kind = "arith" +constraint = "$#`raw_product[i]` = sum_(#`k`=0)^1 2^(16k) sum_(#`j`=0)^(2i+k) #`lhs_ext[j]` dot #`rhs_ext[2i+k-j]`$" +poly = ["-", ["sum", ["=", "k", 0], "1", ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]] +range = ["i", 0, 3] +ref = "mul:c:raw_product" + +[[constraint_groups]] +name = "lookup" + +[[constraints.lookup]] +kind = "interaction" +tag = "MUL" +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "0"] +output = ["idx", "res", "0:4"] +multiplicity = ["-", "μ_lo"] +ref = "mul:c:lookup_lo" + +[[constraints.lookup]] +kind = "interaction" +tag = "MUL" +input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "1"] +output = ["idx", "res", "4:8"] +multiplicity = ["-", "μ_hi"] +ref = "mul:c:lookup_hi" \ No newline at end of file