Merge pull request #328 from zhelenskiy/dev

Karatsuba added, 2 bugs are fixed
This commit is contained in:
Alexander Nozik 2021-05-14 09:12:25 +03:00 committed by GitHub
commit c1b94ff0bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 250 additions and 75 deletions

View File

@ -10,20 +10,19 @@ import kotlinx.benchmark.Blackhole
import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State
import space.kscience.kmath.operations.BigInt
import space.kscience.kmath.operations.BigIntField
import space.kscience.kmath.operations.JBigIntegerField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.*
import java.math.BigInteger
private fun BigInt.pow(power: Int): BigInt = modPow(BigIntField.number(power), BigInt.ZERO)
@UnstableKMathAPI
@State(Scope.Benchmark)
internal class BigIntBenchmark {
val kmNumber = BigIntField.number(Int.MAX_VALUE)
val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE)
val largeKmNumber = BigIntField { number(11).pow(100_000) }
val largeJvmNumber = JBigIntegerField { number(11).pow(100_000) }
val largeKmNumber = BigIntField { number(11).pow(100_000U) }
val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) }
val bigExponent = 50_000
@Benchmark
@ -36,6 +35,16 @@ internal class BigIntBenchmark {
blackhole.consume(jvmNumber + jvmNumber + jvmNumber)
}
@Benchmark
fun kmAddLarge(blackhole: Blackhole) = BigIntField {
blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber)
}
@Benchmark
fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber)
}
@Benchmark
fun kmMultiply(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber * kmNumber * kmNumber)
@ -56,13 +65,33 @@ internal class BigIntBenchmark {
blackhole.consume(largeJvmNumber*largeJvmNumber)
}
// @Benchmark
// fun kmPower(blackhole: Blackhole) = BigIntField {
// blackhole.consume(kmNumber.pow(bigExponent))
// }
//
// @Benchmark
// fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
// blackhole.consume(jvmNumber.pow(bigExponent))
// }
@Benchmark
fun kmPower(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber.pow(bigExponent.toUInt()))
}
@Benchmark
fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(jvmNumber.pow(bigExponent))
}
@Benchmark
fun kmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("0x7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".parseBigInteger())
}
@Benchmark
fun kmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".parseBigInteger())
}
@Benchmark
fun jvmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".toBigInteger(10))
}
@Benchmark
fun jvmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".toBigInteger(16))
}
}

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.operations
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
/**
* Stub for DSL the [Algebra] is.
@ -247,7 +248,7 @@ public interface RingOperations<T> : GroupOperations<T> {
*/
public interface Ring<T> : Group<T>, RingOperations<T> {
/**
* neutral operation for multiplication
* The neutral element of multiplication
*/
public val one: T
}

View File

