From 9500ee0924d828a600771857197a0e369d707fcf Mon Sep 17 00:00:00 2001 From: Peter Klimai Date: Mon, 30 Mar 2020 16:30:16 +0300 Subject: [PATCH 1/3] Initial implementation of multiplatform BigInteger --- .../kmath/operations/KBigInteger.kt | 398 +++++++++++++++ .../kmath/operations/KBigIntegerTest.kt | 480 ++++++++++++++++++ 2 files changed, 878 insertions(+) create mode 100644 kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt create mode 100644 kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt new file mode 100644 index 000000000..0238492d2 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt @@ -0,0 +1,398 @@ +package scientifik.kmath.operations + +import kotlin.math.max +import kotlin.math.sign + +/* + * Kotlin Multiplatform implementation of Big Integer numbers (KBigInteger). + * Initial version from https://github.com/robdrynkin/kotlin-big-integer + */ + +@kotlin.ExperimentalUnsignedTypes +typealias Magnitude = UIntArray + +@kotlin.ExperimentalUnsignedTypes +typealias TBase = ULong + +object KBigIntegerRing: Ring { + override val zero: KBigInteger = KBigInteger.ZERO + override val one: KBigInteger = KBigInteger.ONE + + override fun add(a: KBigInteger, b: KBigInteger): KBigInteger = a.plus(b) + + override fun multiply(a: KBigInteger, k: Number): KBigInteger = a.times(k.toLong()) + + override fun multiply(a: KBigInteger, b: KBigInteger): KBigInteger = a.times(b) + + operator fun String.unaryPlus(): KBigInteger = KBigInteger(this)!! + + operator fun String.unaryMinus(): KBigInteger = -KBigInteger(this)!! + +} + + +@kotlin.ExperimentalUnsignedTypes +class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): + RingElement, Comparable { + + constructor(x: Int) : this(x.sign, uintArrayOf(kotlin.math.abs(x).toUInt())) + constructor(x: Long) : this(x.sign, uintArrayOf( + (kotlin.math.abs(x).toULong() and BASE).toUInt(), + ((kotlin.math.abs(x).toULong() shr BASE_SIZE) and BASE).toUInt())) + + val magnitude = stripLeadingZeros(magnitude) + val sign = if (this.magnitude.isNotEmpty()) sign else 0 + val sizeByte: Int = magnitude.size * BASE_SIZE / 4 + + override val context: KBigIntegerRing get() = KBigIntegerRing + + override fun unwrap(): KBigInteger = this + override fun KBigInteger.wrap(): KBigInteger = this + + companion object { + val BASE = 0xffffffffUL + const val BASE_SIZE: Int = 32 + val ZERO: KBigInteger = KBigInteger() + val ONE: KBigInteger = KBigInteger(1) + + private val hexMapping: HashMap = + hashMapOf( + 0U to "0", 1U to "1", 2U to "2", 3U to "3", 4U to "4", 5U to "5", 6U to "6", 7U to "7", 8U to "8", + 9U to "9", 10U to "a", 11U to "b", 12U to "c", 13U to "d", 14U to "e", 15U to "f" + ) + + private fun stripLeadingZeros(mag: Magnitude): Magnitude { + // TODO: optimize performance + if (mag.isEmpty()) { + return mag + } + var resSize: Int = mag.size - 1 + while (mag[resSize] == 0U) { + if (resSize == 0) + break + resSize -= 1 + } + return mag.sliceArray(IntRange(0, resSize)) + } + + private fun compareMagnitudes(mag1: Magnitude, mag2: Magnitude): Int { + when { + mag1.size > mag2.size -> return 1 + mag1.size < mag2.size -> return -1 + else -> { + for (i in mag1.size - 1 downTo 0) { + if (mag1[i] > mag2[i]) { + return 1 + } else if (mag1[i] < mag2[i]) { + return -1 + } + } + return 0 + } + } + } + + private fun addMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { + val resultLength: Int = max(mag1.size, mag2.size) + 1 + val result = Magnitude(resultLength) + var carry: TBase = 0UL + + for (i in 0 until resultLength - 1) { + val res = when { + i >= mag1.size -> mag2[i].toULong() + carry + i >= mag2.size -> mag1[i].toULong() + carry + else -> mag1[i].toULong() + mag2[i].toULong() + carry + } + result[i] = (res and BASE).toUInt() + carry = (res shr BASE_SIZE) + } + result[resultLength - 1] = carry.toUInt() + return result + } + + private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { + val resultLength: Int = mag1.size + val result = Magnitude(resultLength) + var carry = 0L + + for (i in 0 until resultLength) { + var res: Long = + if (i < mag2.size) mag1[i].toLong() - mag2[i].toLong() - carry + else mag1[i].toLong() - carry + + carry = if (res < 0) 1 else 0 + res += carry * (BASE + 1UL).toLong() + + result[i] = res.toUInt() + } + + return result + } + + private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { + val resultLength: Int = mag.size + 1 + val result = Magnitude(resultLength) + var carry: ULong = 0UL + + for (i in mag.indices) { + val cur: ULong = carry + mag[i].toULong() * x.toULong() + result[i] = (cur and BASE.toULong()).toUInt() + carry = cur shr BASE_SIZE + } + result[resultLength - 1] = (carry and BASE).toUInt() + + return result + } + + private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { + val resultLength: Int = mag1.size + mag2.size + val result = Magnitude(resultLength) + + for (i in mag1.indices) { + var carry: ULong = 0UL + for (j in mag2.indices) { + val cur: ULong = result[i + j].toULong() + mag1[i].toULong() * mag2[j].toULong() + carry + result[i + j] = (cur and BASE.toULong()).toUInt() + carry = cur shr BASE_SIZE + } + result[i + mag2.size] = (carry and BASE).toUInt() + } + + return result + } + + internal fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { + val resultLength: Int = mag.size + val result = Magnitude(resultLength) + var carry: ULong = 0UL + + for (i in mag.size - 1 downTo 0) { + val cur: ULong = mag[i].toULong() + (carry shl BASE_SIZE) + result[i] = (cur / x).toUInt() + carry = cur % x + } + return result + } + + internal fun divideMagnitudes(mag1_: Magnitude, mag2: Magnitude): Magnitude { + val mag1 = ULongArray(mag1_.size) { mag1_[it].toULong() } + + val resultLength: Int = mag1.size - mag2.size + 1 + val result = LongArray(resultLength) + + for (i in mag1.size - 1 downTo mag2.size - 1) { + val div: ULong = mag1[i] / mag2[mag2.size - 1] + result[i - mag2.size + 1] = div.toLong() + for (j in mag2.indices) { + mag1[i - j] -= mag2[mag2.size - 1 - j] * div + } + if (i > 0) { + mag1[i - 1] += (mag1[i] shl BASE_SIZE) + } + } + + val normalizedResult = Magnitude(resultLength) + var carry = 0L + + for (i in result.indices) { + result[i] += carry + if (result[i] < 0L) { + normalizedResult[i] = (result[i] + (BASE + 1UL).toLong()).toUInt() + carry = -1 + } else { + normalizedResult[i] = result[i].toUInt() + carry = 0 + } + } + + return normalizedResult + } + } + + override fun compareTo(other: KBigInteger): Int { + return when { + (this.sign == 0) and (other.sign == 0) -> 0 + this.sign < other.sign -> -1 + this.sign > other.sign -> 1 + else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude) + } + } + + override fun equals(other: Any?): Boolean { + if (other is KBigInteger) { + return this.compareTo(other) == 0 + } + else error("Can't compare KBigInteger to a different type") + } + + override fun hashCode(): Int { + return magnitude.hashCode() + this.sign + } + + operator fun unaryMinus(): KBigInteger { + return if (this.sign == 0) this else KBigInteger(-this.sign, this.magnitude) + } + + override operator fun plus(b: KBigInteger): KBigInteger { + return when { + b.sign == 0 -> this + this.sign == 0 -> b + this == -b -> ZERO + this.sign == b.sign -> KBigInteger(this.sign, addMagnitudes(this.magnitude, b.magnitude)) + else -> { + val comp: Int = compareMagnitudes(this.magnitude, b.magnitude) + + if (comp == 1) { + KBigInteger(this.sign, subtractMagnitudes(this.magnitude, b.magnitude)) + } else { + KBigInteger(-this.sign, subtractMagnitudes(b.magnitude, this.magnitude)) + } + } + } + } + + override operator fun minus(b: KBigInteger): KBigInteger { + return this + (-b) + } + + override operator fun times(b: KBigInteger): KBigInteger { + return when { + this.sign == 0 -> ZERO + b.sign == 0 -> ZERO +// TODO: Karatsuba + else -> KBigInteger(this.sign * b.sign, multiplyMagnitudes(this.magnitude, b.magnitude)) + } + } + + operator fun times(other: UInt): KBigInteger { + return when { + this.sign == 0 -> ZERO + other == 0U -> ZERO + else -> KBigInteger(this.sign, multiplyMagnitudeByUInt(this.magnitude, other)) + } + } + + operator fun times(other: Int): KBigInteger { + return if (other > 0) + this * kotlin.math.abs(other).toUInt() + else + -this * kotlin.math.abs(other).toUInt() + } + + operator fun div(other: UInt): KBigInteger { + return KBigInteger(this.sign, divideMagnitudeByUInt(this.magnitude, other)) + } + + operator fun div(other: Int): KBigInteger { + return KBigInteger(this.sign * other.sign, divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())) + } + + operator fun div(other: KBigInteger): KBigInteger { + return when { + this < other -> ZERO + this == other -> ONE + else -> KBigInteger(this.sign * other.sign, divideMagnitudes(this.magnitude, other.magnitude)) + } + } + + operator fun rem(other: Int): Int { + val res = this - (this / other) * other + return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt() + } + + operator fun rem(other: KBigInteger): KBigInteger { + return this - (this / other) * other + } + + fun modPow(exponent: KBigInteger, m: KBigInteger): KBigInteger { + return when { + exponent == ZERO -> ONE + exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m + else -> { + val sqRoot = modPow(exponent / 2, m) + (sqRoot * sqRoot) % m + } + } + } + + override fun toString(): String { + if (this.sign == 0) { + return "0x0" + } + var res: String = if (this.sign == -1) "-0x" else "0x" + var numberStarted = false + + for (i in this.magnitude.size - 1 downTo 0) { + for (j in BASE_SIZE / 4 - 1 downTo 0) { + val curByte = (this.magnitude[i] shr 4 * j) and 0xfU + if (numberStarted or (curByte != 0U)) { + numberStarted = true + res += hexMapping[curByte] + } + } + } + + return res + } +} + +@kotlin.ExperimentalUnsignedTypes +fun abs(x: KBigInteger): KBigInteger { + return if (x.sign == 0) x else KBigInteger(1, x.magnitude) +} + +@kotlin.ExperimentalUnsignedTypes +// Can't put it as constructor in class due to platform declaration clash with KBigInteger(Int) +fun KBigInteger(x: UInt): KBigInteger + = KBigInteger(1, uintArrayOf(x)) + +@kotlin.ExperimentalUnsignedTypes +// Can't put it as constructor in class due to platform declaration clash with KBigInteger(Long) +fun KBigInteger(x: ULong): KBigInteger + = KBigInteger(1, uintArrayOf((x and KBigInteger.BASE).toUInt(), ((x shr KBigInteger.BASE_SIZE) and KBigInteger.BASE).toUInt())) + +val hexChToInt = hashMapOf('0' to 0, '1' to 1, '2' to 2, '3' to 3, '4' to 4, '5' to 5, '6' to 6, '7' to 7, + '8' to 8, '9' to 9, 'A' to 10, 'B' to 11, 'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15) + +// Returns None if a valid number can not be read from a string +fun KBigInteger(s: String): KBigInteger? { + val sign: Int + val sPositive: String + when { + s[0] == '+' -> { + sign = +1 + sPositive = s.substring(1) + } + s[0] == '-' -> { + sign = -1 + sPositive = s.substring(1) + } + else -> { + sPositive = s + sign = +1 + } + } + var res = KBigInteger.ZERO + var digitValue = KBigInteger.ONE + val sPositiveUpper = sPositive.toUpperCase() + if (sPositiveUpper.startsWith("0X")) { // hex representation + val sHex = sPositiveUpper.substring(2) + for (ch in sHex.reversed()) { + if (ch == '_') continue + res += digitValue * (hexChToInt[ch] ?: return null) + digitValue *= KBigInteger(16) + } + } + else { // decimal representation + val sDecimal = sPositiveUpper + for (ch in sDecimal.reversed()) { + if (ch == '_') continue + if (ch !in '0'..'9') { + return null + } + res += digitValue * (ch.toInt() - '0'.toInt()) + digitValue *= KBigInteger(10) + } + } + return res * sign +} diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt new file mode 100644 index 000000000..10b89f28b --- /dev/null +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt @@ -0,0 +1,480 @@ +package scientifik.kmath.operations + +import kotlin.test.Test +import kotlin.test.assertTrue +import kotlin.test.assertEquals + +@kotlin.ExperimentalUnsignedTypes +class KBigIntegerConstructorTest { + @Test + fun testConstructorZero() { + assertEquals(KBigInteger(0), KBigInteger(0, uintArrayOf())) + } + + @Test + fun testConstructor8() { + assertEquals(KBigInteger(8), KBigInteger(1, uintArrayOf(8U))) + } + + @Test + fun testConstructor_0xffffffffaL() { + val x = KBigInteger(-0xffffffffaL) + val y = KBigInteger(-1, uintArrayOf(0xfffffffaU, 0xfU)) + assertEquals(x, y) + } +} + +@kotlin.ExperimentalUnsignedTypes +class KBigIntegerCompareTest { + @Test + fun testCompare1_2() { + val x = KBigInteger(1) + val y = KBigInteger(2) + assertTrue { x < y } + } + + @Test + fun testCompare0_0() { + val x = KBigInteger(0) + val y = KBigInteger(0) + assertEquals(x, y) + } + + @Test + fun testCompare1__2() { + val x = KBigInteger(1) + val y = KBigInteger(-2) + assertTrue { x > y } + } + + @Test + fun testCompare_1__2() { + val x = KBigInteger(-1) + val y = KBigInteger(-2) + assertTrue { x > y } + } + + @Test + fun testCompare_2__1() { + val x = KBigInteger(-2) + val y = KBigInteger(-1) + assertTrue { x < y } + } + + @Test + fun testCompare12345_12345() { + val x = KBigInteger(12345) + val y = KBigInteger(12345) + assertEquals(x, y) + } + + @Test + fun testEqualsWithLong() { + val x = KBigInteger(12345) + assertTrue { x == KBigInteger(12345L) } + } + + @Test + fun testEqualsWithULong() { + val x = KBigInteger(12345) + assertTrue { x == KBigInteger(12345UL) } + } + + @Test + fun testCompareBigNumbersGreater() { + val x = KBigInteger(0xfffffffffL) + val y = KBigInteger(0xffffffffaL) + assertTrue { x > y } + } + + @Test + fun testCompareBigNumbersEqual() { + val x = KBigInteger(0xffffffffaL) + val y = KBigInteger(0xffffffffaL) + assertEquals(x, y) + } + + @Test + fun testCompareBigNumbersLess() { + val x = KBigInteger(-0xffffffffaL) + val y = KBigInteger(0xffffffffaL) + assertTrue { x < y } + } +} + +@kotlin.ExperimentalUnsignedTypes +class KBigIntegerOperationsTest { + @Test + fun testPlus_1_1() { + val x = KBigInteger(1) + val y = KBigInteger(1) + + val res = x + y + val sum = KBigInteger(2) + + assertEquals(sum, res) + } + + @Test + fun testPlusBigNumbers() { + val x = KBigInteger(0x7fffffff) + val y = KBigInteger(0x7fffffff) + val z = KBigInteger(0x7fffffff) + + val res = x + y + z + val sum = KBigInteger(1, uintArrayOf(0x7ffffffdU, 0x1U)) + + assertEquals(sum, res) + } + + @Test + fun testUnaryMinus() { + val x = KBigInteger(1234) + val y = KBigInteger(-1234) + assertEquals(-x, y) + } + + @Test + fun testMinus_2_1() { + val x = KBigInteger(2) + val y = KBigInteger(1) + + val res = x - y + val sum = KBigInteger(1) + + assertEquals(sum, res) + } + + @Test + fun testMinus__2_1() { + val x = KBigInteger(-2) + val y = KBigInteger(1) + + val res = x - y + val sum = KBigInteger(-3) + + assertEquals(sum, res) + } + + @Test + fun testMinus___2_1() { + val x = KBigInteger(-2) + val y = KBigInteger(1) + + val res = -x - y + val sum = KBigInteger(1) + + assertEquals(sum, res) + } + + @Test + fun testMinusBigNumbers() { + val x = KBigInteger(12345) + val y = KBigInteger(0xffffffffaL) + + val res = x - y + val sum = KBigInteger(-0xfffffcfc1L) + + assertEquals(sum, res) + } + + @Test + fun testMultiply_2_3() { + val x = KBigInteger(2) + val y = KBigInteger(3) + + val res = x * y + val prod = KBigInteger(6) + + assertEquals(prod, res) + } + + @Test + fun testMultiply__2_3() { + val x = KBigInteger(-2) + val y = KBigInteger(3) + + val res = x * y + val prod = KBigInteger(-6) + + assertEquals(prod, res) + } + + @Test + fun testMultiply_0xfff123_0xfff456() { + val x = KBigInteger(0xfff123) + val y = KBigInteger(0xfff456) + + val res = x * y + val prod = KBigInteger(0xffe579ad5dc2L) + + assertEquals(prod, res) + } + + @Test + fun testMultiplyUInt_0xfff123_0xfff456() { + val x = KBigInteger(0xfff123) + val y = 0xfff456U + + val res = x * y + val prod = KBigInteger(0xffe579ad5dc2L) + + assertEquals(prod, res) + } + + @Test + fun testMultiplyInt_0xfff123__0xfff456() { + val x = KBigInteger(0xfff123) + val y = -0xfff456 + + val res = x * y + val prod = KBigInteger(-0xffe579ad5dc2L) + + assertEquals(prod, res) + } + + @Test + fun testMultiply_0xffffffff_0xffffffff() { + val x = KBigInteger(0xffffffffL) + val y = KBigInteger(0xffffffffL) + + val res = x * y + val prod = KBigInteger(0xfffffffe00000001UL) + + assertEquals(prod, res) + } + + @Test + fun test_square_0x11223344U_0xad2ffffdU_0x17eU() { + val num = KBigInteger(-1, uintArrayOf(0x11223344U, 0xad2ffffdU, 0x17eU )) + println(num) + val res = num * num + assertEquals(res, KBigInteger(1, uintArrayOf(0xb0542a10U, 0xbbd85bc8U, 0x2a1fa515U, 0x5069e03bU, 0x23c09U))) + } + + @Test + fun testDivision_6_3() { + val x = KBigInteger(6) + val y = 3U + + val res = x / y + val div = KBigInteger(2) + + assertEquals(div, res) + } + + @Test + fun testBigDivision_6_3() { + val x = KBigInteger(6) + val y = KBigInteger(3) + + val res = x / y + val div = KBigInteger(2) + + assertEquals(div, res) + } + + @Test + fun testDivision_20__3() { + val x = KBigInteger(20) + val y = -3 + + val res = x / y + val div = KBigInteger(-6) + + assertEquals(div, res) + } + + @Test + fun testBigDivision_20__3() { + val x = KBigInteger(20) + val y = KBigInteger(-3) + + val res = x / y + val div = KBigInteger(-6) + + assertEquals(div, res) + } + + @Test + fun testDivision_0xfffffffe00000001_0xffffffff() { + val x = KBigInteger(0xfffffffe00000001UL) + val y = 0xffffffffU + + val res = x / y + val div = KBigInteger(0xffffffffL) + + assertEquals(div, res) + } + + @Test + fun testBigDivision_0xfffffffe00000001_0xffffffff() { + val x = KBigInteger(0xfffffffe00000001UL) + val y = KBigInteger(0xffffffffU) + + val res = x / y + val div = KBigInteger(0xffffffffL) + + assertEquals(div, res) + } + + @Test + fun testMod_20_3() { + val x = KBigInteger(20) + val y = 3 + + val res = x % y + val mod = 2 + + assertEquals(mod, res) + } + + @Test + fun testBigMod_20_3() { + val x = KBigInteger(20) + val y = KBigInteger(3) + + val res = x % y + val mod = KBigInteger(2) + + assertEquals(mod, res) + } + + @Test + fun testMod_0xfffffffe00000001_12345() { + val x = KBigInteger(0xfffffffe00000001UL) + val y = 12345 + + val res = x % y + val mod = 1980 + + assertEquals(mod, res) + } + + @Test + fun testBigMod_0xfffffffe00000001_12345() { + val x = KBigInteger(0xfffffffe00000001UL) + val y = KBigInteger(12345) + + val res = x % y + val mod = KBigInteger(1980) + + assertEquals(mod, res) + } + + @Test + fun testModPow_3_10_17() { + val x = KBigInteger(3) + val exp = KBigInteger(10) + val mod = KBigInteger(17) + + val res = KBigInteger(8) + + return assertEquals(res, x.modPow(exp, mod)) + } + + @Test + fun testModPowBigNumbers() { + val x = KBigInteger(0xfffffffeabcdef01UL) + val exp = KBigInteger(2) + val mod = KBigInteger(0xfffffffeabcUL) + + val res = KBigInteger(0x6deec7895faUL) + + return assertEquals(res, x.modPow(exp, mod)) + } + + @Test + fun testModBigNumbers() { + val x = KBigInteger(0xfffffffeabcdef01UL) + val mod = KBigInteger(0xfffffffeabcUL) + + val res = KBigInteger(0xdef01) + + return assertEquals(res, x % mod) + } +} + +@kotlin.ExperimentalUnsignedTypes +class KBigIntegerConversionsTest { + @Test + fun testToString0x10() { + val x = KBigInteger(0x10) + assertEquals("0x10", x.toString()) + } + + @Test + fun testToString0x17ffffffd() { + val x = KBigInteger(0x17ffffffdL) + assertEquals("0x17ffffffd", x.toString()) + } + + @Test + fun testToString_0x17ead2ffffd() { + val x = KBigInteger(-0x17ead2ffffdL) + assertEquals("-0x17ead2ffffd", x.toString()) + } + + @Test + fun testToString_0x17ead2ffffd11223344() { + val x = KBigInteger(-1, uintArrayOf(0x11223344U, 0xad2ffffdU, 0x17eU )) + assertEquals("-0x17ead2ffffd11223344", x.toString()) + } + + @Test + fun testFromString_0x17ead2ffffd11223344() { + val x = KBigInteger("0x17ead2ffffd11223344")!! + assertEquals( "0x17ead2ffffd11223344", x.toString()) + } + + @Test + fun testFromString_7059135710711894913860() { + val x = KBigInteger("-7059135710711894913860") + assertEquals("-0x17ead2ffffd11223344", x.toString()) + } +} + +@kotlin.ExperimentalUnsignedTypes +class DivisionTests { + // TODO + @Test + fun test_0xfffffffeabcdef01UL_0xfffffffeabc() { + val res = KBigInteger(0xfffffffeabcdef01UL) / KBigInteger(0xfffffffeabc) + assertEquals(res, KBigInteger(0x100000)) + + } + +// println(KBigInteger(+1, uintArrayOf(1000U, 1000U, 1000U)) / KBigInteger(0xfffffffeabc) ) + // >>> hex((1000 + 1000*2**32 + 1000*2**64)/ 0xfffffffeabc) == 0x3e800000L +// println(KBigInteger(+1, KBigInteger.divideMagnitudeByUInt(uintArrayOf(1000U, 1000U, 1000U),456789U))) + // 0x8f789719813969L + +} + +class KBigIntegerRingTest { + @Test + fun testSum() { + val res = KBigIntegerRing { + KBigInteger(1_000L) * KBigInteger(1_000L) + } + assertEquals(res, KBigInteger(1_000_000) ) + } + + @Test + fun test_sum_100_000_000__100_000_000() { + KBigIntegerRing { + val sum = +"100_000_000" + +"100_000_000" + assertEquals(sum, KBigInteger("200_000_000")) + } + } + + @Test + fun test_mul_3__4() { + KBigIntegerRing { + val prod = +"0x3000_0000_0000" * +"0x4000_0000_0000_0000_0000" + assertEquals(prod, KBigInteger("0xc00_0000_0000_0000_0000_0000_0000_0000")) + } + } + +} + From 19d1459a558fb19731a1ba713b184342d5c62bbd Mon Sep 17 00:00:00 2001 From: Peter Klimai Date: Wed, 1 Apr 2020 23:56:39 +0300 Subject: [PATCH 2/3] Fix division and add tests --- .../kmath/operations/KBigInteger.kt | 136 ++++++++++++------ .../kmath/operations/KBigIntegerTest.kt | 103 ++++++++++--- 2 files changed, 179 insertions(+), 60 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt index 0238492d2..036f44a91 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt @@ -1,6 +1,8 @@ package scientifik.kmath.operations +import kotlin.math.log2 import kotlin.math.max +import kotlin.math.min import kotlin.math.sign /* @@ -27,7 +29,6 @@ object KBigIntegerRing: Ring { operator fun String.unaryPlus(): KBigInteger = KBigInteger(this)!! operator fun String.unaryMinus(): KBigInteger = -KBigInteger(this)!! - } @@ -161,7 +162,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return result } - internal fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { + private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { val resultLength: Int = mag.size val result = Magnitude(resultLength) var carry: ULong = 0UL @@ -174,39 +175,6 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return result } - internal fun divideMagnitudes(mag1_: Magnitude, mag2: Magnitude): Magnitude { - val mag1 = ULongArray(mag1_.size) { mag1_[it].toULong() } - - val resultLength: Int = mag1.size - mag2.size + 1 - val result = LongArray(resultLength) - - for (i in mag1.size - 1 downTo mag2.size - 1) { - val div: ULong = mag1[i] / mag2[mag2.size - 1] - result[i - mag2.size + 1] = div.toLong() - for (j in mag2.indices) { - mag1[i - j] -= mag2[mag2.size - 1 - j] * div - } - if (i > 0) { - mag1[i - 1] += (mag1[i] shl BASE_SIZE) - } - } - - val normalizedResult = Magnitude(resultLength) - var carry = 0L - - for (i in result.indices) { - result[i] += carry - if (result[i] < 0L) { - normalizedResult[i] = (result[i] + (BASE + 1UL).toLong()).toUInt() - carry = -1 - } else { - normalizedResult[i] = result[i].toUInt() - carry = 0 - } - } - - return normalizedResult - } } override fun compareTo(other: KBigInteger): Int { @@ -287,12 +255,100 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return KBigInteger(this.sign * other.sign, divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())) } - operator fun div(other: KBigInteger): KBigInteger { - return when { - this < other -> ZERO - this == other -> ONE - else -> KBigInteger(this.sign * other.sign, divideMagnitudes(this.magnitude, other.magnitude)) + private fun division(other: KBigInteger): Pair { + // Long division algorithm: + // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder + // TODO: Implement more effective algorithm + var q: KBigInteger = ZERO + var r: KBigInteger = ZERO + + val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt() + for (i in bitSize downTo 0) { + r = r shl 1 + r = r or ((abs(this) shr i) and ONE) + if (r >= abs(other)) { + r -= abs(other) + q += (ONE shl i) + } } + + return Pair(KBigInteger(this.sign * other.sign, q.magnitude), r) + } + + operator fun div(other: KBigInteger): KBigInteger { + return this.division(other).first + } + + infix fun shl(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + 1 + val relShift = i % BASE_SIZE + val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift} + val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)} + + val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) + + for (j in this.magnitude.indices) { + newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) + if (j != 0) { + newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) + } + } + + newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last()) + + return KBigInteger(this.sign, newMagnitude) + } + + infix fun shr(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + val relShift = i % BASE_SIZE + val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift} + val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)} + if (this.magnitude.size - fullShifts <= 0) { + return ZERO + } + val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts) + + for (j in fullShifts until this.magnitude.size) { + newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) + if (j != this.magnitude.size - 1) { + newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) + } + } + + return KBigInteger(this.sign, newMagnitude) + } + + infix fun or(other: KBigInteger): KBigInteger { + if (this == ZERO) return other; + if (other == ZERO) return this; + val resSize = max(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + if (i < this.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or this.magnitude[i] + } + if (i < other.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or other.magnitude[i] + } + } + return KBigInteger(1, newMagnitude) + } + + infix fun and(other: KBigInteger): KBigInteger { + if ((this == ZERO) or (other == ZERO)) return ZERO; + val resSize = min(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + newMagnitude[i] = this.magnitude[i] and other.magnitude[i] + } + return KBigInteger(1, newMagnitude) } operator fun rem(other: Int): Int { diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt index 10b89f28b..daece0f45 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt @@ -244,6 +244,62 @@ class KBigIntegerOperationsTest { assertEquals(prod, res) } + @Test + fun test_shr_20() { + val x = KBigInteger(20) + assertEquals(KBigInteger(10), x shr 1) + } + + @Test + fun test_shl_20() { + val x = KBigInteger(20) + assertEquals(KBigInteger(40), x shl 1) + } + + @Test + fun test_shl_1_0() { + assertEquals(KBigInteger.ONE, KBigInteger.ONE shl 0) + } + + @Test + fun test_shl_1_32() { + assertEquals(KBigInteger(0x100000000UL), KBigInteger.ONE shl 32) + } + + @Test + fun test_shl_1_33() { + assertEquals(KBigInteger(0x200000000UL), KBigInteger.ONE shl 33) + } + + @Test + fun test_shr_1_33_33() { + assertEquals(KBigInteger.ONE, (KBigInteger.ONE shl 33) shr 33) + } + + @Test + fun test_shr_1_32() { + assertEquals(KBigInteger.ZERO, KBigInteger.ONE shr 32) + } + + @Test + fun test_and_123_456() { + val x = KBigInteger(123) + val y = KBigInteger(456) + assertEquals(KBigInteger(72), x and y) + } + + @Test + fun test_or_123_456() { + val x = KBigInteger(123) + val y = KBigInteger(456) + assertEquals(KBigInteger(507), x or y) + } + + @Test + fun test_asd() { + assertEquals(KBigInteger.ONE, KBigInteger.ZERO or ((KBigInteger(20) shr 4) and KBigInteger.ONE)) + } + @Test fun test_square_0x11223344U_0xad2ffffdU_0x17eU() { val num = KBigInteger(-1, uintArrayOf(0x11223344U, 0xad2ffffdU, 0x17eU )) @@ -307,6 +363,12 @@ class KBigIntegerOperationsTest { assertEquals(div, res) } + @Test + fun testBigDivision_0xfffffffeabcdef01UL_0xfffffffeabc() { + val res = KBigInteger(0xfffffffeabcdef01UL) / KBigInteger(0xfffffffeabc) + assertEquals(res, KBigInteger(0x100000)) + } + @Test fun testBigDivision_0xfffffffe00000001_0xffffffff() { val x = KBigInteger(0xfffffffe00000001UL) @@ -379,7 +441,7 @@ class KBigIntegerOperationsTest { val exp = KBigInteger(2) val mod = KBigInteger(0xfffffffeabcUL) - val res = KBigInteger(0x6deec7895faUL) + val res = KBigInteger(0xc2253cde01) return assertEquals(res, x.modPow(exp, mod)) } @@ -434,26 +496,9 @@ class KBigIntegerConversionsTest { } } -@kotlin.ExperimentalUnsignedTypes -class DivisionTests { - // TODO - @Test - fun test_0xfffffffeabcdef01UL_0xfffffffeabc() { - val res = KBigInteger(0xfffffffeabcdef01UL) / KBigInteger(0xfffffffeabc) - assertEquals(res, KBigInteger(0x100000)) - - } - -// println(KBigInteger(+1, uintArrayOf(1000U, 1000U, 1000U)) / KBigInteger(0xfffffffeabc) ) - // >>> hex((1000 + 1000*2**32 + 1000*2**64)/ 0xfffffffeabc) == 0x3e800000L -// println(KBigInteger(+1, KBigInteger.divideMagnitudeByUInt(uintArrayOf(1000U, 1000U, 1000U),456789U))) - // 0x8f789719813969L - -} - class KBigIntegerRingTest { @Test - fun testSum() { + fun testKBigIntegerRingSum() { val res = KBigIntegerRing { KBigInteger(1_000L) * KBigInteger(1_000L) } @@ -461,7 +506,7 @@ class KBigIntegerRingTest { } @Test - fun test_sum_100_000_000__100_000_000() { + fun testKBigIntegerRingSum_100_000_000__100_000_000() { KBigIntegerRing { val sum = +"100_000_000" + +"100_000_000" assertEquals(sum, KBigInteger("200_000_000")) @@ -476,5 +521,23 @@ class KBigIntegerRingTest { } } + @Test + fun test_div_big_1() { + KBigIntegerRing { + val res = +"1_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000" / + +"555_000_444_000_333_000_222_000_111_000_999_001" + assertEquals(res, +"1801800360360432432518919022699") + } + } + + @Test + fun test_rem_big_1() { + KBigIntegerRing { + val res = +"1_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000" % + +"555_000_444_000_333_000_222_000_111_000_999_001" + assertEquals(res, +"324121220440768000291647788404676301") + } + } + } From 48cb683bc4dd28c1c9f509cbf04dc538952f9082 Mon Sep 17 00:00:00 2001 From: Peter Klimai Date: Wed, 15 Apr 2020 18:55:13 +0300 Subject: [PATCH 3/3] Refactoring of KBigInteger --- .../kmath/operations/KBigInteger.kt | 505 +++++++++--------- .../kmath/operations/KBigIntegerTest.kt | 8 +- 2 files changed, 259 insertions(+), 254 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt index 036f44a91..b9b2bbb81 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt @@ -26,45 +26,256 @@ object KBigIntegerRing: Ring { override fun multiply(a: KBigInteger, b: KBigInteger): KBigInteger = a.times(b) - operator fun String.unaryPlus(): KBigInteger = KBigInteger(this)!! + operator fun String.unaryPlus(): KBigInteger = this.toKBigInteger()!! - operator fun String.unaryMinus(): KBigInteger = -KBigInteger(this)!! + operator fun String.unaryMinus(): KBigInteger = -this.toKBigInteger()!! } - @kotlin.ExperimentalUnsignedTypes -class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): - RingElement, Comparable { +class KBigInteger internal constructor( + private val sign: Byte = 0, + private val magnitude: Magnitude = Magnitude(0) +): Comparable { - constructor(x: Int) : this(x.sign, uintArrayOf(kotlin.math.abs(x).toUInt())) - constructor(x: Long) : this(x.sign, uintArrayOf( + constructor(x: Int) : this(x.sign.toByte(), uintArrayOf(kotlin.math.abs(x).toUInt())) + + constructor(x: Long) : this(x.sign.toByte(), stripLeadingZeros(uintArrayOf( (kotlin.math.abs(x).toULong() and BASE).toUInt(), - ((kotlin.math.abs(x).toULong() shr BASE_SIZE) and BASE).toUInt())) + ((kotlin.math.abs(x).toULong() shr BASE_SIZE) and BASE).toUInt()))) - val magnitude = stripLeadingZeros(magnitude) - val sign = if (this.magnitude.isNotEmpty()) sign else 0 - val sizeByte: Int = magnitude.size * BASE_SIZE / 4 + override fun compareTo(other: KBigInteger): Int { + return when { + (this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 + this.sign < other.sign -> -1 + this.sign > other.sign -> 1 + else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude) + } + } - override val context: KBigIntegerRing get() = KBigIntegerRing + override fun equals(other: Any?): Boolean { + if (other is KBigInteger) { + return this.compareTo(other) == 0 + } + else error("Can't compare KBigInteger to a different type") + } - override fun unwrap(): KBigInteger = this - override fun KBigInteger.wrap(): KBigInteger = this + override fun hashCode(): Int { + return magnitude.hashCode() + this.sign + } + + fun abs(): KBigInteger = if (sign == 0.toByte()) this else KBigInteger(1, magnitude) + + operator fun unaryMinus(): KBigInteger { + return if (this.sign == 0.toByte()) this else KBigInteger((-this.sign).toByte(), this.magnitude) + } + + operator fun plus(b: KBigInteger): KBigInteger { + return when { + b.sign == 0.toByte() -> this + this.sign == 0.toByte() -> b + this == -b -> ZERO + this.sign == b.sign -> KBigInteger(this.sign, addMagnitudes(this.magnitude, b.magnitude)) + else -> { + val comp: Int = compareMagnitudes(this.magnitude, b.magnitude) + + if (comp == 1) { + KBigInteger(this.sign, subtractMagnitudes(this.magnitude, b.magnitude)) + } else { + KBigInteger((-this.sign).toByte(), subtractMagnitudes(b.magnitude, this.magnitude)) + } + } + } + } + + operator fun minus(b: KBigInteger): KBigInteger { + return this + (-b) + } + + operator fun times(b: KBigInteger): KBigInteger { + return when { + this.sign == 0.toByte() -> ZERO + b.sign == 0.toByte() -> ZERO +// TODO: Karatsuba + else -> KBigInteger((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude)) + } + } + + operator fun times(other: UInt): KBigInteger { + return when { + this.sign == 0.toByte() -> ZERO + other == 0U -> ZERO + else -> KBigInteger(this.sign, multiplyMagnitudeByUInt(this.magnitude, other)) + } + } + + operator fun times(other: Int): KBigInteger { + return if (other > 0) + this * kotlin.math.abs(other).toUInt() + else + -this * kotlin.math.abs(other).toUInt() + } + + operator fun div(other: UInt): KBigInteger { + return KBigInteger(this.sign, divideMagnitudeByUInt(this.magnitude, other)) + } + + operator fun div(other: Int): KBigInteger { + return KBigInteger((this.sign * other.sign).toByte(), + divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())) + } + + private fun division(other: KBigInteger): Pair { + // Long division algorithm: + // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder + // TODO: Implement more effective algorithm + var q: KBigInteger = ZERO + var r: KBigInteger = ZERO + + val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt() + for (i in bitSize downTo 0) { + r = r shl 1 + r = r or ((abs(this) shr i) and ONE) + if (r >= abs(other)) { + r -= abs(other) + q += (ONE shl i) + } + } + + return Pair(KBigInteger((this.sign * other.sign).toByte(), q.magnitude), r) + } + + operator fun div(other: KBigInteger): KBigInteger { + return this.division(other).first + } + + infix fun shl(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + 1 + val relShift = i % BASE_SIZE + val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift} + val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)} + + val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) + + for (j in this.magnitude.indices) { + newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) + if (j != 0) { + newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) + } + } + + newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last()) + + return KBigInteger(this.sign, stripLeadingZeros(newMagnitude)) + } + + infix fun shr(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + val relShift = i % BASE_SIZE + val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift} + val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)} + if (this.magnitude.size - fullShifts <= 0) { + return ZERO + } + val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts) + + for (j in fullShifts until this.magnitude.size) { + newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) + if (j != this.magnitude.size - 1) { + newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) + } + } + + return KBigInteger(this.sign, stripLeadingZeros(newMagnitude)) + } + + infix fun or(other: KBigInteger): KBigInteger { + if (this == ZERO) return other; + if (other == ZERO) return this; + val resSize = max(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + if (i < this.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or this.magnitude[i] + } + if (i < other.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or other.magnitude[i] + } + } + return KBigInteger(1, stripLeadingZeros(newMagnitude)) + } + + infix fun and(other: KBigInteger): KBigInteger { + if ((this == ZERO) or (other == ZERO)) return ZERO; + val resSize = min(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + newMagnitude[i] = this.magnitude[i] and other.magnitude[i] + } + return KBigInteger(1, stripLeadingZeros(newMagnitude)) + } + + operator fun rem(other: Int): Int { + val res = this - (this / other) * other + return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt() + } + + operator fun rem(other: KBigInteger): KBigInteger { + return this - (this / other) * other + } + + fun modPow(exponent: KBigInteger, m: KBigInteger): KBigInteger { + return when { + exponent == ZERO -> ONE + exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m + else -> { + val sqRoot = modPow(exponent / 2, m) + (sqRoot * sqRoot) % m + } + } + } + + override fun toString(): String { + if (this.sign == 0.toByte()) { + return "0x0" + } + var res: String = if (this.sign == (-1).toByte()) "-0x" else "0x" + var numberStarted = false + + for (i in this.magnitude.size - 1 downTo 0) { + for (j in BASE_SIZE / 4 - 1 downTo 0) { + val curByte = (this.magnitude[i] shr 4 * j) and 0xfU + if (numberStarted or (curByte != 0U)) { + numberStarted = true + res += hexMapping[curByte] + } + } + } + + return res + } companion object { - val BASE = 0xffffffffUL + const val BASE = 0xffffffffUL const val BASE_SIZE: Int = 32 val ZERO: KBigInteger = KBigInteger() val ONE: KBigInteger = KBigInteger(1) private val hexMapping: HashMap = hashMapOf( - 0U to "0", 1U to "1", 2U to "2", 3U to "3", 4U to "4", 5U to "5", 6U to "6", 7U to "7", 8U to "8", - 9U to "9", 10U to "a", 11U to "b", 12U to "c", 13U to "d", 14U to "e", 15U to "f" + 0U to "0", 1U to "1", 2U to "2", 3U to "3", + 4U to "4", 5U to "5", 6U to "6", 7U to "7", + 8U to "8", 9U to "9", 10U to "a", 11U to "b", + 12U to "c", 13U to "d", 14U to "e", 15U to "f" ) - private fun stripLeadingZeros(mag: Magnitude): Magnitude { - // TODO: optimize performance - if (mag.isEmpty()) { + internal fun stripLeadingZeros(mag: Magnitude): Magnitude { + if (mag.isEmpty() || mag.last() != 0U) { return mag } var resSize: Int = mag.size - 1 @@ -108,7 +319,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): carry = (res shr BASE_SIZE) } result[resultLength - 1] = carry.toUInt() - return result + return stripLeadingZeros(result) } private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { @@ -127,7 +338,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): result[i] = res.toUInt() } - return result + return stripLeadingZeros(result) } private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { @@ -142,7 +353,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): } result[resultLength - 1] = (carry and BASE).toUInt() - return result + return stripLeadingZeros(result) } private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude { @@ -159,7 +370,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): result[i + mag2.size] = (carry and BASE).toUInt() } - return result + return stripLeadingZeros(result) } private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { @@ -172,230 +383,15 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): result[i] = (cur / x).toUInt() carry = cur % x } - return result + return stripLeadingZeros(result) } } - override fun compareTo(other: KBigInteger): Int { - return when { - (this.sign == 0) and (other.sign == 0) -> 0 - this.sign < other.sign -> -1 - this.sign > other.sign -> 1 - else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude) - } - } - - override fun equals(other: Any?): Boolean { - if (other is KBigInteger) { - return this.compareTo(other) == 0 - } - else error("Can't compare KBigInteger to a different type") - } - - override fun hashCode(): Int { - return magnitude.hashCode() + this.sign - } - - operator fun unaryMinus(): KBigInteger { - return if (this.sign == 0) this else KBigInteger(-this.sign, this.magnitude) - } - - override operator fun plus(b: KBigInteger): KBigInteger { - return when { - b.sign == 0 -> this - this.sign == 0 -> b - this == -b -> ZERO - this.sign == b.sign -> KBigInteger(this.sign, addMagnitudes(this.magnitude, b.magnitude)) - else -> { - val comp: Int = compareMagnitudes(this.magnitude, b.magnitude) - - if (comp == 1) { - KBigInteger(this.sign, subtractMagnitudes(this.magnitude, b.magnitude)) - } else { - KBigInteger(-this.sign, subtractMagnitudes(b.magnitude, this.magnitude)) - } - } - } - } - - override operator fun minus(b: KBigInteger): KBigInteger { - return this + (-b) - } - - override operator fun times(b: KBigInteger): KBigInteger { - return when { - this.sign == 0 -> ZERO - b.sign == 0 -> ZERO -// TODO: Karatsuba - else -> KBigInteger(this.sign * b.sign, multiplyMagnitudes(this.magnitude, b.magnitude)) - } - } - - operator fun times(other: UInt): KBigInteger { - return when { - this.sign == 0 -> ZERO - other == 0U -> ZERO - else -> KBigInteger(this.sign, multiplyMagnitudeByUInt(this.magnitude, other)) - } - } - - operator fun times(other: Int): KBigInteger { - return if (other > 0) - this * kotlin.math.abs(other).toUInt() - else - -this * kotlin.math.abs(other).toUInt() - } - - operator fun div(other: UInt): KBigInteger { - return KBigInteger(this.sign, divideMagnitudeByUInt(this.magnitude, other)) - } - - operator fun div(other: Int): KBigInteger { - return KBigInteger(this.sign * other.sign, divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())) - } - - private fun division(other: KBigInteger): Pair { - // Long division algorithm: - // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder - // TODO: Implement more effective algorithm - var q: KBigInteger = ZERO - var r: KBigInteger = ZERO - - val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt() - for (i in bitSize downTo 0) { - r = r shl 1 - r = r or ((abs(this) shr i) and ONE) - if (r >= abs(other)) { - r -= abs(other) - q += (ONE shl i) - } - } - - return Pair(KBigInteger(this.sign * other.sign, q.magnitude), r) - } - - operator fun div(other: KBigInteger): KBigInteger { - return this.division(other).first - } - - infix fun shl(i: Int): KBigInteger { - if (this == ZERO) return ZERO - if (i == 0) return this - - val fullShifts = i / BASE_SIZE + 1 - val relShift = i % BASE_SIZE - val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift} - val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)} - - val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) - - for (j in this.magnitude.indices) { - newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) - if (j != 0) { - newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) - } - } - - newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last()) - - return KBigInteger(this.sign, newMagnitude) - } - - infix fun shr(i: Int): KBigInteger { - if (this == ZERO) return ZERO - if (i == 0) return this - - val fullShifts = i / BASE_SIZE - val relShift = i % BASE_SIZE - val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift} - val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)} - if (this.magnitude.size - fullShifts <= 0) { - return ZERO - } - val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts) - - for (j in fullShifts until this.magnitude.size) { - newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) - if (j != this.magnitude.size - 1) { - newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) - } - } - - return KBigInteger(this.sign, newMagnitude) - } - - infix fun or(other: KBigInteger): KBigInteger { - if (this == ZERO) return other; - if (other == ZERO) return this; - val resSize = max(this.magnitude.size, other.magnitude.size) - val newMagnitude: Magnitude = Magnitude(resSize) - for (i in 0 until resSize) { - if (i < this.magnitude.size) { - newMagnitude[i] = newMagnitude[i] or this.magnitude[i] - } - if (i < other.magnitude.size) { - newMagnitude[i] = newMagnitude[i] or other.magnitude[i] - } - } - return KBigInteger(1, newMagnitude) - } - - infix fun and(other: KBigInteger): KBigInteger { - if ((this == ZERO) or (other == ZERO)) return ZERO; - val resSize = min(this.magnitude.size, other.magnitude.size) - val newMagnitude: Magnitude = Magnitude(resSize) - for (i in 0 until resSize) { - newMagnitude[i] = this.magnitude[i] and other.magnitude[i] - } - return KBigInteger(1, newMagnitude) - } - - operator fun rem(other: Int): Int { - val res = this - (this / other) * other - return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt() - } - - operator fun rem(other: KBigInteger): KBigInteger { - return this - (this / other) * other - } - - fun modPow(exponent: KBigInteger, m: KBigInteger): KBigInteger { - return when { - exponent == ZERO -> ONE - exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m - else -> { - val sqRoot = modPow(exponent / 2, m) - (sqRoot * sqRoot) % m - } - } - } - - override fun toString(): String { - if (this.sign == 0) { - return "0x0" - } - var res: String = if (this.sign == -1) "-0x" else "0x" - var numberStarted = false - - for (i in this.magnitude.size - 1 downTo 0) { - for (j in BASE_SIZE / 4 - 1 downTo 0) { - val curByte = (this.magnitude[i] shr 4 * j) and 0xfU - if (numberStarted or (curByte != 0U)) { - numberStarted = true - res += hexMapping[curByte] - } - } - } - - return res - } } @kotlin.ExperimentalUnsignedTypes -fun abs(x: KBigInteger): KBigInteger { - return if (x.sign == 0) x else KBigInteger(1, x.magnitude) -} +fun abs(x: KBigInteger): KBigInteger = x.abs() @kotlin.ExperimentalUnsignedTypes // Can't put it as constructor in class due to platform declaration clash with KBigInteger(Int) @@ -405,26 +401,35 @@ fun KBigInteger(x: UInt): KBigInteger @kotlin.ExperimentalUnsignedTypes // Can't put it as constructor in class due to platform declaration clash with KBigInteger(Long) fun KBigInteger(x: ULong): KBigInteger - = KBigInteger(1, uintArrayOf((x and KBigInteger.BASE).toUInt(), ((x shr KBigInteger.BASE_SIZE) and KBigInteger.BASE).toUInt())) + = KBigInteger(1, + KBigInteger.stripLeadingZeros(uintArrayOf( + (x and KBigInteger.BASE).toUInt(), + ((x shr KBigInteger.BASE_SIZE) and KBigInteger.BASE).toUInt()) + ) + ) -val hexChToInt = hashMapOf('0' to 0, '1' to 1, '2' to 2, '3' to 3, '4' to 4, '5' to 5, '6' to 6, '7' to 7, - '8' to 8, '9' to 9, 'A' to 10, 'B' to 11, 'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15) +val hexChToInt = hashMapOf( + '0' to 0, '1' to 1, '2' to 2, '3' to 3, + '4' to 4, '5' to 5, '6' to 6, '7' to 7, + '8' to 8, '9' to 9, 'A' to 10, 'B' to 11, + 'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15 +) // Returns None if a valid number can not be read from a string -fun KBigInteger(s: String): KBigInteger? { +fun String.toKBigInteger(): KBigInteger? { val sign: Int val sPositive: String when { - s[0] == '+' -> { + this[0] == '+' -> { sign = +1 - sPositive = s.substring(1) + sPositive = this.substring(1) } - s[0] == '-' -> { + this[0] == '-' -> { sign = -1 - sPositive = s.substring(1) + sPositive = this.substring(1) } else -> { - sPositive = s + sPositive = this sign = +1 } } diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt index daece0f45..96e2dee9a 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt @@ -485,13 +485,13 @@ class KBigIntegerConversionsTest { @Test fun testFromString_0x17ead2ffffd11223344() { - val x = KBigInteger("0x17ead2ffffd11223344")!! + val x = "0x17ead2ffffd11223344".toKBigInteger() assertEquals( "0x17ead2ffffd11223344", x.toString()) } @Test fun testFromString_7059135710711894913860() { - val x = KBigInteger("-7059135710711894913860") + val x = "-7059135710711894913860".toKBigInteger() assertEquals("-0x17ead2ffffd11223344", x.toString()) } } @@ -509,7 +509,7 @@ class KBigIntegerRingTest { fun testKBigIntegerRingSum_100_000_000__100_000_000() { KBigIntegerRing { val sum = +"100_000_000" + +"100_000_000" - assertEquals(sum, KBigInteger("200_000_000")) + assertEquals(sum, "200_000_000".toKBigInteger()) } } @@ -517,7 +517,7 @@ class KBigIntegerRingTest { fun test_mul_3__4() { KBigIntegerRing { val prod = +"0x3000_0000_0000" * +"0x4000_0000_0000_0000_0000" - assertEquals(prod, KBigInteger("0xc00_0000_0000_0000_0000_0000_0000_0000")) + assertEquals(prod, "0xc00_0000_0000_0000_0000_0000_0000_0000".toKBigInteger()) } }