Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions spec/mul.typ
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ The following range checks are assumed to be performed/enforced outside of this
#render_chip_assumptions(chip, config)

== Constraints

=== Overview
When `lhs` and `rhs` are _unsigned_ integers, computing their product $mod 2^128$ comes down to evaluating
$
Expand Down Expand Up @@ -67,14 +66,14 @@ We let `raw_product` capture the second summation in this last formula (see @mul
By construction, $#`raw_product`_i < 2^51$ for all $i in [0, 3]$, far exceeding the 32-bits that fit in a single `Word`-limb.
What remains then is to reduce each limb of `raw_product` $mod 2^32$, carrying the overflow of each limb to the next, constructing the output `res` in doing so.

This reduce-and-carry operation is constrained @mul:a:res and @mul:c:carry, combined with `carry`'s definition.
This reduce-and-carry operation is constrained by @mul:c:range_lo/@mul:c:range_hi and @mul:c:carry, combined with `carry`'s definition.
@mul:c:carry and `carry`'s definition enforce that
$
forall i in [0, 3]: #`raw_product`_i + #`carry`_(i-1) - #`res`_i in { k dot 2^32 | k in [0, 2^20) }
$
with $#`carry`_(-1) = 0$ for simplicity.
In other words: $#`res`_i equiv #`raw_product`_i + #`carry`_(i-1) (mod 2^32)$.
With @mul:a:res forcing $#`res`_i < 2^32$, $#`res`_i$ can only assume one value: $#`raw_product`_i + #`carry`_(i-1) mod 2^32$.
With @mul:c:range_lo/@mul:c:range_hi forcing $#`res`_i < 2^32$, $#`res`_i$ can only assume one value: $#`raw_product`_i + #`carry`_(i-1) mod 2^32$.

*Note*: one may have observed that @mul:c:carry requires $#`carry`_i in [0, 2^20)$, while no limb of a valid carry value would ever exceed $2^19$.
This is indeed the case.
Expand All @@ -83,7 +82,7 @@ In fact, in this situation it suffices to assert that $#`carry`_i < frac(p, 2^32
Given that other chips also use 20-bit lookups, using `IS_B20` makes for a simpler design.

=== Definitions
We constrain `lhs_is_negative` and `rhs_is_negative` according to their definition; `carry` is appropriately range checked.
We constrain `lhs_is_negative` and `rhs_is_negative` according to their definition; `lo`, `hi` and `carry` are appropriately range checked.
#render_constraint_table(chip, config, groups: "def")

=== Product
Expand All @@ -99,3 +98,12 @@ The #mul chip contributes the following to the lookup:
The table can be padded to the next power of two with the following value assignments:

#render_chip_padding_table(chip, config)

== Notes
- `lo` and `hi` are stored in `DWordHL`s (rather than `DWordWL`s) because of their values being range checked.
Since it is not required that both `μ_lo` and `μ_hi` are non-zero at the same time, one cannot safely assume their range to be checked elsewhere.

As an optimization, one might be able to use a `DWordWL` and `DWordHL` to store `lo` and `hi`,
where one would decide which to store in which based on the multiplicities `μ_lo` and `μ_hi`;
the value sent into the lookup could then be assumed range-checked by the other side of the relation.
This optimization was not included at this moment because of its negative impact on the readability and verifiability of the chip.
68 changes: 47 additions & 21 deletions spec/src/mul.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@ pad = 0
# Output

[[variables.output]]
name = "res"
type = "QuadWL"
desc = "the (extended) multiplication result"
name = "lo"
type = "DWordHL"
desc = "the lower limbs of the (extended) multiplication result"
pad = 0

[[variables.output]]
name = "hi"
type = "DWordHL"
desc = "the upper limbs of the (extended) multiplication result"
pad = 0

# Auxiliary
Expand Down Expand Up @@ -63,26 +69,36 @@ name = "lhs_ext"
type = ["Half", 8]
desc = "sign-extended value of `lhs`"
def = {idx="i", polys=[
{range=[0, 3], poly=["idx", "lhs", "i"]},
{range=[4, 7], poly=["*", 0xFFFF, "lhs_is_negative"]},
{iter=[0, 3], poly=["idx", "lhs", "i"]},
{iter=[4, 7], poly=["*", 0xFFFF, "lhs_is_negative"]},
]}

[[variables.virtual]]
name = "rhs_ext"
type = ["Half", 8]
desc = "sign-extended value of `rhs`"
def = {idx="i", polys=[
{range=[0, 3], poly=["idx", "rhs", "i"]},
{range=[4, 7], poly=["*", 0xFFFF, "rhs_is_negative"]},
{iter=[0, 3], poly=["idx", "rhs", "i"]},
{iter=[4, 7], poly=["*", 0xFFFF, "rhs_is_negative"]},
]}

[[variables.virtual]]
name = "res"
type = "QuadWL"
desc = "concatenation of `lo` and `hi`."
def = {idx="i", polys=[
{iter=[0, 1], poly=["idx", ["cast", "lo", "DWordWL"], "i"]},
{iter=[2, 3], poly=["idx", ["cast", "hi", "DWordWL"], ["-", "i", 2]]},
]}


[[variables.virtual]]
name = "carry"
type = ["B20", 4]
desc = "carry values"
def = {idx="i", polys=[
{range=0, poly=["*", ["^", 2, -32], ["-", ["idx", "raw_product", 0], ["idx", "res", 0]]]},
{range=[1, 3], poly=["*", ["^", 2, -32], ["-", ["+", ["idx", "raw_product", "i"], ["idx", "carry", ["-", "i", 1]]], ["idx", "res", "i"]]]},
{iter=0, poly=["*", ["^", 2, -32], ["-", ["idx", "raw_product", 0], ["idx", "res", 0]]]},
{iter=[1, 3], poly=["*", ["^", 2, -32], ["-", ["+", ["idx", "raw_product", "i"], ["idx", "carry", ["-", "i", 1]]], ["idx", "res", "i"]]]},
]}

[[variables.virtual]]
Expand All @@ -109,17 +125,11 @@ pad = 0

[[assumptions]]
desc = "`IS_HALF[lhs[i]]`"
range = ["i", 0, 3]
iter = ["i", 0, 3]

[[assumptions]]
desc = "`IS_HALF[rhs[i]]`"
range = ["i", 0, 3]

[[assumptions]]
desc = "`IS_WORD[res[i]]`"
range = ["i", 0, 3]
ref = "mul:a:res"

iter = ["i", 0, 3]

# Constraints

Expand All @@ -140,11 +150,27 @@ input = [["idx", "rhs", 3], "rhs_signed"]
output = "rhs_is_negative"
ref = "mul:c:rhs_is_negative"

[[constraints.def]]
kind = "interaction"
tag = "IS_HALF"
input = [["idx", "lo", "i"]]
iter = ["i", 0, 3]
multiplicity = "μ_sum"
ref = "mul:c:range_lo"

[[constraints.def]]
kind = "interaction"
tag = "IS_HALF"
input = [["idx", "hi", "i"]]
iter = ["i", 0, 3]
multiplicity = "μ_sum"
ref = "mul:c:range_hi"

[[constraints.def]]
kind = "interaction"
tag = "IS_B20"
input = [["idx", "carry", "i"]]
range = ["i", 0, 3]
iter = ["i", 0, 3]
multiplicity = "μ_sum"
ref = "mul:c:carry"

Expand All @@ -156,7 +182,7 @@ name = "prod"
kind = "arith"
constraint = "$#`raw_product[i]` = sum_(#`k`=0)^1 2^(16k) sum_(#`j`=0)^(2i+k) #`lhs_ext[j]` dot #`rhs_ext[2i+k-j]`$"
poly = ["-", ["sum", ["=", "k", 0], "1", ["*", ["^", 2, ["*", 16, "k"]], ["sum", ["=", "j", 0], ["+", ["*", 2, "i"], "k"], ["*", ["idx", "lhs_ext", "j"], ["idx", "rhs_ext", ["-", ["+", ["*", 2, "i"], "k"], "j"]]]]]], ["idx", "raw_product", "i"]]
range = ["i", 0, 3]
iter = ["i", 0, 3]
ref = "mul:c:raw_product"

[[constraint_groups]]
Expand All @@ -166,14 +192,14 @@ name = "lookup"
kind = "interaction"
tag = "MUL"
input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "0"]
output = ["idx", "res", "0:4"]
output = ["cast", "lo", "DWordWL"]
multiplicity = ["-", "μ_lo"]
ref = "mul:c:lookup_lo"

[[constraints.lookup]]
kind = "interaction"
tag = "MUL"
input = ["lhs", "lhs_signed", "rhs", "rhs_signed", "1"]
output = ["idx", "res", "4:8"]
output = ["cast", "hi", "DWordWL"]
multiplicity = ["-", "μ_hi"]
ref = "mul:c:lookup_hi"