Karatsuba added

Incorrect equals fix test
Incorrect overflow handling support
This commit is contained in:
zhelenskiy 2021-05-11 01:40:45 +03:00
parent 3ad7f32ada
commit 2da0648d73
2 changed files with 61 additions and 10 deletions

View File

@ -86,20 +86,23 @@ public class BigInt internal constructor(
public operator fun times(b: BigInt): BigInt = when { public operator fun times(b: BigInt): BigInt = when {
this.sign == 0.toByte() -> ZERO this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba b.magnitude.size == 1 -> this * b.magnitude[0] * b.sign.toInt()
this.magnitude.size == 1 -> b * this.magnitude[0] * this.sign.toInt()
else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) else -> BigInt((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
} }
public operator fun times(other: UInt): BigInt = when { public operator fun times(other: UInt): BigInt = when {
sign == 0.toByte() -> ZERO sign == 0.toByte() -> ZERO
other == 0U -> ZERO other == 0U -> ZERO
other == 1U -> this
else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other)) else -> BigInt(sign, multiplyMagnitudeByUInt(magnitude, other))
} }
public operator fun times(other: Int): BigInt = if (other > 0) public operator fun times(other: Int): BigInt = when {
this * kotlin.math.abs(other).toUInt() other > 0 -> this * kotlin.math.abs(other).toUInt()
else other != Int.MIN_VALUE -> -this * kotlin.math.abs(other).toUInt()
-this * kotlin.math.abs(other).toUInt() else -> times(other.toBigInt())
}
public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other)) public operator fun div(other: UInt): BigInt = BigInt(this.sign, divideMagnitudeByUInt(this.magnitude, other))
@ -237,6 +240,7 @@ public class BigInt internal constructor(
public const val BASE_SIZE: Int = 32 public const val BASE_SIZE: Int = 32
public val ZERO: BigInt = BigInt(0, uintArrayOf()) public val ZERO: BigInt = BigInt(0, uintArrayOf())
public val ONE: BigInt = BigInt(1, uintArrayOf(1u)) public val ONE: BigInt = BigInt(1, uintArrayOf(1u))
private const val KARATSUBA_THRESHOLD = 80
private val hexMapping: HashMap<UInt, String> = hashMapOf( private val hexMapping: HashMap<UInt, String> = hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3", 0U to "0", 1U to "1", 2U to "2", 3U to "3",
@ -317,7 +321,16 @@ public class BigInt internal constructor(
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { internal fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude = when {
mag1.size + mag2.size < KARATSUBA_THRESHOLD || mag1.isEmpty() || mag2.isEmpty() -> naiveMultiplyMagnitudes(
mag1,
mag2
)
// TODO implement Fourier
else -> karatsubaMultiplyMagnitudes(mag1, mag2)
}
internal fun naiveMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
val resultLength = mag1.size + mag2.size val resultLength = mag1.size + mag2.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
@ -336,6 +349,21 @@ public class BigInt internal constructor(
return stripLeadingZeros(result) return stripLeadingZeros(result)
} }
internal fun karatsubaMultiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
//https://en.wikipedia.org/wiki/Karatsuba_algorithm
val halfSize = min(mag1.size, mag2.size) / 2
val x0 = mag1.sliceArray(0 until halfSize).toBigInt(1)
val x1 = mag1.sliceArray(halfSize until mag1.size).toBigInt(1)
val y0 = mag2.sliceArray(0 until halfSize).toBigInt(1)
val y1 = mag2.sliceArray(halfSize until mag2.size).toBigInt(1)
val z0 = x0 * y0
val z2 = x1 * y1
val z1 = (x0 + x1) * (y1 + y0) - z0 - z2
return (z2.shl(2 * halfSize * BASE_SIZE) + z1.shl(halfSize * BASE_SIZE) + z0).magnitude
}
private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
val resultLength = mag.size val resultLength = mag.size
val result = Magnitude(resultLength) val result = Magnitude(resultLength)
@ -426,7 +454,7 @@ private val hexChToInt: MutableMap<Char, Int> = hashMapOf(
public fun String.parseBigInteger(): BigInt? { public fun String.parseBigInteger(): BigInt? {
val sign: Int val sign: Int
val sPositive: String val sPositive: String
//TODO substring = O(n). Can be replaced by some drop that is O(1)
when { when {
this[0] == '+' -> { this[0] == '+' -> {
sign = +1 sign = +1
@ -446,8 +474,10 @@ public fun String.parseBigInteger(): BigInt? {
var digitValue = BigInt.ONE var digitValue = BigInt.ONE
val sPositiveUpper = sPositive.uppercase() val sPositiveUpper = sPositive.uppercase()
if (sPositiveUpper.startsWith("0X")) { // hex representation if (sPositiveUpper.startsWith("0X")) {
// hex representation
val sHex = sPositiveUpper.substring(2) val sHex = sPositiveUpper.substring(2)
// TODO optimize O(n2) -> O(n)
for (ch in sHex.reversed()) { for (ch in sHex.reversed()) {
if (ch == '_') continue if (ch == '_') continue

View File

@ -5,8 +5,9 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import kotlin.test.Test import kotlin.random.Random
import kotlin.test.assertEquals import kotlin.random.nextUInt
import kotlin.test.*
@kotlin.ExperimentalUnsignedTypes @kotlin.ExperimentalUnsignedTypes
class BigIntOperationsTest { class BigIntOperationsTest {
@ -150,6 +151,18 @@ class BigIntOperationsTest {
assertEquals(prod, res) assertEquals(prod, res)
} }
@Test
fun testKaratsuba() {
val x = uintArrayOf(12U, 345U)
val y = uintArrayOf(6U, 789U)
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x, y), BigInt.karatsubaMultiplyMagnitudes(x, y))
repeat(1000) {
val x1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
val y1 = UIntArray(Random.nextInt(100, 1000)) { Random.nextUInt() }
assertContentEquals(BigInt.naiveMultiplyMagnitudes(x1, y1), BigInt.karatsubaMultiplyMagnitudes(x1, y1))
}
}
@Test @Test
fun test_shr_20() { fun test_shr_20() {
val x = 20.toBigInt() val x = 20.toBigInt()
@ -383,4 +396,12 @@ class BigIntOperationsTest {
return assertEquals(res, x % mod) return assertEquals(res, x % mod)
} }
@Test
fun testNotEqualsOtherTypeInstanceButButNotFails() = assertFalse(0.toBigInt().equals(""))
@Test
fun testIntAbsOverflow() {
assertEquals((-Int.MAX_VALUE.toLong().toBigInt() - 1.toBigInt()) * 2, 2.toBigInt() * Int.MIN_VALUE)
}
} }