1
0
mirror of https://github.com/vlang/v.git synced 2023-08-10 21:13:21 +03:00

math.big: add mod_inverse and improve big_mod_pow to allow for large exponents and moduli (#18461)

This commit is contained in:
phoebe
2023-06-19 16:59:49 +02:00
committed by GitHub
parent 396d46d9ca
commit a3f24caffc
3 changed files with 599 additions and 63 deletions

View File

@@ -415,16 +415,55 @@ pub fn (a Integer) % (b Integer) Integer {
return r
}
// mask_bits is the equivalent of `a % 2^n` (only when `a >= 0`), however doing a full division
// run for this would be a lot of work when we can simply "cut off" all bits to the left of
// the `n`th bit.
[direct_array_access]
fn (a Integer) mask_bits(n u32) Integer {
$if debug {
assert a.signum >= 0
}
if a.digits.len == 0 || n == 0 {
return zero_int
}
w := n / 32
b := n % 32
if w >= a.digits.len {
return a
}
return Integer{
digits: if b == 0 {
mut storage := []u32{len: int(w)}
for i := 0; i < storage.len; i++ {
storage[i] = a.digits[i]
}
storage
} else {
mut storage := []u32{len: int(w) + 1}
for i := 0; i < storage.len; i++ {
storage[i] = a.digits[i]
}
storage[w] &= ~(u32(-1) << b)
storage
}
signum: 1
}
}
// pow returns the integer `a` raised to the power of the u32 `exponent`.
pub fn (a Integer) pow(exponent u32) Integer {
pub fn (base Integer) pow(exponent u32) Integer {
if exponent == 0 {
return one_int
}
if exponent == 1 {
return a.clone()
return base.clone()
}
mut n := exponent
mut x := a
mut x := base
mut y := one_int
for n > 1 {
if n & 1 == 1 {
@@ -436,63 +475,74 @@ pub fn (a Integer) pow(exponent u32) Integer {
return x * y
}
// mod_pow returns the integer `a` raised to the power of the u32 `exponent` modulo the integer `divisor`.
pub fn (a Integer) mod_pow(exponent u32, divisor Integer) Integer {
// mod_pow returns the integer `a` raised to the power of the u32 `exponent` modulo the integer `modulus`.
pub fn (base Integer) mod_pow(exponent u32, modulus Integer) Integer {
if exponent == 0 {
return one_int
}
if exponent == 1 {
return a % divisor
return base % modulus
}
mut n := exponent
mut x := a % divisor
mut x := base % modulus
mut y := one_int
for n > 1 {
if n & 1 == 1 {
y *= x % divisor
y *= x % modulus
}
x *= x % divisor
x *= x % modulus
n >>= 1
}
return x * y % divisor
return x * y % modulus
}
// big_mod_power returns the integer `a` raised to the power of the integer `exponent` modulo the integer `divisor`.
// big_mod_pow returns the integer `base` raised to the power of the integer `exponent` modulo the integer `modulus`.
[direct_array_access]
pub fn (a Integer) big_mod_pow(exponent Integer, divisor Integer) Integer {
pub fn (base Integer) big_mod_pow(exponent Integer, modulus Integer) !Integer {
if exponent.signum < 0 {
panic('Exponent needs to be non-negative.')
return error('math.big: Exponent needs to be non-negative.')
}
if exponent.signum == 0 {
// this goes first as otherwise 1 could be returned incorrectly if base == 1
if modulus.bit_len() <= 1 {
return zero_int
}
// x^0 == 1 || 1^x == 1
if exponent.signum == 0 || base.bit_len() == 1 {
return one_int
}
mut x := a % divisor
mut y := one_int
mut n := u32(0)
// For all but the last digit of the exponent
for index in 0 .. exponent.digits.len - 1 {
n = exponent.digits[index]
for _ in 0 .. 32 {
if n & 1 == 1 {
y *= x % divisor
}
x *= x % divisor
n >>= 1
}
// 0^x == 0 (x != 0 due to previous clause)
if base.signum == 0 {
return one_int
}
// Last digit of the exponent
n = exponent.digits.last()
for n > 1 {
if n & 1 == 1 {
y *= x % divisor
if exponent.bit_len() == 1 {
// x^1 without mod == x
if modulus.signum == 0 {
return base
}
x *= x % divisor
n >>= 1
// x^1 (mod m) === x % m
return base % modulus
}
return x * y % divisor
// the amount of precomputation in windowed exponentiation (done in the montgomery and binary
// windowed exponentiation algorithms) is far too costly for small sized exponents, so
// we redirect the call to mod_pow
return if exponent.digits.len > 1 {
if modulus.is_odd() {
// modulus is odd, therefore we use the normal
// montgomery modular exponentiation algorithm
base.mont_odd(exponent, modulus)
} else if modulus.is_power_of_2() {
base.exp_binary(exponent, modulus)
} else {
base.mont_even(exponent, modulus)
}
} else {
base.mod_pow(exponent.digits[0], modulus)
}
}
// inc returns the integer `a` incremented by 1.
@@ -956,6 +1006,98 @@ fn gcd_binary(x Integer, y Integer) Integer {
return b.lshift(shift)
}
// mod_inverse calculates the multiplicative inverse of the integer `a` in the ring `/n`.
// Therefore, the return value `x` satisfies `a * x == 1 (mod m)`.
// -----
// An error is returned if `a` and `n` are not relatively prime, i.e. `gcd(a, n) != 1` or
// if n <= 1
[inline]
pub fn (a Integer) mod_inverse(n Integer) !Integer {
return if n.bit_len() <= 1 {
error('math.big: Modulus `n` must be greater than 1')
} else if a.gcd(n) != one_int {
error('math.big: No multiplicative inverse')
} else {
a.mod_inv(n)
}
}
// this is an internal function, therefore we assume valid inputs,
// i.e. m > 1 and gcd(a, m) = 1
// see pub fn mod_inverse for details on the result
// -----
// the algorithm is based on the Extended Euclidean algorithm which computes `ax + by = d`
// in this case `b` is the input integer `a` and `a` is the input modulus `m`. The extended
// Euclidean algorithm calculates the greatest common divisor `d` and two coefficients `x` and `y`
// satisfying the above equality.
//
// For the sake of clarity, we refer to the input integer `a` as `b` and the integer `m` as `a`.
// If `gcd(a, b) = d = 1` then the coefficient `y` is known to be the multiplicative inverse of
// `b` in ring `Z/aZ`, since reducing `ax + by = 1` by `a` yields `by == 1 (mod a)`.
[direct_array_access]
fn (a Integer) mod_inv(m Integer) Integer {
mut n := Integer{
digits: m.digits.clone()
signum: 1
}
mut b := a
mut x := one_int
mut y := zero_int
if b.signum < 0 || b.abs_cmp(n) >= 0 {
b = b % n
}
mut sign := -1
for b != zero_int {
q, r := if n.bit_len() == b.bit_len() {
one_int, n - b
} else {
n.div_mod(b)
}
n = b
b = r
// tmp := q * x + y
tmp := if q == one_int {
x
} else if q.digits.len == 1 && q.digits[0] & (q.digits[0] - 1) == 0 {
x.lshift(u32(bits.trailing_zeros_32(q.digits[0])))
} else {
q * x
} + y
y = x
x = tmp
sign = -sign
}
if sign < 0 {
y = m - y
}
$if debug {
assert n == one_int
}
return if y.signum > 0 && y.abs_cmp(m) < 0 {
y
} else {
y % m
}
}
[direct_array_access; inline]
fn (x Integer) is_odd() bool {
return x.digits[0] & 1 == 1
}
// is_power_of_2 returns true when the integer `x` satisfies `2^n`, where `n >= 0`
[inline]
pub fn (x Integer) is_power_of_2() bool {
return x.bitwise_and(x - one_int).bit_len() == 0
}
// bit_len returns the number of bits required to represent the integer `a`.
[inline]
pub fn (x Integer) bit_len() int {