diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt index 0238492d2..036f44a91 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/KBigInteger.kt @@ -1,6 +1,8 @@ package scientifik.kmath.operations +import kotlin.math.log2 import kotlin.math.max +import kotlin.math.min import kotlin.math.sign /* @@ -27,7 +29,6 @@ object KBigIntegerRing: Ring { operator fun String.unaryPlus(): KBigInteger = KBigInteger(this)!! operator fun String.unaryMinus(): KBigInteger = -KBigInteger(this)!! - } @@ -161,7 +162,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return result } - internal fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { + private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude { val resultLength: Int = mag.size val result = Magnitude(resultLength) var carry: ULong = 0UL @@ -174,39 +175,6 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return result } - internal fun divideMagnitudes(mag1_: Magnitude, mag2: Magnitude): Magnitude { - val mag1 = ULongArray(mag1_.size) { mag1_[it].toULong() } - - val resultLength: Int = mag1.size - mag2.size + 1 - val result = LongArray(resultLength) - - for (i in mag1.size - 1 downTo mag2.size - 1) { - val div: ULong = mag1[i] / mag2[mag2.size - 1] - result[i - mag2.size + 1] = div.toLong() - for (j in mag2.indices) { - mag1[i - j] -= mag2[mag2.size - 1 - j] * div - } - if (i > 0) { - mag1[i - 1] += (mag1[i] shl BASE_SIZE) - } - } - - val normalizedResult = Magnitude(resultLength) - var carry = 0L - - for (i in result.indices) { - result[i] += carry - if (result[i] < 0L) { - normalizedResult[i] = (result[i] + (BASE + 1UL).toLong()).toUInt() - carry = -1 - } else { - normalizedResult[i] = result[i].toUInt() - carry = 0 - } - } - - return normalizedResult - } } override fun compareTo(other: KBigInteger): Int { @@ -287,12 +255,100 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)): return KBigInteger(this.sign * other.sign, divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt())) } - operator fun div(other: KBigInteger): KBigInteger { - return when { - this < other -> ZERO - this == other -> ONE - else -> KBigInteger(this.sign * other.sign, divideMagnitudes(this.magnitude, other.magnitude)) + private fun division(other: KBigInteger): Pair { + // Long division algorithm: + // https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder + // TODO: Implement more effective algorithm + var q: KBigInteger = ZERO + var r: KBigInteger = ZERO + + val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt() + for (i in bitSize downTo 0) { + r = r shl 1 + r = r or ((abs(this) shr i) and ONE) + if (r >= abs(other)) { + r -= abs(other) + q += (ONE shl i) + } } + + return Pair(KBigInteger(this.sign * other.sign, q.magnitude), r) + } + + operator fun div(other: KBigInteger): KBigInteger { + return this.division(other).first + } + + infix fun shl(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + 1 + val relShift = i % BASE_SIZE + val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift} + val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)} + + val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts) + + for (j in this.magnitude.indices) { + newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j]) + if (j != 0) { + newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1]) + } + } + + newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last()) + + return KBigInteger(this.sign, newMagnitude) + } + + infix fun shr(i: Int): KBigInteger { + if (this == ZERO) return ZERO + if (i == 0) return this + + val fullShifts = i / BASE_SIZE + val relShift = i % BASE_SIZE + val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift} + val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)} + if (this.magnitude.size - fullShifts <= 0) { + return ZERO + } + val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts) + + for (j in fullShifts until this.magnitude.size) { + newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j]) + if (j != this.magnitude.size - 1) { + newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1]) + } + } + + return KBigInteger(this.sign, newMagnitude) + } + + infix fun or(other: KBigInteger): KBigInteger { + if (this == ZERO) return other; + if (other == ZERO) return this; + val resSize = max(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + if (i < this.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or this.magnitude[i] + } + if (i < other.magnitude.size) { + newMagnitude[i] = newMagnitude[i] or other.magnitude[i] + } + } + return KBigInteger(1, newMagnitude) + } + + infix fun and(other: KBigInteger): KBigInteger { + if ((this == ZERO) or (other == ZERO)) return ZERO; + val resSize = min(this.magnitude.size, other.magnitude.size) + val newMagnitude: Magnitude = Magnitude(resSize) + for (i in 0 until resSize) { + newMagnitude[i] = this.magnitude[i] and other.magnitude[i] + } + return KBigInteger(1, newMagnitude) } operator fun rem(other: Int): Int { diff --git a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt index 10b89f28b..daece0f45 100644 --- a/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt +++ b/kmath-core/src/commonTest/kotlin/scientifik/kmath/operations/KBigIntegerTest.kt @@ -244,6 +244,62 @@ class KBigIntegerOperationsTest { assertEquals(prod, res) } + @Test + fun test_shr_20() { + val x = KBigInteger(20) + assertEquals(KBigInteger(10), x shr 1) + } + + @Test + fun test_shl_20() { + val x = KBigInteger(20) + assertEquals(KBigInteger(40), x shl 1) + } + + @Test + fun test_shl_1_0() { + assertEquals(KBigInteger.ONE, KBigInteger.ONE shl 0) + } + + @Test + fun test_shl_1_32() { + assertEquals(KBigInteger(0x100000000UL), KBigInteger.ONE shl 32) + } + + @Test + fun test_shl_1_33() { + assertEquals(KBigInteger(0x200000000UL), KBigInteger.ONE shl 33) + } + + @Test + fun test_shr_1_33_33() { + assertEquals(KBigInteger.ONE, (KBigInteger.ONE shl 33) shr 33) + } + + @Test + fun test_shr_1_32() { + assertEquals(KBigInteger.ZERO, KBigInteger.ONE shr 32) + } + + @Test + fun test_and_123_456() { + val x = KBigInteger(123) + val y = KBigInteger(456) + assertEquals(KBigInteger(72), x and y) + } + + @Test + fun test_or_123_456() { + val x = KBigInteger(123) + val y = KBigInteger(456) + assertEquals(KBigInteger(507), x or y) + } + + @Test + fun test_asd() { + assertEquals(KBigInteger.ONE, KBigInteger.ZERO or ((KBigInteger(20) shr 4) and KBigInteger.ONE)) + } + @Test fun test_square_0x11223344U_0xad2ffffdU_0x17eU() { val num = KBigInteger(-1, uintArrayOf(0x11223344U, 0xad2ffffdU, 0x17eU )) @@ -307,6 +363,12 @@ class KBigIntegerOperationsTest { assertEquals(div, res) } + @Test + fun testBigDivision_0xfffffffeabcdef01UL_0xfffffffeabc() { + val res = KBigInteger(0xfffffffeabcdef01UL) / KBigInteger(0xfffffffeabc) + assertEquals(res, KBigInteger(0x100000)) + } + @Test fun testBigDivision_0xfffffffe00000001_0xffffffff() { val x = KBigInteger(0xfffffffe00000001UL) @@ -379,7 +441,7 @@ class KBigIntegerOperationsTest { val exp = KBigInteger(2) val mod = KBigInteger(0xfffffffeabcUL) - val res = KBigInteger(0x6deec7895faUL) + val res = KBigInteger(0xc2253cde01) return assertEquals(res, x.modPow(exp, mod)) } @@ -434,26 +496,9 @@ class KBigIntegerConversionsTest { } } -@kotlin.ExperimentalUnsignedTypes -class DivisionTests { - // TODO - @Test - fun test_0xfffffffeabcdef01UL_0xfffffffeabc() { - val res = KBigInteger(0xfffffffeabcdef01UL) / KBigInteger(0xfffffffeabc) - assertEquals(res, KBigInteger(0x100000)) - - } - -// println(KBigInteger(+1, uintArrayOf(1000U, 1000U, 1000U)) / KBigInteger(0xfffffffeabc) ) - // >>> hex((1000 + 1000*2**32 + 1000*2**64)/ 0xfffffffeabc) == 0x3e800000L -// println(KBigInteger(+1, KBigInteger.divideMagnitudeByUInt(uintArrayOf(1000U, 1000U, 1000U),456789U))) - // 0x8f789719813969L - -} - class KBigIntegerRingTest { @Test - fun testSum() { + fun testKBigIntegerRingSum() { val res = KBigIntegerRing { KBigInteger(1_000L) * KBigInteger(1_000L) } @@ -461,7 +506,7 @@ class KBigIntegerRingTest { } @Test - fun test_sum_100_000_000__100_000_000() { + fun testKBigIntegerRingSum_100_000_000__100_000_000() { KBigIntegerRing { val sum = +"100_000_000" + +"100_000_000" assertEquals(sum, KBigInteger("200_000_000")) @@ -476,5 +521,23 @@ class KBigIntegerRingTest { } } + @Test + fun test_div_big_1() { + KBigIntegerRing { + val res = +"1_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000" / + +"555_000_444_000_333_000_222_000_111_000_999_001" + assertEquals(res, +"1801800360360432432518919022699") + } + } + + @Test + fun test_rem_big_1() { + KBigIntegerRing { + val res = +"1_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000" % + +"555_000_444_000_333_000_222_000_111_000_999_001" + assertEquals(res, +"324121220440768000291647788404676301") + } + } + }