@ -56,8 +56,7 @@ public class BigInt internal constructor(
else -> sign * compareMagnitudes(magnitude, other.magnitude)
}
public override fun equals(other: Any?): Boolean =
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0
public override fun hashCode(): Int = magnitude.hashCode() + sign
@ -87,20 +86,25 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba
b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
}
public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO
other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
}
public operator fun times(other: Int): BigInt = if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
public fun pow(exponent: UInt): BigInt = BigIntField.power(this@BigInt, exponent)
public operator fun times(other: Int): BigInt = when {
other > 0 -> this * kotlin.math.abs(other).toUInt()
other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
else -> times(other.toBigInt())
}
public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))
@ -238,6 +242,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80
private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3",
@ -276,7 +281,7 @@ public class BigInt internal constructor(
}
result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE)
carry = res shr BASE_SIZE
}
result[resultLength - 1] = carry.toUInt()
@ -318,7 +323,14 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}
private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() ->
naiveMultiplyMagnitudes(mag1, mag2)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}
internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength)
@ -337,6 +349,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result)
}
internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)
val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 - x1) * (y1 - y0) + z0 + z2
return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}
private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size
val result = Magnitude(resultLength)
@ -414,58 +441,90 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, copyOf())
}
private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)
/**
* Returns null if a valid number can not be read from a string
*/
public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
val sign: Int
val sPositive: String
when {
this[0] == '+' -> {
val positivePartIndex = when (this[0]) {
'+' -> {
sign = +1
sPositive = this.substring(1)
1
}
this[0] == '-' -> {
'-' -> {
sign = -1
sPositive = this.substring(1)
1
}
else -> {
sPositive = this
sign = +1
0
}
}
var res = BigInt.ZERO
var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase()
var isEmpty = true
if (sPositiveUpper.startsWith("0X")) { // hex representation
val sHex = sPositiveUpper.substring(2)
return if (this.startsWith("0X", startIndex = positivePartIndex, ignoreCase = true)) {
// hex representation
for (ch in sHex.reversed()) {
if (ch == '_') continue
res += digitValue * (hexChToInt[ch] ?: return null)
digitValue *= 16.toBigInt()
val uInts = ArrayList<UInt>(length).apply { add(0U) }
var offset = 0
fun addDigit(value: UInt) {
uInts[uInts.lastIndex] += value shl offset
offset += 4
if (offset == 32) {
uInts.add(0U)
offset = 0
}
}
} else for (ch in sPositiveUpper.reversed()) {
for (index in lastIndex downTo positivePartIndex + 2) {
when (val ch = this[index]) {
'_' -> continue
in '0'..'9' -> addDigit((ch - '0').toUInt())
in 'A'..'F' -> addDigit((ch - 'A').toUInt() + 10U)
in 'a'..'f' -> addDigit((ch - 'a').toUInt() + 10U)
else -> return null
}
isEmpty = false
}
while (uInts.isNotEmpty() && uInts.last() == 0U)
uInts.removeLast()
if (isEmpty) null else BigInt(sign.toByte(), uInts.toUIntArray())
} else {
// decimal representation
if (ch == '_') continue
if (ch !in '0'..'9') {
return null
}
res += digitValue * (ch.code - '0'.code)
digitValue *= 10.toBigInt()
}
return res * sign
val positivePart = buildList(length) {
for (index in positivePartIndex until length)
when (val a = this@parseBigInteger[index]) {
'_' -> continue
in '0'..'9' -> add(a)
else -> return null
}
}
val offset = positivePart.size % 9
isEmpty = offset == 0
fun parseUInt(fromIndex: Int, toIndex: Int): UInt? {
var res = 0U
for (i in fromIndex until toIndex) {
res = res * 10U + (positivePart[i].digitToIntOrNull()?.toUInt() ?: return null)
}
return res
}
var res = parseUInt(0, offset)?.toBigInt() ?: return null
for (index in offset..positivePart.lastIndex step 9) {
isEmpty = false
res = res * 1_000_000_000U + (parseUInt(index, index + 9) ?: return null).toBigInt()
}
if (isEmpty) null else res * sign
}
}
public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =

View File

@ -153,7 +153,7 @@ public interface PowerOperations<T> : Algebra<T> {
}
/**
* Raises this element to the power [pow].
* Raises this element to the power [power].
*
* @receiver the base.
* @param power the exponent.

View File

@ -97,34 +97,45 @@ public fun <T, S> Sequence<T>.averageWith(space: S): T where S : Ring<T>, S : Sc
//TODO optimized power operation
/**
* Raises [arg] to the natural power [power].
* Raises [arg] to the non-negative integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Evgeniy Zhelenskiy
*/
public fun <T> Ring<T>.power(arg: T, power: Int): T {
require(power >= 0) { "The power can't be negative." }
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
var res = arg
repeat(power - 1) { res *= arg }
return res
public fun <T> Ring<T>.power(arg: T, power: UInt): T = when {
arg == zero && power > 0U -> zero
arg == one -> arg
arg == -one -> powWithoutOptimization(arg, power % 2U)
else -> powWithoutOptimization(arg, power)
}
private fun <T> Ring<T>.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) {
0U -> one
1U -> base
else -> {
val pre = powWithoutOptimization(base, exponent shr 1).let { it * it }
if (exponent and 1U == 0U) pre else pre * base
}
}
/**
* Raises [arg] to the integer power [power].
*
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication and division.
* @param arg the base.
* @param power the exponent.
* @return the base raised to the power.
* @author Iaroslav Postovalov
* @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/
public fun <T> Field<T>.power(arg: T, power: Int): T {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." }
if (power == 0) return one
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
return (this as Ring<T>).power(arg, power)
public fun <T> Field<T>.power(arg: T, power: UInt): T = when {
power < 0 -> one / (this as Ring<T>).power(arg, power)
else -> (this as Ring<T>).power(arg, power)
}

View File

@ -5,7 +5,9 @@
package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.testutils.RingVerifier
import kotlin.math.pow
import kotlin.test.Test
import kotlin.test.assertEquals
@ -21,6 +23,18 @@ internal class BigIntAlgebraTest {
assertEquals(res, 1_000_000.toBigInt())
}
@UnstableKMathAPI
@Test
fun testKBigIntegerRingPow() {
for (num in -5..5)
for (exponent in 0U..10U)
assertEquals(
num.toDouble().pow(exponent.toInt()).toLong().toBigInt(),
num.toBigInt().pow(exponent),
"$num ^ $exponent"
)
}
@Test
fun testKBigIntegerRingSum_100_000_000__100_000_000() {
BigIntField {

View File

@ -7,15 +7,43 @@ package space.kscience.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
@kotlin.ExperimentalUnsignedTypes
class BigIntConversionsTest {
@Test
fun testEmptyString() {
assertNull("".parseBigInteger())
assertNull("+".parseBigInteger())
assertNull("-".parseBigInteger())
assertNull("0x".parseBigInteger())
assertNull("+0x".parseBigInteger())
assertNull("-0x".parseBigInteger())
assertNull("_".parseBigInteger())
assertNull("+_".parseBigInteger())
assertNull("-_".parseBigInteger())
assertNull("0x_".parseBigInteger())
assertNull("+0x_".parseBigInteger())
assertNull("-0x_".parseBigInteger())
}
@Test
fun testToString0x10() {
val x = 0x10.toBigInt()
assertEquals("0x10", x.toString())
}
@Test
fun testUnderscores() {
assertEquals("0x10", "0x_1_0_".parseBigInteger().toString())
assertEquals("0xa", "_1_0_".parseBigInteger().toString())
}
@Test
fun testToString0x17ffffffd() {
val x = 0x17ffffffdL.toBigInt()

View File

@ -5,8 +5,9 @@
package space.kscience.kmath.operations
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.random.Random
import kotlin.random.nextUInt
import kotlin.test.*
@kotlin.ExperimentalUnsignedTypes
class BigIntOperationsTest {
@ -150,6 +151,18 @@ class BigIntOperationsTest {
assertEquals(prod, res)
}
@Test
fun testKaratsuba() {
val x = uintArrayOf(12U, 345U)
val y = uintArrayOf(6U, 789U)
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x, y), BigInt.karatsubaMultiplyMagnitudes(x, y))
repeat(1000) {
val x1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
val y1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x1, y1), BigInt.karatsubaMultiplyMagnitudes(x1, y1))
}
}
@Test
fun test_shr_20() {
val x = 20.toBigInt()
@ -383,4 +396,12 @@ class BigIntOperationsTest {
return assertEquals(res, x % mod)
}
@Test
fun testNotEqualsOtherTypeInstanceButButNotFails() = assertFalse(0.toBigInt().equals(""))
@Test
fun testIntAbsOverflow() {
assertEquals((-Int.MAX_VALUE.toLong().toBigInt() - 1.toBigInt()) * 2, 2.toBigInt() * Int.MIN_VALUE)
}
}

View File

@ -18,4 +18,16 @@ internal class DoubleFieldTest {
val sqrt = DoubleField { sqrt(25 * one) }
assertEquals(5.0, sqrt)
}
@Test
fun testPow() = DoubleField {
val num = 5 * one
assertEquals(5.0, power(num, 1))
assertEquals(25.0, power(num, 2))
assertEquals(1.0, power(num, 0))
assertEquals(0.2, power(num, -1))
assertEquals(0.04, power(num, -2))
assertEquals(0.0, power(num, Int.MIN_VALUE))
assertEquals(1.0, power(zero, 0))
}
}