diff --git a/.github/workflows/build-wheel.yml b/.github/workflows/build-wheel.yml new file mode 100644 index 0000000..8a177a4 --- /dev/null +++ b/.github/workflows/build-wheel.yml @@ -0,0 +1,48 @@ +name: Build Wheel + +on: + push: + branches: [main, master] + tags: + - "v*" + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install build backend + run: python -m pip install --upgrade pdm-backend build + + - name: Install test dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-dev.lock + pip install pytest + + - name: Run tests with pytest + run: pytest + + - name: Build wheel + run: python -m build --wheel + + - name: Create Release + id: create_release + uses: softprops/action-gh-release@v1 + with: + files: | + dist/*.whl + draft: false + prerelease: false + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..3e388a4 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13.2 diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..aa38025 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,10 @@ +MIT License + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/libnum/ecc.py b/libnum/ecc.py deleted file mode 100644 index 9c10bd1..0000000 --- a/libnum/ecc.py +++ /dev/null @@ -1,167 +0,0 @@ -import random - -from .sqrtmod import sqrtmod_prime_power, has_sqrtmod_prime_power -from .modular import invmod - -__all__ = ('NULL_POINT', 'Curve') - -NULL_POINT = (None, None) - - -class Curve: - def __init__(self, a, b, p, g=None, - order=None, - cofactor=None, - seed=None): - self.a = a - self.b = b - self.module = p - - self.g = g - self.order = order - self.cofactor = cofactor - self.seed = seed - self.points_count = None - if self.cofactor == 1 and self.order is not None: - self.points_count = self.order - return None - - def is_null(self, p): - """ - Check if a point is curve's null point - """ - return p == NULL_POINT - - def is_opposite(self, p1, p2): - """ - Check if one point is opposite to another (p1 == -p2) - """ - x1, y1 = p1 - x2, y2 = p2 - return (x1 == x2 and y1 == -y2 % self.module) - - def check(self, p): - """ - Check if point is on the curve - """ - x, y = p - if self.is_null(p): - return True - left = (y ** 2) % self.module - right = self.right(x) - return left == right - - def check_x(self, x): - """ - Check if there is a point on the curve with given @x coordinate - """ - if x > self.module or x < 0: - raise ValueError("Value " + str(x) + - " is not in range [0; ]") - a = self.right(x) - n = self.module - - if not has_sqrtmod_prime_power(a, n): - return False - - ys = sqrtmod_prime_power(a, n) - return map(lambda y: (x, y), ys) - - def right(self, x): - """ - Right part of the curve equation: x^3 + a*x + b (mod p) - """ - return (x ** 3 + self.a * x + self.b) % self.module - - def find_points_in_range(self, start=0, end=None): - """ - List of points in given range for x coordinate - """ - points = [] - - if end is None: - end = self.module - 1 - - for x in range(start, end + 1): - p = self.check_x(x) - if not p: - continue - points.extend(p) - - return points - - def find_points_rand(self, number=1): - """ - List of @number random points on the curve - """ - points = [] - - while len(points) < number: - x = random.randint(0, self.module) - p = self.check_x(x) - if not p: - continue - points.append(p) - - return points - - def add(self, p1, p2): - """ - Sum of two points - """ - if self.is_null(p1): - return p2 - - if self.is_null(p2): - return p1 - - if self.is_opposite(p1, p2): - return NULL_POINT - - x1, y1 = p1 - x2, y2 = p2 - - l = 0 - if x1 != x2: - l = (y2 - y1) * invmod(x2 - x1, self.module) - else: - l = (3 * x1 ** 2 + self.a) * invmod(2 * y1, self.module) - - x = (l * l - x1 - x2) % self.module - y = (l * (x1 - x) - y1) % self.module # yes, it's that new x - return (x, y) - - def power(self, p, n): - """ - n✕P or (P + P + ... + P) n times - """ - if n == 0 or self.is_null(p): - return NULL_POINT - - res = NULL_POINT - while n: - if n & 1: - res = self.add(res, p) - p = self.add(p, p) - n >>= 1 - return res - - def generate(self, n): - """ - Too lazy to give self.g to self.power - """ - return self.power(self.g, n) - - def get_order(self, p, limit=None): - """ - Tries to calculate order of @p, returns None if @limit is reached - (SLOW method) - """ - order = 1 - res = p - while not self.is_null(res): - res = self.add(res, p) - order += 1 - if limit is not None and order >= limit: - return None - return order diff --git a/pyproject.toml b/pyproject.toml index 8d8d98f..3ade360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ -[tool.poetry] +[project] name = "libnum" -version = "1.7.1" +version = "1.7.2" description = "Working with numbers (primes, modular, etc.)" -authors = ["hellman"] +authors = [{ "name" = "hellman" }] license = "MIT" readme = "README.md" keywords = ["numbers", "modular", "cryptography", "number theory"] @@ -11,15 +11,13 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Security :: Cryptography', ] - -[tool.poetry.urls] -homepage = "http://github.com/hellman/libnum" - -[tool.poetry.dependencies] -python = "^3.4" - -[tool.poetry.dev-dependencies] +dependencies = [] +requires-python = ">= 3.6" [build-system] -requires = ["poetry-core>=1.0.0a5"] -build-backend = "poetry.core.masonry.api" +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[tool.rye] +managed = true +dev-dependencies = ["pytest>=8.3.5"] diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..2d3da02 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,19 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +iniconfig==2.1.0 + # via pytest +packaging==25.0 + # via pytest +pluggy==1.5.0 + # via pytest +pytest==8.3.5 diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..505fd45 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,12 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. diff --git a/libnum/__init__.py b/src/libnum/__init__.py similarity index 100% rename from libnum/__init__.py rename to src/libnum/__init__.py index 75e64b4..c4888ea 100644 --- a/libnum/__init__.py +++ b/src/libnum/__init__.py @@ -8,15 +8,15 @@ # commonly used things from fractions import Fraction -from .primes import * -from .factorize import * +from . import ecc +from .chains import * from .common import * +from .factorize import * from .modular import * +from .primes import * from .sqrtmod import * -from .stuff import * from .strings import * -from .chains import * -from . import ecc +from .stuff import * # TODO: Add doctest after we have better docs diff --git a/libnum/chains.py b/src/libnum/chains.py similarity index 89% rename from libnum/chains.py rename to src/libnum/chains.py index 82fab02..ba0a3a3 100644 --- a/libnum/chains.py +++ b/src/libnum/chains.py @@ -1,7 +1,9 @@ -from libnum import nroot, gcd, Fraction +from fractions import Fraction +from .common import gcd, nroot -class Chain(object): + +class Chain: def __init__(self, *args): self._chain = [] self._frac = None @@ -15,11 +17,11 @@ def _calcFrac(self): r = Fraction(0, 1) for x in reversed(self.chain): r = Fraction(1, x + r) - return 1/r + return 1 / r def _calcChain(self): r = [] - a, b = self.frac.numerator, self.frac.denominator + a, b = self.frac.numerator, self.frac.denominator # type: ignore while b != 1: r.append(a / b) a, b = b, a % b @@ -63,9 +65,9 @@ def convergents(self): def sqrt_chained_fractions(n, limit=None): - ''' + """ E.g. sqrt_chained_fractions(13) = [3,(1,1,1,1,6)] - ''' + """ s = nroot(n, 2) if s**2 == n: return [s] @@ -90,14 +92,14 @@ def sqrt_chained_fractions(n, limit=None): def _sqrt_iter(n, s, t, a, b): - ''' + """ take t*(sqrt(n)+a)/b s = floor(sqrt(n)) return (v, next fraction params t, a, b) - ''' + """ v = t * (s + a) // b t2 = b - b2 = t * (n - (b * v - a)**2) + b2 = t * (n - (b * v - a) ** 2) a2 = b * v - a g = gcd(t2, b2) t2 //= g @@ -105,6 +107,6 @@ def _sqrt_iter(n, s, t, a, b): return v, (t2, a2, b2) -if __name__ == '__main__': +if __name__ == "__main__": for v in (2, 3, 5, 6, 7, 8, 10, 11, 12, 13, 1337, 31337): print("sqrt(%d): %s" % (v, repr(sqrt_chained_fractions(v)))) diff --git a/libnum/common.py b/src/libnum/common.py similarity index 81% rename from libnum/common.py rename to src/libnum/common.py index ca2b6da..b0759e2 100644 --- a/libnum/common.py +++ b/src/libnum/common.py @@ -1,10 +1,10 @@ import math import random - from functools import reduce +from typing import Tuple -def len_in_bits(n): +def len_in_bits(n: int) -> int: """ Return number of bits in binary representation of @n. Probably deprecated by .bit_length(). @@ -14,18 +14,18 @@ def len_in_bits(n): return n.bit_length() -def randint_bits(size): +def randint_bits(size: int) -> int: return random.getrandbits(size) | (1 << (size - 1)) -def ceil(x, y): +def ceil(x: int, y: int) -> int: """ Divide x by y with ceiling. """ return (x + y - 1) // y -def nroot(x, n): +def nroot(x: int, n: int) -> int: """ Return truncated n'th root of x. Using binary search. @@ -44,15 +44,16 @@ def nroot(x, n): raise ValueError("can't extract even root of negative") high = 1 - while high ** n <= x: + while high**n <= x: high <<= 1 low = high >> 1 + mid = (low + high) >> 1 while low < high: mid = (low + high) >> 1 - if low < mid and mid ** n < x: + if low < mid and mid**n < x: low = mid - elif high > mid and mid ** n > x: + elif high > mid and mid**n > x: high = mid else: return sign * mid @@ -62,7 +63,7 @@ def nroot(x, n): _gcd = math.gcd -def _lcm(a, b): +def _lcm(a: int, b: int) -> int: """ Return lowest common multiple. """ @@ -71,21 +72,21 @@ def _lcm(a, b): return abs(a * b) // _gcd(a, b) -def gcd(*lst): +def gcd(*lst: int) -> int: """ Return gcd of a variable number of arguments. """ return abs(reduce(lambda a, b: _gcd(a, b), lst)) -def lcm(*lst): +def lcm(*lst: int) -> int: """ Return lcm of a variable number of arguments. """ return abs(reduce(lambda a, b: _lcm(a, b), lst)) -def xgcd(a, b): +def xgcd(a: int, b: int) -> Tuple[int, int, int]: """ Extented Euclid GCD algorithm. Return (x, y, g) : a * x + b * y = gcd(a, b) = g. @@ -109,7 +110,7 @@ def xgcd(a, b): return ppx, ppy, a -def extract_prime_power(a, p): +def extract_prime_power(a: int, p: int) -> Tuple[int, int]: """ Return s, t such that a = p**s * t, t % p = 0 """ diff --git a/src/libnum/ecc.py b/src/libnum/ecc.py new file mode 100644 index 0000000..3b19695 --- /dev/null +++ b/src/libnum/ecc.py @@ -0,0 +1,189 @@ +import random +from typing import List, Optional, Tuple, Union + +from .modular import invmod +from .sqrtmod import has_sqrtmod_prime_power, sqrtmod_prime_power + +__all__ = ("NULL_POINT", "Curve") + +Point = Tuple[Optional[int], Optional[int]] +NULL_POINT: Point = (None, None) + + +class Curve: + def __init__( + self, + a: int, + b: int, + p: int, + g: Optional[Point] = None, + order: Optional[int] = None, + cofactor: Optional[int] = None, + seed: Optional[int] = None, + ) -> None: + self.a: int = a + self.b: int = b + self.module: int = p + + self.g: Optional[Point] = g + self.order: Optional[int] = order + self.cofactor: Optional[int] = cofactor + self.seed: Optional[int] = seed + self.points_count: Optional[int] = None + if self.cofactor == 1 and self.order is not None: + self.points_count = self.order + return None + + def is_null(self, p: Point) -> bool: + """ + Check if a point is curve's null point + """ + return p == NULL_POINT + + def is_opposite(self, p1: Point, p2: Point) -> bool: + """ + Check if one point is opposite to another (p1 == -p2) + """ + if self.is_null(p1) or self.is_null(p2): + return False + x1, y1 = p1 + x2, y2 = p2 + if x1 is None or y1 is None or x2 is None or y2 is None: + return False + return x1 == x2 and y1 == -y2 % self.module + + def check(self, p: Point) -> bool: + """ + Check if point is on the curve + """ + x, y = p + if self.is_null(p): + return True + if x is None or y is None: + return False + left = (y**2) % self.module + right = self.right(x) + return left == right + + def check_x(self, x: int) -> Union[bool, List[Point]]: + """ + Check if there is a point on the curve with given @x coordinate + """ + if x > self.module or x < 0: + raise ValueError("Value " + str(x) + " is not in range [0; ]") + a = self.right(x) + n = self.module + + if not has_sqrtmod_prime_power(a, n): + return False + + ys = sqrtmod_prime_power(a, n) + return list(map(lambda y: (x, y), ys)) + + def right(self, x: int) -> int: + """ + Right part of the curve equation: x^3 + a*x + b (mod p) + """ + return (x**3 + self.a * x + self.b) % self.module + + def find_points_in_range( + self, start: int = 0, end: Optional[int] = None + ) -> List[Point]: + """ + List of points in given range for x coordinate + """ + points: List[Point] = [] + + if end is None: + end = self.module - 1 + + for x in range(start, end + 1): + p = self.check_x(x) + if p is False: + continue + if isinstance(p, list): + points.extend(p) + + return points + + def find_points_rand(self, number: int = 1) -> List[Point]: + """ + List of @number random points on the curve + """ + points: List[Point] = [] + + while len(points) < number: + x = random.randint(0, self.module) + p = self.check_x(x) + if p is False: + continue + if isinstance(p, list): + points.append(p[0]) # Take first point found + + return points + + def add(self, p1: Point, p2: Point) -> Point: + """ + Sum of two points + """ + if self.is_null(p1): + return p2 + + if self.is_null(p2): + return p1 + + if self.is_opposite(p1, p2): + return NULL_POINT + + x1, y1 = p1 + x2, y2 = p2 + + if x1 is None or y1 is None or x2 is None or y2 is None: + return NULL_POINT + + slope = 0 + if x1 != x2: + slope = (y2 - y1) * invmod(x2 - x1, self.module) + else: + slope = (3 * x1**2 + self.a) * invmod(2 * y1, self.module) + + x = (slope * slope - x1 - x2) % self.module + y = (slope * (x1 - x) - y1) % self.module # yes, it's that new x + return (x, y) + + def power(self, p: Point, n: int) -> Point: + """ + n✕P or (P + P + ... + P) n times + """ + if n == 0 or self.is_null(p): + return NULL_POINT + + res = NULL_POINT + while n: + if n & 1: + res = self.add(res, p) + p = self.add(p, p) + n >>= 1 + return res + + def generate(self, n: int) -> Point: + """ + Too lazy to give self.g to self.power + """ + if self.g is None: + return NULL_POINT + return self.power(self.g, n) + + def get_order(self, p: Point, limit: Optional[int] = None) -> Optional[int]: + """ + Tries to calculate order of @p, returns None if @limit is reached + (SLOW method) + """ + order = 1 + res = p + while not self.is_null(res): + res = self.add(res, p) + order += 1 + if limit is not None and order >= limit: + return None + return order diff --git a/libnum/factorize.py b/src/libnum/factorize.py similarity index 74% rename from libnum/factorize.py rename to src/libnum/factorize.py index 9fc2fa2..b96c0f7 100644 --- a/libnum/factorize.py +++ b/src/libnum/factorize.py @@ -4,20 +4,20 @@ import math import random - from functools import reduce -from .primes import primes, prime_test -from .common import gcd, nroot +from typing import Callable, Dict, List, Optional +from .common import gcd, nroot +from .primes import prime_test, primes -__all__ = "factorize unfactorize".split() +__all__ = ["factorize", "unfactorize"] -_PRIMES_CHECK = primes(100) -_PRIMES_P1 = primes(100) +_PRIMES_CHECK: List[int] = primes(100) +_PRIMES_P1: List[int] = primes(100) -def rho_pollard_reduce(n, f): +def rho_pollard_reduce(n: int, f: Callable[[int], int]) -> int: # use Pollard's (p-1) method to narrow down search a = random.randint(2, n - 2) for p in _PRIMES_P1: @@ -38,10 +38,11 @@ def rho_pollard_reduce(n, f): return g -_FUNC_REDUCE = lambda n: rho_pollard_reduce(n, lambda x: (pow(x, 2, n) + 1) % n) +def _FUNC_REDUCE(n: int) -> int: + return rho_pollard_reduce(n, lambda x: (pow(x, 2, n) + 1) % n) -def factorize(n): +def factorize(n: int) -> Dict[int, int]: """ Use _FUNC_REDUCE (defaults to rho-pollard method) to factorize @n Return a dict like {p: e} @@ -49,7 +50,7 @@ def factorize(n): if n in (0, 1): return {n: 1} - prime_factors = {} + prime_factors: Dict[int, int] = {} if n < 0: n = -n @@ -60,7 +61,7 @@ def factorize(n): prime_factors[p] = prime_factors.get(p, 0) + 1 n //= p - factors = [n] + factors: List[int] = [n] if n == 1: if not prime_factors: prime_factors[1] = 1 @@ -69,6 +70,7 @@ def factorize(n): while factors: n = factors.pop() + p: Optional[int] = None if prime_test(n): p = n prime_factors[p] = prime_factors.get(p, 0) + 1 @@ -76,8 +78,8 @@ def factorize(n): is_pp = is_power(n) if is_pp: + p, e = is_pp if prime_test(p): - p, e = is_pp prime_factors[p] = prime_factors.get(p, 0) + e continue # else we need to factor @p and remember power @@ -92,14 +94,14 @@ def factorize(n): return prime_factors -def unfactorize(factors): +def unfactorize(factors: Dict[int, int]) -> int: return reduce(lambda acc, p_e: acc * (p_e[0] ** p_e[1]), factors.items(), 1) -def is_power(n): +def is_power(n: int) -> Optional[tuple[int, int]]: limit = int(math.log(n, 2)) for power in range(limit, 1, -1): p = nroot(n, power) if pow(p, power) == n: return p, power - return False + return None diff --git a/libnum/modular.py b/src/libnum/modular.py similarity index 82% rename from libnum/modular.py rename to src/libnum/modular.py index 79844ae..fea840b 100644 --- a/libnum/modular.py +++ b/src/libnum/modular.py @@ -1,12 +1,12 @@ import operator - from functools import reduce +from typing import Dict, List from .common import gcd, xgcd -from .stuff import factorial_get_prime_pow, factorial +from .stuff import factorial, factorial_get_prime_pow -def has_invmod(a, modulus): +def has_invmod(a: int, modulus: int) -> bool: """ Check if @a can be inversed under @modulus. Call this before calling invmod. @@ -20,7 +20,7 @@ def has_invmod(a, modulus): return True -def invmod(a, n): +def invmod(a: int, n: int) -> int: """ Return 1 / a (mod n). @a and @n must be co-primes. @@ -36,7 +36,7 @@ def invmod(a, n): return x % n -def solve_crt(remainders, modules): +def solve_crt(remainders: List[int], modules: List[int]) -> int: """ Solve Chinese Remainder Theorem. @modules and @remainders are lists. @@ -64,26 +64,26 @@ def solve_crt(remainders, modules): return x % N -def nCk_mod(n, k, factors): +def nCk_mod(n: int, k: int, factors: Dict[int, int]) -> int: """ Compute nCk modulo, factorization of modulus is needed """ - rems = [] - mods = [] + rems: List[int] = [] + mods: List[int] = [] for p, e in factors.items(): rems.append(nCk_mod_prime_power(n, k, p, e)) - mods.append(p ** e) + mods.append(p**e) return solve_crt(rems, mods) -def factorial_mod(n, factors): +def factorial_mod(n: int, factors: Dict[int, int]) -> int: """ Compute factorial modulo, factorization of modulus is needed """ - rems = [] - mods = [] + rems: List[int] = [] + mods: List[int] = [] for p, e in factors.items(): - pe = p ** e + pe = p**e if n >= pe or factorial_get_prime_pow(n, p) >= e: factmod = 0 else: @@ -93,7 +93,7 @@ def factorial_mod(n, factors): return solve_crt(rems, mods) -def nCk_mod_prime_power(n, k, p, e): +def nCk_mod_prime_power(n: int, k: int, p: int, e: int) -> int: """ Compute nCk mod small prime power: p**e Algorithm by Andrew Granville: @@ -110,7 +110,7 @@ def nCk_get_prime_pow(n, k, p): return res def nCk_get_non_prime_part(n, k, p, e): - pe = p ** e + pe = p**e r = n - k fact_pe = [1] @@ -154,5 +154,5 @@ def nCk_get_non_prime_part(n, k, p, e): modpow = e - prime_part_pow - r = nCk_get_non_prime_part(n, k, p, modpow) % (p ** modpow) - return ((p ** prime_part_pow) * r) % (p ** e) + r = nCk_get_non_prime_part(n, k, p, modpow) % (p**modpow) + return ((p**prime_part_pow) * r) % (p**e) diff --git a/libnum/primes.py b/src/libnum/primes.py similarity index 84% rename from libnum/primes.py rename to src/libnum/primes.py index 1865dd3..9edc33f 100644 --- a/libnum/primes.py +++ b/src/libnum/primes.py @@ -1,20 +1,20 @@ import math -import random import operator - +import random from functools import reduce +from typing import List, Optional +from .common import extract_prime_power, gcd, len_in_bits, randint_bits from .sqrtmod import jacobi -from .common import len_in_bits, gcd, extract_prime_power, randint_bits from .strings import s2n -_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] -_small_primes_product = 1 -_primes_bits = [[] for i in range(11)] -_primes_mask = [] +_primes: List[int] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] +_small_primes_product: int = 1 +_primes_bits: List[List[int]] = [[] for i in range(11)] +_primes_mask: List[bool] = [] -def _init(): +def _init() -> None: global _small_primes_product, _primes, _primes_bits, _primes_mask _primes = primes(1024) for p in _primes: @@ -24,7 +24,7 @@ def _init(): return -def primes(until): +def primes(until: int) -> List[int]: """ Return list of primes not greater than @until. Rather slow. """ @@ -51,7 +51,7 @@ def primes(until): return _primes -def generate_prime(size, k=25): +def generate_prime(size: int, k: int = 25) -> int: """ Generate a pseudo-prime with @size bits length. Optional arg @k=25 defines number of tests. @@ -73,7 +73,7 @@ def generate_prime(size, k=25): return -def generate_prime_from_string(s, size=None, k=25): +def generate_prime_from_string(s: str, size: Optional[int] = None, k: int = 25) -> int: """ Generate a pseudo-prime starting with @s in string representation. Optional arg @size defines length in bits, if is not set than +some bytes. @@ -94,7 +94,7 @@ def generate_prime_from_string(s, size=None, k=25): extend_len = size - len(s) * 8 visible_part = s2n(s) << extend_len - hi = 2 ** extend_len + hi = 2**extend_len while True: n = visible_part | random.randint(1, hi) | 1 # only even @@ -107,7 +107,7 @@ def generate_prime_from_string(s, size=None, k=25): return -def prime_test_ferma(p, k=25): +def prime_test_ferma(p: int, k: int = 25) -> bool: """ Test for primality based on Ferma's Little Theorem Totally fails in Carmichael'e numbers @@ -130,7 +130,7 @@ def prime_test_ferma(p, k=25): return True -def prime_test_solovay_strassen(p, k=25): +def prime_test_solovay_strassen(p: int, k: int = 25) -> bool: """ Test for primality by Solovai-Strassen Stronger than Ferma's test @@ -156,11 +156,13 @@ def prime_test_solovay_strassen(p, k=25): return True -def prime_test_miller_rabin(p, k=25): +def prime_test_miller_rabin(p: Optional[int], k: int = 25) -> bool: """ Test for primality by Miller-Rabin Stronger than Solovay-Strassen's test """ + if p is None: + return False if p < 2: return False if p <= 3: @@ -191,7 +193,7 @@ def prime_test_miller_rabin(p, k=25): if i < s - 1: break # good else: - return False # bad + return False # bad else: # result is not 1 return False diff --git a/libnum/ranges.py b/src/libnum/ranges.py similarity index 78% rename from libnum/ranges.py rename to src/libnum/ranges.py index 82ec222..fb5b6a9 100644 --- a/libnum/ranges.py +++ b/src/libnum/ranges.py @@ -1,6 +1,6 @@ import json - from functools import reduce +from typing import Any, Iterator, List, Tuple """ TODO: fix properties for empty @@ -24,12 +24,12 @@ class Ranges(object): - add_range method - unite with (x, y) range """ - def __init__(self, *ranges): - self._segments = [] - for (a, b) in ranges: + def __init__(self, *ranges: Tuple[int, int]): + self._segments: List[Tuple[int, int]] = [] + for a, b in ranges: self.add_range(a, b) - def add_range(self, x, y): + def add_range(self, x: int, y: int) -> None: if y < x: raise ValueError("end is smaller than start: %d < %d" % (y, x)) @@ -58,7 +58,7 @@ def add_range(self, x, y): self._segments.append((x, y)) return - def __or__(self, other): + def __or__(self, other: "Ranges") -> "Ranges": res = Ranges() for x, y in self._segments: res.add_range(x, y) @@ -66,7 +66,7 @@ def __or__(self, other): res.add_range(x, y) return res - def __and__(self, other): + def __and__(self, other: "Ranges") -> "Ranges": res = [] index1 = 0 index2 = 0 @@ -97,48 +97,45 @@ def __and__(self, other): index1 += 1 return Ranges(*res) - def __iter__(self): + def __iter__(self) -> Iterator[int]: for a, b in self._segments: while a <= b: yield a a += 1 return - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.segments == other.segments @property - def len(self): - return reduce( - lambda acc, ab: acc + 1 + ab[1] - ab[0], - self._segments, 0 - ) + def len(self) -> int: + return reduce(lambda acc, ab: acc + 1 + ab[1] - ab[0], self._segments, 0) @property - def min(self): + def min(self) -> int: return self._segments[0][0] @property - def max(self): + def max(self) -> int: return self._segments[-1][1] @property - def segments(self): + def segments(self) -> Tuple[Tuple[int, int], ...]: return tuple(self._segments) - def __str__(self): + def __str__(self) -> str: return str(self.segments) - def __contains__(self, other): + def __contains__(self, other: int) -> bool: assert isinstance(other, int) for a, b in self._segments: if a <= other <= b: return True return False - def to_json(self): + def to_json(self) -> str: return json.dumps(self._segments) @classmethod - def from_json(cls, j): + def from_json(cls, j: str) -> "Ranges": return Ranges(*json.loads(j)) diff --git a/libnum/sqrtmod.py b/src/libnum/sqrtmod.py similarity index 85% rename from libnum/sqrtmod.py rename to src/libnum/sqrtmod.py index 1a5f198..a4bb168 100644 --- a/libnum/sqrtmod.py +++ b/src/libnum/sqrtmod.py @@ -1,11 +1,12 @@ import random from itertools import product +from typing import Dict, Iterator, List, Tuple from .common import extract_prime_power -from .modular import solve_crt, invmod +from .modular import invmod, solve_crt -def has_sqrtmod(a, factors=None): +def has_sqrtmod(a: int, factors: Dict[int, int]) -> bool: """ Check if @a is quadratic residue, factorization needed @factors - list of (prime, power) tuples @@ -22,28 +23,24 @@ def has_sqrtmod(a, factors=None): return True -def sqrtmod(a, factors): +def sqrtmod(a: int, factors: Dict[int, int]) -> Iterator[int]: """ x ^ 2 = a (mod *factors). Yield square roots by product of @factors as modulus. @factors - list of (prime, power) tuples """ - coprime_factors = [p ** k for p, k in factors.items()] + coprime_factors = [p**k for p, k in factors.items()] - sqrts = [] + sqrts: List[List[int]] = [] for i, (p, k) in enumerate(factors.items()): - # it's bad that all roots by each modulus are calculated here - # - we can start yielding roots faster - sqrts.append( - list(sqrtmod_prime_power(a % coprime_factors[i], p, k)) - ) + sqrts.append(list(sqrtmod_prime_power(a % coprime_factors[i], p, k))) for rems in product(*sqrts): - yield solve_crt(rems, coprime_factors) + yield solve_crt(list(rems), coprime_factors) return -def has_sqrtmod_prime_power(a, p, n=1): +def has_sqrtmod_prime_power(a: int, p: int, n: int = 1) -> bool: """ Check if @a (mod @p**@n) is quadratic residue, @p is prime. """ @@ -53,7 +50,7 @@ def has_sqrtmod_prime_power(a, p, n=1): if n < 1: raise ValueError("Prime power must be positive: " + str(n)) - a = a % (p ** n) + a = a % (p**n) if a in (0, 1): return True @@ -71,7 +68,7 @@ def has_sqrtmod_prime_power(a, p, n=1): return jacobi(a, p) == 1 -def sqrtmod_prime_power(a, p, k=1): +def sqrtmod_prime_power(a: int, p: int, k: int = 1) -> Iterator[int]: """ Yield square roots of @a mod @p**@k, @p - prime @@ -87,11 +84,11 @@ def sqrtmod_prime_power(a, p, k=1): powers.append(pow_p) # x**2 == a (mod p), p is prime - def sqrtmod_prime(a, p): + def sqrtmod_prime(a: int, p: int) -> Tuple[int, ...]: if a == 0: return (0,) if a == 1: - return (1, p-1) if p != 2 else (1,) + return (1, p - 1) if p != 2 else (1,) if jacobi(a, p) == -1: raise ValueError("No square root for %d (mod %d)" % (a, p)) @@ -115,11 +112,11 @@ def sqrtmod_prime(a, p): return (r, (-r) % p) # both roots # x**2 == a (mod p**k), p is prime, gcd(a, p) == 1 - def sqrtmod_prime_power_for_coprime(a, p, k): + def sqrtmod_prime_power_for_coprime(a: int, p: int, k: int) -> Tuple[int, ...]: if a == 1: if p == 2: if k == 1: - return (1, ) + return (1,) if k == 2: return (1, 3) if k == 3: @@ -147,7 +144,7 @@ def sqrtmod_prime_power_for_coprime(a, p, k): roots = next_roots roots = [pow_p - r for r in roots] + list(roots) - return roots + return tuple(roots) else: # p >= 3 r = sqrtmod_prime(a, p)[0] # any root @@ -156,7 +153,7 @@ def sqrtmod_prime_power_for_coprime(a, p, k): next_powind = min(powind * 2, k) # Represent root: x = +- (r + p**powind * t1) b = (a - r**2) % powers[next_powind] - b = (b * invmod(2*r, powers[next_powind])) % powers[next_powind] + b = (b * invmod(2 * r, powers[next_powind])) % powers[next_powind] if b: if b % powers[powind]: raise ValueError("No square root for given value") @@ -172,7 +169,7 @@ def sqrtmod_prime_power_for_coprime(a, p, k): return # x**2 == 0 (mod p**k), p is prime - def sqrt_for_zero(p, k): + def sqrt_for_zero(p: int, k: int) -> List[int]: roots = [0] start_k = (k // 2 + 1) if k & 1 else (k // 2) @@ -222,7 +219,7 @@ def sqrt_for_zero(p, k): return -def jacobi(a, n): +def jacobi(a: int, n: int) -> int: """ Return Jacobi symbol (or Legendre symbol if n is prime) """ diff --git a/libnum/strings.py b/src/libnum/strings.py similarity index 82% rename from libnum/strings.py rename to src/libnum/strings.py index 679e8b7..40c3ad5 100644 --- a/libnum/strings.py +++ b/src/libnum/strings.py @@ -1,7 +1,9 @@ +from typing import Union + from .common import len_in_bits -def s2n(s): +def s2n(s: Union[str, bytes]) -> int: r""" String to number (big endian). @@ -15,7 +17,7 @@ def s2n(s): return int.from_bytes(s, "big") -def n2s(n): +def n2s(n: int) -> bytes: r""" Number to string (big endian). @@ -29,7 +31,7 @@ def n2s(n): return n.to_bytes(nbytes, "big") -def s2b(s): +def s2b(s: Union[str, bytes]) -> str: """ String to binary. @@ -40,7 +42,7 @@ def s2b(s): return "0" * ((8 - len(res)) % 8) + res -def b2s(b): +def b2s(b: str) -> bytes: """ Binary to string. diff --git a/libnum/stuff.py b/src/libnum/stuff.py similarity index 67% rename from libnum/stuff.py rename to src/libnum/stuff.py index 7432582..3d6af06 100644 --- a/libnum/stuff.py +++ b/src/libnum/stuff.py @@ -1,13 +1,12 @@ import operator - from functools import reduce -def grey_code(n): +def grey_code(n: int) -> int: return n ^ (n >> 1) -def rev_grey_code(g): +def rev_grey_code(g: int) -> int: n = 0 while g: n ^= g @@ -15,7 +14,7 @@ def rev_grey_code(g): return n -def factorial(n): +def factorial(n: int) -> int: res = 1 while n > 1: res *= n @@ -23,7 +22,7 @@ def factorial(n): return res -def factorial_get_prime_pow(n, p): +def factorial_get_prime_pow(n: int, p: int) -> int: """ Return power of prime @p in @n! """ @@ -35,7 +34,7 @@ def factorial_get_prime_pow(n, p): return count -def nCk(n, k): +def nCk(n: int, k: int) -> int: """ Combinations number """ @@ -45,14 +44,13 @@ def nCk(n, k): return 0 if k in (0, n): return 1 - if k in (1, n-1): + if k in (1, n - 1): return n low_min = 1 low_max = min(n, k) high_min = max(1, n - k + 1) high_max = n - return ( - reduce(operator.mul, range(high_min, high_max + 1), 1) - // reduce(operator.mul, range(low_min, low_max + 1), 1) + return reduce(operator.mul, range(high_min, high_max + 1), 1) // reduce( + operator.mul, range(low_min, low_max + 1), 1 ) diff --git a/tests/test_common.py b/tests/test_common.py index 910e560..7576d4e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,5 +1,6 @@ import pytest -from libnum import len_in_bits, gcd, lcm, nroot + +from libnum import gcd, lcm, len_in_bits, nroot def test_len_in_bits(): @@ -23,7 +24,7 @@ def test_len_in_bits(): def test_nroot(): for x in range(0, 100): for p in range(1, 3): - n = x ** p + n = x**p assert nroot(n, p) == x assert nroot(-64, 3) == -4