Merge pull request #328 from zhelenskiy/dev

Karatsuba added, 2 bugs are fixed
This commit is contained in:
Alexander Nozik 2021-05-14 09:12:25 +03:00 committed by GitHub
commit c1b94ff0bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 250 additions and 75 deletions

View File

@ -10,20 +10,19 @@ import kotlinx.benchmark.Blackhole
import org.openjdk.jmh.annotations.Benchmark import org.openjdk.jmh.annotations.Benchmark
import org.openjdk.jmh.annotations.Scope import org.openjdk.jmh.annotations.Scope
import org.openjdk.jmh.annotations.State import org.openjdk.jmh.annotations.State
import space.kscience.kmath.operations.BigInt import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.BigIntField import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.JBigIntegerField import java.math.BigInteger
import space.kscience.kmath.operations.invoke
private fun BigInt.pow(power: Int): BigInt = modPow(BigIntField.number(power), BigInt.ZERO)
@UnstableKMathAPI
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class BigIntBenchmark { internal class BigIntBenchmark {
val kmNumber = BigIntField.number(Int.MAX_VALUE) val kmNumber = BigIntField.number(Int.MAX_VALUE)
val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE) val jvmNumber = JBigIntegerField.number(Int.MAX_VALUE)
val largeKmNumber = BigIntField { number(11).pow(100_000) } val largeKmNumber = BigIntField { number(11).pow(100_000U) }
val largeJvmNumber = JBigIntegerField { number(11).pow(100_000) } val largeJvmNumber: BigInteger = JBigIntegerField { number(11).pow(100_000) }
val bigExponent = 50_000 val bigExponent = 50_000
@Benchmark @Benchmark
@ -36,6 +35,16 @@ internal class BigIntBenchmark {
blackhole.consume(jvmNumber + jvmNumber + jvmNumber) blackhole.consume(jvmNumber + jvmNumber + jvmNumber)
} }
@Benchmark
fun kmAddLarge(blackhole: Blackhole) = BigIntField {
blackhole.consume(largeKmNumber + largeKmNumber + largeKmNumber)
}
@Benchmark
fun jvmAddLarge(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume(largeJvmNumber + largeJvmNumber + largeJvmNumber)
}
@Benchmark @Benchmark
fun kmMultiply(blackhole: Blackhole) = BigIntField { fun kmMultiply(blackhole: Blackhole) = BigIntField {
blackhole.consume(kmNumber * kmNumber * kmNumber) blackhole.consume(kmNumber * kmNumber * kmNumber)
@ -56,13 +65,33 @@ internal class BigIntBenchmark {
blackhole.consume(largeJvmNumber*largeJvmNumber) blackhole.consume(largeJvmNumber*largeJvmNumber)
} }
// @Benchmark @Benchmark
// fun kmPower(blackhole: Blackhole) = BigIntField { fun kmPower(blackhole: Blackhole) = BigIntField {
// blackhole.consume(kmNumber.pow(bigExponent)) blackhole.consume(kmNumber.pow(bigExponent.toUInt()))
// } }
//
// @Benchmark @Benchmark
// fun jvmPower(blackhole: Blackhole) = JBigIntegerField { fun jvmPower(blackhole: Blackhole) = JBigIntegerField {
// blackhole.consume(jvmNumber.pow(bigExponent)) blackhole.consume(jvmNumber.pow(bigExponent))
// } }
@Benchmark
fun kmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("0x7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".parseBigInteger())
}
@Benchmark
fun kmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".parseBigInteger())
}
@Benchmark
fun jvmParsing10(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("236656783929183747565738292847574838922010".toBigInteger(10))
}
@Benchmark
fun jvmParsing16(blackhole: Blackhole) = JBigIntegerField {
blackhole.consume("7f57ed8b89c29a3b9a85c7a5b84ca3929c7b7488593".toBigInteger(16))
}
} }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.Symbol import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.misc.UnstableKMathAPI
/** /**
* Stub for DSL the [Algebra] is. * Stub for DSL the [Algebra] is.
@ -247,7 +248,7 @@ public interface RingOperations<T> : GroupOperations<T> {
*/ */
public interface Ring<T> : Group<T>, RingOperations<T> { public interface Ring<T> : Group<T>, RingOperations<T> {
/** /**
* neutral operation for multiplication * The neutral element of multiplication
*/ */
public val one: T public val one: T
} }

