diff --git a/vlib/math/big/array_ops.v b/vlib/math/big/array_ops.v index 536c43ea80..bf853008ef 100644 --- a/vlib/math/big/array_ops.v +++ b/vlib/math/big/array_ops.v @@ -110,14 +110,20 @@ fn subtract_digit_array(operand_a []u32, operand_b []u32, mut storage []u32) { shrink_tail_zeros(mut storage) } -const karatsuba_multiplication_limit = 1_000_000 +const karatsuba_multiplication_limit = 240 -// set limit to choose algorithm +const toom3_multiplication_limit = 10_000 [inline] fn multiply_digit_array(operand_a []u32, operand_b []u32, mut storage []u32) { - if operand_a.len >= big.karatsuba_multiplication_limit - || operand_b.len >= big.karatsuba_multiplication_limit { + max_len := if operand_a.len >= operand_b.len { + operand_a.len + } else { + operand_b.len + } + if max_len >= big.toom3_multiplication_limit { + toom3_multiply_digit_array(operand_a, operand_b, mut storage) + } else if max_len >= big.karatsuba_multiplication_limit { karatsuba_multiply_digit_array(operand_a, operand_b, mut storage) } else { simple_multiply_digit_array(operand_a, operand_b, mut storage) diff --git a/vlib/math/big/big.v b/vlib/math/big/big.v index a883472216..5068a6ca4c 100644 --- a/vlib/math/big/big.v +++ b/vlib/math/big/big.v @@ -16,4 +16,9 @@ pub const ( signum: 1 is_const: true } + three_int = Integer{ + digits: [u32(3)] + signum: 1 + is_const: true + } ) diff --git a/vlib/math/big/special_array_ops.v b/vlib/math/big/special_array_ops.v index 3c7c2d4156..7043a75f70 100644 --- a/vlib/math/big/special_array_ops.v +++ b/vlib/math/big/special_array_ops.v @@ -61,8 +61,9 @@ fn newton_divide_array_by_array(operand_a []u32, operand_b []u32, mut quotient [ shrink_tail_zeros(mut remainder) } +// bit_length returns the number of bits needed to represent the absolute value of the integer a. [inline] -fn bit_length(a Integer) int { +pub fn bit_length(a Integer) int { return a.digits.len * 32 - bits.leading_zeros_32(a.digits.last()) } @@ -82,32 +83,37 @@ fn debug_u32_str(a []u32) string { return sb.str() } +[direct_array_access; inline] +fn found_multiplication_base_case(operand_a []u32, operand_b []u32, mut storage []u32) bool { + // base case necessary to end recursion + if operand_a.len == 0 || operand_b.len == 0 { + storage.clear() + return true + } + + if operand_a.len < operand_b.len { + multiply_digit_array(operand_b, operand_a, mut storage) + return true + } + + if operand_b.len == 1 { + multiply_array_by_digit(operand_a, operand_b[0], mut storage) + return true + } + return false +} + // karatsuba algorithm for multiplication // possible optimisations: // - transform one or all the recurrences in loops [direct_array_access] fn karatsuba_multiply_digit_array(operand_a []u32, operand_b []u32, mut storage []u32) { - // base case necessary to end recursion - if operand_a.len == 0 || operand_b.len == 0 { - storage.clear() + if found_multiplication_base_case(operand_a, operand_b, mut storage) { return } - if operand_a.len < operand_b.len { - multiply_digit_array(operand_b, operand_a, mut storage) - return - } - - if operand_b.len == 1 { - multiply_array_by_digit(operand_a, operand_b[0], mut storage) - return - } - // karatsuba // thanks to the base cases we can pass zero-length arrays to the mult func half := math.max(operand_a.len, operand_b.len) / 2 - if half <= 0 { - panic('Unreachable. Both array have 1 length and multiply_array_by_digit should have been called') - } a_l := operand_a[0..half] a_h := operand_a[half..] mut b_l := []u32{} @@ -137,14 +143,107 @@ fn karatsuba_multiply_digit_array(operand_a []u32, operand_b []u32, mut storage subtract_in_place(mut p_2, p_3) // return p_1.lshift(2 * u32(half * 32)) + p_2.lshift(u32(half * 32)) + p_3 - lshift_byte_in_place(mut storage, 2 * half) - lshift_byte_in_place(mut p_2, half) + lshift_digits_in_place(mut storage, 2 * half) + lshift_digits_in_place(mut p_2, half) add_in_place(mut storage, p_2) add_in_place(mut storage, p_3) shrink_tail_zeros(mut storage) } +[direct_array_access] +fn toom3_multiply_digit_array(operand_a []u32, operand_b []u32, mut storage []u32) { + if found_multiplication_base_case(operand_a, operand_b, mut storage) { + return + } + + // After the base case, we have operand_a as the larger integer in terms of digit length + + // k is the length (in u32 digits) of the lower order slices + k := (operand_a.len + 2) / 3 + k2 := 2 * k + + // The pieces of the calculation need to be worked on as proper big.Integers + // because the intermediate results can be negative. After recombination, the + // final result will be positive. + + // Slices of a and b + a0 := Integer{ + digits: operand_a[0..k] + signum: 1 + } + a1 := Integer{ + digits: operand_a[k..k2] + signum: 1 + } + a2 := Integer{ + digits: operand_a[k2..] + signum: 1 + } + + // Zero arrays by default + mut b0 := zero_int.clone() + mut b1 := zero_int.clone() + mut b2 := zero_int.clone() + + if operand_b.len < k { + b0 = Integer{ + digits: operand_b + signum: 1 + } + } else if operand_b.len < k2 { + b0 = Integer{ + digits: operand_b[0..k] + signum: 1 + } + b1 = Integer{ + digits: operand_b[k..] + signum: 1 + } + } else { + b0 = Integer{ + digits: operand_b[0..k] + signum: 1 + } + b1 = Integer{ + digits: operand_b[k..k2] + signum: 1 + } + b2 = Integer{ + digits: operand_b[k2..] + signum: 1 + } + } + + // https://en.wikipedia.org/wiki/Toom%E2%80%93Cook_multiplication#Details + // DOI: 10.1007/978-3-540-73074-3_10 + + p0 := a0 * b0 + mut ptemp := a2 + a0 + mut qtemp := b2 + b0 + vm1 := (ptemp - a1) * (qtemp - b1) + ptemp += a1 + qtemp += b1 + p1 := ptemp * qtemp + p2 := ((ptemp + a2).lshift(1) - a0) * ((qtemp + b2).lshift(1) - b0) + pinf := a2 * b2 + + mut t2 := (p2 - vm1) / three_int + mut tm1 := (p1 - vm1).rshift(1) + mut t1 := p1 - p0 + t2 = (t2 - t1).rshift(1) + t1 = (t1 - tm1 - pinf) + t2 = t2 - pinf.lshift(1) + tm1 = tm1 - t2 + + // shift amount + s := u32(k) << 5 + + result := (((pinf.lshift(s) + t2).lshift(s) + t1).lshift(s) + tm1).lshift(s) + p0 + + storage = result.digits +} + [inline] fn pow2(k int) Integer { mut ret := []u32{len: (k >> 5) + 1} @@ -155,22 +254,34 @@ fn pow2(k int) Integer { } } -// optimized left shift of full u8(s) in place. byte_nb must be positive +// optimized left shift in place. amount must be positive [direct_array_access] -fn lshift_byte_in_place(mut a []u32, byte_nb int) { +fn lshift_digits_in_place(mut a []u32, amount int) { a_len := a.len // control or allocate capacity - for _ in a_len .. a_len + byte_nb { + for _ in a_len .. a_len + amount { a << u32(0) } for index := a_len - 1; index >= 0; index-- { - a[index + byte_nb] = a[index] + a[index + amount] = a[index] } - for index in 0 .. byte_nb { + for index in 0 .. amount { a[index] = u32(0) } } +// optimized right shift in place. amount must be positive +[direct_array_access] +fn rshift_digits_in_place(mut a []u32, amount int) { + for index := 0; index < a.len - amount; index++ { + a[index] = a[index + amount] + } + for index := a.len - amount; index < a.len; index++ { + a[index] = u32(0) + } + shrink_tail_zeros(mut a) +} + // operand b can be greater than operand a // the capacity of both array is supposed to be sufficient [direct_array_access; inline] @@ -210,20 +321,20 @@ fn subtract_in_place(mut a []u32, b []u32) { mut carry := u32(0) mut new_carry := u32(0) for index in 0 .. min { - if a[index] < (b[index] + carry) { - new_carry = 1 + new_carry = if a[index] < (b[index] + carry) { + u32(1) } else { - new_carry = 0 + u32(0) } a[index] -= (b[index] + carry) carry = new_carry } if len_a >= len_b { for index in min .. max { - if a[index] < carry { - new_carry = 1 + new_carry = if a[index] < carry { + u32(1) } else { - new_carry = 0 + u32(0) } a[index] -= carry carry = new_carry diff --git a/vlib/math/big/special_array_ops_test.v b/vlib/math/big/special_array_ops_test.v index 7d381742c8..f4e8c668b9 100644 --- a/vlib/math/big/special_array_ops_test.v +++ b/vlib/math/big/special_array_ops_test.v @@ -24,9 +24,9 @@ fn test_add_in_place() { assert a == [u32(0x17ff72ad), 0x1439] } -fn test_lshift_byte_in_place() { +fn test_lshift_digits_in_place() { mut a := [u32(5), 6, 7, 8] - lshift_byte_in_place(mut a, 2) + lshift_digits_in_place(mut a, 2) assert a == [u32(0), 0, 5, 6, 7, 8] }