mirror of
https://github.com/vlang/v.git
synced 2023-08-10 21:13:21 +03:00
math.big: add mod_inverse and improve big_mod_pow to allow for large exponents and moduli (#18461)
This commit is contained in:
@@ -415,16 +415,55 @@ pub fn (a Integer) % (b Integer) Integer {
|
||||
return r
|
||||
}
|
||||
|
||||
// mask_bits is the equivalent of `a % 2^n` (only when `a >= 0`), however doing a full division
|
||||
// run for this would be a lot of work when we can simply "cut off" all bits to the left of
|
||||
// the `n`th bit.
|
||||
[direct_array_access]
|
||||
fn (a Integer) mask_bits(n u32) Integer {
|
||||
$if debug {
|
||||
assert a.signum >= 0
|
||||
}
|
||||
|
||||
if a.digits.len == 0 || n == 0 {
|
||||
return zero_int
|
||||
}
|
||||
|
||||
w := n / 32
|
||||
b := n % 32
|
||||
|
||||
if w >= a.digits.len {
|
||||
return a
|
||||
}
|
||||
|
||||
return Integer{
|
||||
digits: if b == 0 {
|
||||
mut storage := []u32{len: int(w)}
|
||||
for i := 0; i < storage.len; i++ {
|
||||
storage[i] = a.digits[i]
|
||||
}
|
||||
storage
|
||||
} else {
|
||||
mut storage := []u32{len: int(w) + 1}
|
||||
for i := 0; i < storage.len; i++ {
|
||||
storage[i] = a.digits[i]
|
||||
}
|
||||
storage[w] &= ~(u32(-1) << b)
|
||||
storage
|
||||
}
|
||||
signum: 1
|
||||
}
|
||||
}
|
||||
|
||||
// pow returns the integer `a` raised to the power of the u32 `exponent`.
|
||||
pub fn (a Integer) pow(exponent u32) Integer {
|
||||
pub fn (base Integer) pow(exponent u32) Integer {
|
||||
if exponent == 0 {
|
||||
return one_int
|
||||
}
|
||||
if exponent == 1 {
|
||||
return a.clone()
|
||||
return base.clone()
|
||||
}
|
||||
mut n := exponent
|
||||
mut x := a
|
||||
mut x := base
|
||||
mut y := one_int
|
||||
for n > 1 {
|
||||
if n & 1 == 1 {
|
||||
@@ -436,63 +475,74 @@ pub fn (a Integer) pow(exponent u32) Integer {
|
||||
return x * y
|
||||
}
|
||||
|
||||
// mod_pow returns the integer `a` raised to the power of the u32 `exponent` modulo the integer `divisor`.
|
||||
pub fn (a Integer) mod_pow(exponent u32, divisor Integer) Integer {
|
||||
// mod_pow returns the integer `a` raised to the power of the u32 `exponent` modulo the integer `modulus`.
|
||||
pub fn (base Integer) mod_pow(exponent u32, modulus Integer) Integer {
|
||||
if exponent == 0 {
|
||||
return one_int
|
||||
}
|
||||
if exponent == 1 {
|
||||
return a % divisor
|
||||
return base % modulus
|
||||
}
|
||||
mut n := exponent
|
||||
mut x := a % divisor
|
||||
mut x := base % modulus
|
||||
mut y := one_int
|
||||
for n > 1 {
|
||||
if n & 1 == 1 {
|
||||
y *= x % divisor
|
||||
y *= x % modulus
|
||||
}
|
||||
x *= x % divisor
|
||||
x *= x % modulus
|
||||
n >>= 1
|
||||
}
|
||||
return x * y % divisor
|
||||
return x * y % modulus
|
||||
}
|
||||
|
||||
// big_mod_power returns the integer `a` raised to the power of the integer `exponent` modulo the integer `divisor`.
|
||||
// big_mod_pow returns the integer `base` raised to the power of the integer `exponent` modulo the integer `modulus`.
|
||||
[direct_array_access]
|
||||
pub fn (a Integer) big_mod_pow(exponent Integer, divisor Integer) Integer {
|
||||
pub fn (base Integer) big_mod_pow(exponent Integer, modulus Integer) !Integer {
|
||||
if exponent.signum < 0 {
|
||||
panic('Exponent needs to be non-negative.')
|
||||
return error('math.big: Exponent needs to be non-negative.')
|
||||
}
|
||||
if exponent.signum == 0 {
|
||||
|
||||
// this goes first as otherwise 1 could be returned incorrectly if base == 1
|
||||
if modulus.bit_len() <= 1 {
|
||||
return zero_int
|
||||
}
|
||||
|
||||
// x^0 == 1 || 1^x == 1
|
||||
if exponent.signum == 0 || base.bit_len() == 1 {
|
||||
return one_int
|
||||
}
|
||||
mut x := a % divisor
|
||||
mut y := one_int
|
||||
mut n := u32(0)
|
||||
|
||||
// For all but the last digit of the exponent
|
||||
for index in 0 .. exponent.digits.len - 1 {
|
||||
n = exponent.digits[index]
|
||||
for _ in 0 .. 32 {
|
||||
if n & 1 == 1 {
|
||||
y *= x % divisor
|
||||
}
|
||||
x *= x % divisor
|
||||
n >>= 1
|
||||
}
|
||||
// 0^x == 0 (x != 0 due to previous clause)
|
||||
if base.signum == 0 {
|
||||
return one_int
|
||||
}
|
||||
|
||||
// Last digit of the exponent
|
||||
n = exponent.digits.last()
|
||||
for n > 1 {
|
||||
if n & 1 == 1 {
|
||||
y *= x % divisor
|
||||
if exponent.bit_len() == 1 {
|
||||
// x^1 without mod == x
|
||||
if modulus.signum == 0 {
|
||||
return base
|
||||
}
|
||||
x *= x % divisor
|
||||
n >>= 1
|
||||
// x^1 (mod m) === x % m
|
||||
return base % modulus
|
||||
}
|
||||
|
||||
return x * y % divisor
|
||||
// the amount of precomputation in windowed exponentiation (done in the montgomery and binary
|
||||
// windowed exponentiation algorithms) is far too costly for small sized exponents, so
|
||||
// we redirect the call to mod_pow
|
||||
return if exponent.digits.len > 1 {
|
||||
if modulus.is_odd() {
|
||||
// modulus is odd, therefore we use the normal
|
||||
// montgomery modular exponentiation algorithm
|
||||
base.mont_odd(exponent, modulus)
|
||||
} else if modulus.is_power_of_2() {
|
||||
base.exp_binary(exponent, modulus)
|
||||
} else {
|
||||
base.mont_even(exponent, modulus)
|
||||
}
|
||||
} else {
|
||||
base.mod_pow(exponent.digits[0], modulus)
|
||||
}
|
||||
}
|
||||
|
||||
// inc returns the integer `a` incremented by 1.
|
||||
@@ -956,6 +1006,98 @@ fn gcd_binary(x Integer, y Integer) Integer {
|
||||
return b.lshift(shift)
|
||||
}
|
||||
|
||||
// mod_inverse calculates the multiplicative inverse of the integer `a` in the ring `ℤ/nℤ`.
|
||||
// Therefore, the return value `x` satisfies `a * x == 1 (mod m)`.
|
||||
// -----
|
||||
// An error is returned if `a` and `n` are not relatively prime, i.e. `gcd(a, n) != 1` or
|
||||
// if n <= 1
|
||||
[inline]
|
||||
pub fn (a Integer) mod_inverse(n Integer) !Integer {
|
||||
return if n.bit_len() <= 1 {
|
||||
error('math.big: Modulus `n` must be greater than 1')
|
||||
} else if a.gcd(n) != one_int {
|
||||
error('math.big: No multiplicative inverse')
|
||||
} else {
|
||||
a.mod_inv(n)
|
||||
}
|
||||
}
|
||||
|
||||
// this is an internal function, therefore we assume valid inputs,
|
||||
// i.e. m > 1 and gcd(a, m) = 1
|
||||
// see pub fn mod_inverse for details on the result
|
||||
// -----
|
||||
// the algorithm is based on the Extended Euclidean algorithm which computes `ax + by = d`
|
||||
// in this case `b` is the input integer `a` and `a` is the input modulus `m`. The extended
|
||||
// Euclidean algorithm calculates the greatest common divisor `d` and two coefficients `x` and `y`
|
||||
// satisfying the above equality.
|
||||
//
|
||||
// For the sake of clarity, we refer to the input integer `a` as `b` and the integer `m` as `a`.
|
||||
// If `gcd(a, b) = d = 1` then the coefficient `y` is known to be the multiplicative inverse of
|
||||
// `b` in ring `Z/aZ`, since reducing `ax + by = 1` by `a` yields `by == 1 (mod a)`.
|
||||
[direct_array_access]
|
||||
fn (a Integer) mod_inv(m Integer) Integer {
|
||||
mut n := Integer{
|
||||
digits: m.digits.clone()
|
||||
signum: 1
|
||||
}
|
||||
mut b := a
|
||||
mut x := one_int
|
||||
mut y := zero_int
|
||||
if b.signum < 0 || b.abs_cmp(n) >= 0 {
|
||||
b = b % n
|
||||
}
|
||||
mut sign := -1
|
||||
|
||||
for b != zero_int {
|
||||
q, r := if n.bit_len() == b.bit_len() {
|
||||
one_int, n - b
|
||||
} else {
|
||||
n.div_mod(b)
|
||||
}
|
||||
|
||||
n = b
|
||||
b = r
|
||||
|
||||
// tmp := q * x + y
|
||||
tmp := if q == one_int {
|
||||
x
|
||||
} else if q.digits.len == 1 && q.digits[0] & (q.digits[0] - 1) == 0 {
|
||||
x.lshift(u32(bits.trailing_zeros_32(q.digits[0])))
|
||||
} else {
|
||||
q * x
|
||||
} + y
|
||||
|
||||
y = x
|
||||
x = tmp
|
||||
sign = -sign
|
||||
}
|
||||
|
||||
if sign < 0 {
|
||||
y = m - y
|
||||
}
|
||||
|
||||
$if debug {
|
||||
assert n == one_int
|
||||
}
|
||||
|
||||
return if y.signum > 0 && y.abs_cmp(m) < 0 {
|
||||
y
|
||||
} else {
|
||||
y % m
|
||||
}
|
||||
}
|
||||
|
||||
[direct_array_access; inline]
|
||||
fn (x Integer) is_odd() bool {
|
||||
return x.digits[0] & 1 == 1
|
||||
}
|
||||
|
||||
// is_power_of_2 returns true when the integer `x` satisfies `2^n`, where `n >= 0`
|
||||
[inline]
|
||||
pub fn (x Integer) is_power_of_2() bool {
|
||||
return x.bitwise_and(x - one_int).bit_len() == 0
|
||||
}
|
||||
|
||||
// bit_len returns the number of bits required to represent the integer `a`.
|
||||
[inline]
|
||||
pub fn (x Integer) bit_len() int {
|
||||
|
||||
Reference in New Issue
Block a user