View File

@ -56,8 +56,7 @@ public class BigInt internal constructor(
else -> sign * compareMagnitudes(magnitude, other.magnitude) else -> sign * compareMagnitudes(magnitude, other.magnitude)
} }
public override fun equals(other: Any?): Boolean = public override fun equals(other: Any?): Boolean = other is BigInt && compareTo(other) == 0
if (other is BigInt) compareTo(other) == 0 else error("Can't compare KBigInteger to a different type")
public override fun hashCode(): Int = magnitude.hashCode() + sign public override fun hashCode(): Int = magnitude.hashCode() + sign
@ -87,20 +86,25 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when { public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
} }
public operator fun times(other: UInt): BigInt = when { public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO sign == 0.toByte() -> ZERO
other == 0U -> ZERO other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other)) else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
} }
public operator fun times(other: Int): BigInt = if (other > 0) public fun pow(exponent: UInt): BigInt = BigIntField.power(this@BigInt, exponent)
this * kotlin.math.abs(other).toUInt()
else public operator fun times(other: Int): BigInt = when {
-this * kotlin.math.abs(other).toUInt() other > 0 -> this * kotlin.math.abs(other).toUInt()
other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
else -> times(other.toBigInt())
}
public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other)) public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))
@ -238,6 +242,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32 public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf()) public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u)) public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80
private val hexMapping: HashMap<UInt, String> = hashMapOf( private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3", 0U to "0", 1U to "1", 2U to "2", 3U to "3",
@ -276,7 +281,7 @@ public class BigInt internal constructor(
} }
result[i] = (res and BASE).toUInt() result[i] = (res and BASE).toUInt()
carry = (res shr BASE_SIZE) carry = res shr BASE_SIZE
} }
result[resultLength - 1] = carry.toUInt() result[resultLength - 1] = carry.toUInt()
@ -318,7 +323,14 @@ public class BigInt internal constructor(
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() ->
naiveMultiplyMagnitudes(mag1, mag2)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}
internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
@ -337,6 +349,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)
val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 - x1) * (y1 - y0) + z0 + z2
return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}
private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size val resultLength = mag.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
@ -414,58 +441,90 @@ public fun UIntArray.toBigInt(sign: Byte): BigInt {
return BigInt(sign, copyOf()) return BigInt(sign, copyOf())
} }
private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)
/** /**
* Returns null if a valid number can not be read from a string * Returns null if a valid number can not be read from a string
*/ */
public fun String.parseBigInteger(): BigInt? { public fun String.parseBigInteger(): BigInt? {
if (this.isEmpty()) return null
val sign: Int val sign: Int
val sPositive: String
when { val positivePartIndex = when (this[0]) {
this[0] == '+' -> { '+' -> {
sign = +1 sign = +1
sPositive = this.substring(1) 1
} }
this[0] == '-' -> { '-' -> {
sign = -1 sign = -1
sPositive = this.substring(1) 1
} }
else -> { else -> {
sPositive = this
sign = +1 sign = +1
0
} }
} }
var res = BigInt.ZERO var isEmpty = true
var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase()
if (sPositiveUpper.startsWith("0X")) { // hex representation return if (this.startsWith("0X", startIndex = positivePartIndex, ignoreCase = true)) {
val sHex = sPositiveUpper.substring(2) // hex representation
for (ch in sHex.reversed()) { val uInts = ArrayList<UInt>(length).apply { add(0U) }
if (ch == '_') continue var offset = 0
res += digitValue * (hexChToInt[ch] ?: return null) fun addDigit(value: UInt) {
digitValue *= 16.toBigInt() uInts[uInts.lastIndex] += value shl offset
offset += 4
if (offset == 32) {
uInts.add(0U)
offset = 0
}
} }
} else for (ch in sPositiveUpper.reversed()) {
for (index in lastIndex downTo positivePartIndex + 2) {
when (val ch = this[index]) {
'_' -> continue
in '0'..'9' -> addDigit((ch - '0').toUInt())
in 'A'..'F' -> addDigit((ch - 'A').toUInt() + 10U)
in 'a'..'f' -> addDigit((ch - 'a').toUInt() + 10U)
else -> return null
}
isEmpty = false
}
while (uInts.isNotEmpty() && uInts.last() == 0U)
uInts.removeLast()
if (isEmpty) null else BigInt(sign.toByte(), uInts.toUIntArray())
} else {
// decimal representation // decimal representation
if (ch == '_') continue
if (ch !in '0'..'9') {
return null
}
res += digitValue * (ch.code - '0'.code)
digitValue *= 10.toBigInt()
}
return res * sign val positivePart = buildList(length) {
for (index in positivePartIndex until length)
when (val a = this@parseBigInteger[index]) {
'_' -> continue
in '0'..'9' -> add(a)
else -> return null
}
}
val offset = positivePart.size % 9
isEmpty = offset == 0
fun parseUInt(fromIndex: Int, toIndex: Int): UInt? {
var res = 0U
for (i in fromIndex until toIndex) {
res = res * 10U + (positivePart[i].digitToIntOrNull()?.toUInt() ?: return null)
}
return res
}
var res = parseUInt(0, offset)?.toBigInt() ?: return null
for (index in offset..positivePart.lastIndex step 9) {
isEmpty = false
res = res * 1_000_000_000U + (parseUInt(index, index + 9) ?: return null).toBigInt()
}
if (isEmpty) null else res * sign
}
} }
public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =

View File

@ -153,7 +153,7 @@ public interface PowerOperations<T> : Algebra<T> {
} }
/** /**
* Raises this element to the power [pow]. * Raises this element to the power [power].
* *
* @receiver the base. * @receiver the base.
* @param power the exponent. * @param power the exponent.

View File

@ -97,34 +97,45 @@ public fun <T, S> Sequence<T>.averageWith(space: S): T where S : Ring<T>, S : Sc
//TODO optimized power operation //TODO optimized power operation
/** /**
* Raises [arg] to the natural power [power]. * Raises [arg] to the non-negative integer power [power].
*
* Special case: 0 ^ 0 is 1.
* *
* @receiver the algebra to provide multiplication. * @receiver the algebra to provide multiplication.
* @param arg the base. * @param arg the base.
* @param power the exponent. * @param power the exponent.
* @return the base raised to the power. * @return the base raised to the power.
* @author Evgeniy Zhelenskiy
*/ */
public fun <T> Ring<T>.power(arg: T, power: Int): T { public fun <T> Ring<T>.power(arg: T, power: UInt): T = when {
require(power >= 0) { "The power can't be negative." } arg == zero && power > 0U -> zero
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } arg == one -> arg
if (power == 0) return one arg == -one -> powWithoutOptimization(arg, power % 2U)
var res = arg else -> powWithoutOptimization(arg, power)
repeat(power - 1) { res *= arg }
return res
} }
private fun <T> Ring<T>.powWithoutOptimization(base: T, exponent: UInt): T = when (exponent) {
0U -> one
1U -> base
else -> {
val pre = powWithoutOptimization(base, exponent shr 1).let { it * it }
if (exponent and 1U == 0U) pre else pre * base
}
}
/** /**
* Raises [arg] to the integer power [power]. * Raises [arg] to the integer power [power].
* *
* Special case: 0 ^ 0 is 1.
*
* @receiver the algebra to provide multiplication and division. * @receiver the algebra to provide multiplication and division.
* @param arg the base. * @param arg the base.
* @param power the exponent. * @param power the exponent.
* @return the base raised to the power. * @return the base raised to the power.
* @author Iaroslav Postovalov * @author Iaroslav Postovalov, Evgeniy Zhelenskiy
*/ */
public fun <T> Field<T>.power(arg: T, power: Int): T { public fun <T> Field<T>.power(arg: T, power: UInt): T = when {
require(power != 0 || arg != zero) { "The $zero raised to $power is not defined." } power < 0 -> one / (this as Ring<T>).power(arg, power)
if (power == 0) return one else -> (this as Ring<T>).power(arg, power)
if (power < 0) return one / (this as Ring<T>).power(arg, -power)
return (this as Ring<T>).power(arg, power)
} }

View File

@ -5,7 +5,9 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.testutils.RingVerifier import space.kscience.kmath.testutils.RingVerifier
import kotlin.math.pow
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -21,6 +23,18 @@ internal class BigIntAlgebraTest {
assertEquals(res, 1_000_000.toBigInt()) assertEquals(res, 1_000_000.toBigInt())
} }
@UnstableKMathAPI
@Test
fun testKBigIntegerRingPow() {
for (num in -5..5)
for (exponent in 0U..10U)
assertEquals(
num.toDouble().pow(exponent.toInt()).toLong().toBigInt(),
num.toBigInt().pow(exponent),
"$num ^ $exponent"
)
}
@Test @Test
fun testKBigIntegerRingSum_100_000_000__100_000_000() { fun testKBigIntegerRingSum_100_000_000__100_000_000() {
BigIntField { BigIntField {

View File

@ -7,15 +7,43 @@ package space.kscience.kmath.operations
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNull
@kotlin.ExperimentalUnsignedTypes @kotlin.ExperimentalUnsignedTypes
class BigIntConversionsTest { class BigIntConversionsTest {
@Test
fun testEmptyString() {
assertNull("".parseBigInteger())
assertNull("+".parseBigInteger())
assertNull("-".parseBigInteger())
assertNull("0x".parseBigInteger())
assertNull("+0x".parseBigInteger())
assertNull("-0x".parseBigInteger())
assertNull("_".parseBigInteger())
assertNull("+_".parseBigInteger())
assertNull("-_".parseBigInteger())
assertNull("0x_".parseBigInteger())
assertNull("+0x_".parseBigInteger())
assertNull("-0x_".parseBigInteger())
}
@Test @Test
fun testToString0x10() { fun testToString0x10() {
val x = 0x10.toBigInt() val x = 0x10.toBigInt()
assertEquals("0x10", x.toString()) assertEquals("0x10", x.toString())
} }
@Test
fun testUnderscores() {
assertEquals("0x10", "0x_1_0_".parseBigInteger().toString())
assertEquals("0xa", "_1_0_".parseBigInteger().toString())
}
@Test @Test
fun testToString0x17ffffffd() { fun testToString0x17ffffffd() {
val x = 0x17ffffffdL.toBigInt() val x = 0x17ffffffdL.toBigInt()

View File

@ -5,8 +5,9 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import kotlin.test.Test import kotlin.random.Random
import kotlin.test.assertEquals import kotlin.random.nextUInt
import kotlin.test.*
@kotlin.ExperimentalUnsignedTypes @kotlin.ExperimentalUnsignedTypes
class BigIntOperationsTest { class BigIntOperationsTest {
@ -150,6 +151,18 @@ class BigIntOperationsTest {
assertEquals(prod, res) assertEquals(prod, res)
} }
@Test
fun testKaratsuba() {
val x = uintArrayOf(12U, 345U)
val y = uintArrayOf(6U, 789U)
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x, y), BigInt.karatsubaMultiplyMagnitudes(x, y))
repeat(1000) {
val x1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
val y1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x1, y1), BigInt.karatsubaMultiplyMagnitudes(x1, y1))
}
}
@Test @Test
fun test_shr_20() { fun test_shr_20() {
val x = 20.toBigInt() val x = 20.toBigInt()
@ -383,4 +396,12 @@ class BigIntOperationsTest {
return assertEquals(res, x % mod) return assertEquals(res, x % mod)
} }
@Test
fun testNotEqualsOtherTypeInstanceButButNotFails() = assertFalse(0.toBigInt().equals(""))
@Test
fun testIntAbsOverflow() {
assertEquals((-Int.MAX_VALUE.toLong().toBigInt() - 1.toBigInt()) * 2, 2.toBigInt() * Int.MIN_VALUE)
}
} }

View File

@ -18,4 +18,16 @@ internal class DoubleFieldTest {
val sqrt = DoubleField { sqrt(25 * one) } val sqrt = DoubleField { sqrt(25 * one) }
assertEquals(5.0, sqrt) assertEquals(5.0, sqrt)
} }
@Test
fun testPow() = DoubleField {
val num = 5 * one
assertEquals(5.0, power(num, 1))
assertEquals(25.0, power(num, 2))
assertEquals(1.0, power(num, 0))
assertEquals(0.2, power(num, -1))
assertEquals(0.04, power(num, -2))
assertEquals(0.0, power(num, Int.MIN_VALUE))
assertEquals(1.0, power(zero, 0))
}
} }