Refactoring of KBigInteger

This commit is contained in:
Peter Klimai 2020-04-15 18:55:13 +03:00
parent 19d1459a55
commit 48cb683bc4
2 changed files with 259 additions and 254 deletions

View File

@ -26,45 +26,256 @@ object KBigIntegerRing: Ring<KBigInteger> {
override fun multiply(a: KBigInteger, b: KBigInteger): KBigInteger = a.times(b)
operator fun String.unaryPlus(): KBigInteger = KBigInteger(this)!!
operator fun String.unaryPlus(): KBigInteger = this.toKBigInteger()!!
operator fun String.unaryMinus(): KBigInteger = -KBigInteger(this)!!
operator fun String.unaryMinus(): KBigInteger = -this.toKBigInteger()!!
}
@kotlin.ExperimentalUnsignedTypes
class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
RingElement<KBigInteger, KBigInteger, KBigIntegerRing>, Comparable<KBigInteger> {
class KBigInteger internal constructor(
private val sign: Byte = 0,
private val magnitude: Magnitude = Magnitude(0)
): Comparable<KBigInteger> {
constructor(x: Int) : this(x.sign, uintArrayOf(kotlin.math.abs(x).toUInt()))
constructor(x: Long) : this(x.sign, uintArrayOf(
constructor(x: Int) : this(x.sign.toByte(), uintArrayOf(kotlin.math.abs(x).toUInt()))
constructor(x: Long) : this(x.sign.toByte(), stripLeadingZeros(uintArrayOf(
(kotlin.math.abs(x).toULong() and BASE).toUInt(),
((kotlin.math.abs(x).toULong() shr BASE_SIZE) and BASE).toUInt()))
((kotlin.math.abs(x).toULong() shr BASE_SIZE) and BASE).toUInt())))
val magnitude = stripLeadingZeros(magnitude)
val sign = if (this.magnitude.isNotEmpty()) sign else 0
val sizeByte: Int = magnitude.size * BASE_SIZE / 4
override fun compareTo(other: KBigInteger): Int {
return when {
(this.sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
this.sign < other.sign -> -1
this.sign > other.sign -> 1
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
}
}
override val context: KBigIntegerRing get() = KBigIntegerRing
override fun equals(other: Any?): Boolean {
if (other is KBigInteger) {
return this.compareTo(other) == 0
}
else error("Can't compare KBigInteger to a different type")
}
override fun unwrap(): KBigInteger = this
override fun KBigInteger.wrap(): KBigInteger = this
override fun hashCode(): Int {
return magnitude.hashCode() + this.sign
}
fun abs(): KBigInteger = if (sign == 0.toByte()) this else KBigInteger(1, magnitude)
operator fun unaryMinus(): KBigInteger {
return if (this.sign == 0.toByte()) this else KBigInteger((-this.sign).toByte(), this.magnitude)
}
operator fun plus(b: KBigInteger): KBigInteger {
return when {
b.sign == 0.toByte() -> this
this.sign == 0.toByte() -> b
this == -b -> ZERO
this.sign == b.sign -> KBigInteger(this.sign, addMagnitudes(this.magnitude, b.magnitude))
else -> {
val comp: Int = compareMagnitudes(this.magnitude, b.magnitude)
if (comp == 1) {
KBigInteger(this.sign, subtractMagnitudes(this.magnitude, b.magnitude))
} else {
KBigInteger((-this.sign).toByte(), subtractMagnitudes(b.magnitude, this.magnitude))
}
}
}
}
operator fun minus(b: KBigInteger): KBigInteger {
return this + (-b)
}
operator fun times(b: KBigInteger): KBigInteger {
return when {
this.sign == 0.toByte() -> ZERO
b.sign == 0.toByte() -> ZERO
// TODO: Karatsuba
else -> KBigInteger((this.sign * b.sign).toByte(), multiplyMagnitudes(this.magnitude, b.magnitude))
}
}
operator fun times(other: UInt): KBigInteger {
return when {
this.sign == 0.toByte() -> ZERO
other == 0U -> ZERO
else -> KBigInteger(this.sign, multiplyMagnitudeByUInt(this.magnitude, other))
}
}
operator fun times(other: Int): KBigInteger {
return if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
}
operator fun div(other: UInt): KBigInteger {
return KBigInteger(this.sign, divideMagnitudeByUInt(this.magnitude, other))
}
operator fun div(other: Int): KBigInteger {
return KBigInteger((this.sign * other.sign).toByte(),
divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt()))
}
private fun division(other: KBigInteger): Pair<KBigInteger, KBigInteger> {
// Long division algorithm:
// https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder
// TODO: Implement more effective algorithm
var q: KBigInteger = ZERO
var r: KBigInteger = ZERO
val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt()
for (i in bitSize downTo 0) {
r = r shl 1
r = r or ((abs(this) shr i) and ONE)
if (r >= abs(other)) {
r -= abs(other)
q += (ONE shl i)
}
}
return Pair(KBigInteger((this.sign * other.sign).toByte(), q.magnitude), r)
}
operator fun div(other: KBigInteger): KBigInteger {
return this.division(other).first
}
infix fun shl(i: Int): KBigInteger {
if (this == ZERO) return ZERO
if (i == 0) return this
val fullShifts = i / BASE_SIZE + 1
val relShift = i % BASE_SIZE
val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift}
val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)}
val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts)
for (j in this.magnitude.indices) {
newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j])
if (j != 0) {
newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1])
}
}
newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last())
return KBigInteger(this.sign, stripLeadingZeros(newMagnitude))
}
infix fun shr(i: Int): KBigInteger {
if (this == ZERO) return ZERO
if (i == 0) return this
val fullShifts = i / BASE_SIZE
val relShift = i % BASE_SIZE
val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift}
val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)}
if (this.magnitude.size - fullShifts <= 0) {
return ZERO
}
val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts)
for (j in fullShifts until this.magnitude.size) {
newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j])
if (j != this.magnitude.size - 1) {
newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1])
}
}
return KBigInteger(this.sign, stripLeadingZeros(newMagnitude))
}
infix fun or(other: KBigInteger): KBigInteger {
if (this == ZERO) return other;
if (other == ZERO) return this;
val resSize = max(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) {
if (i < this.magnitude.size) {
newMagnitude[i] = newMagnitude[i] or this.magnitude[i]
}
if (i < other.magnitude.size) {
newMagnitude[i] = newMagnitude[i] or other.magnitude[i]
}
}
return KBigInteger(1, stripLeadingZeros(newMagnitude))
}
infix fun and(other: KBigInteger): KBigInteger {
if ((this == ZERO) or (other == ZERO)) return ZERO;
val resSize = min(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) {
newMagnitude[i] = this.magnitude[i] and other.magnitude[i]
}
return KBigInteger(1, stripLeadingZeros(newMagnitude))
}
operator fun rem(other: Int): Int {
val res = this - (this / other) * other
return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt()
}
operator fun rem(other: KBigInteger): KBigInteger {
return this - (this / other) * other
}
fun modPow(exponent: KBigInteger, m: KBigInteger): KBigInteger {
return when {
exponent == ZERO -> ONE
exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m
else -> {
val sqRoot = modPow(exponent / 2, m)
(sqRoot * sqRoot) % m
}
}
}
override fun toString(): String {
if (this.sign == 0.toByte()) {
return "0x0"
}
var res: String = if (this.sign == (-1).toByte()) "-0x" else "0x"
var numberStarted = false
for (i in this.magnitude.size - 1 downTo 0) {
for (j in BASE_SIZE / 4 - 1 downTo 0) {
val curByte = (this.magnitude[i] shr 4 * j) and 0xfU
if (numberStarted or (curByte != 0U)) {
numberStarted = true
res += hexMapping[curByte]
}
}
}
return res
}
companion object {
val BASE = 0xffffffffUL
const val BASE = 0xffffffffUL
const val BASE_SIZE: Int = 32
val ZERO: KBigInteger = KBigInteger()
val ONE: KBigInteger = KBigInteger(1)
private val hexMapping: HashMap<UInt, String> =
hashMapOf(
0U to "0", 1U to "1", 2U to "2", 3U to "3", 4U to "4", 5U to "5", 6U to "6", 7U to "7", 8U to "8",
9U to "9", 10U to "a", 11U to "b", 12U to "c", 13U to "d", 14U to "e", 15U to "f"
0U to "0", 1U to "1", 2U to "2", 3U to "3",
4U to "4", 5U to "5", 6U to "6", 7U to "7",
8U to "8", 9U to "9", 10U to "a", 11U to "b",
12U to "c", 13U to "d", 14U to "e", 15U to "f"
)
private fun stripLeadingZeros(mag: Magnitude): Magnitude {
// TODO: optimize performance
if (mag.isEmpty()) {
internal fun stripLeadingZeros(mag: Magnitude): Magnitude {
if (mag.isEmpty() || mag.last() != 0U) {
return mag
}
var resSize: Int = mag.size - 1
@ -108,7 +319,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
carry = (res shr BASE_SIZE)
}
result[resultLength - 1] = carry.toUInt()
return result
return stripLeadingZeros(result)
}
private fun subtractMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
@ -127,7 +338,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
result[i] = res.toUInt()
}
return result
return stripLeadingZeros(result)
}
private fun multiplyMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
@ -142,7 +353,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
}
result[resultLength - 1] = (carry and BASE).toUInt()
return result
return stripLeadingZeros(result)
}
private fun multiplyMagnitudes(mag1: Magnitude, mag2: Magnitude): Magnitude {
@ -159,7 +370,7 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
result[i + mag2.size] = (carry and BASE).toUInt()
}
return result
return stripLeadingZeros(result)
}
private fun divideMagnitudeByUInt(mag: Magnitude, x: UInt): Magnitude {
@ -172,230 +383,15 @@ class KBigInteger(sign: Int = 0, magnitude: Magnitude = Magnitude(0)):
result[i] = (cur / x).toUInt()
carry = cur % x
}
return result
return stripLeadingZeros(result)
}
}
override fun compareTo(other: KBigInteger): Int {
return when {
(this.sign == 0) and (other.sign == 0) -> 0
this.sign < other.sign -> -1
this.sign > other.sign -> 1
else -> this.sign * compareMagnitudes(this.magnitude, other.magnitude)
}
}
override fun equals(other: Any?): Boolean {
if (other is KBigInteger) {
return this.compareTo(other) == 0
}
else error("Can't compare KBigInteger to a different type")
}
override fun hashCode(): Int {
return magnitude.hashCode() + this.sign
}
operator fun unaryMinus(): KBigInteger {
return if (this.sign == 0) this else KBigInteger(-this.sign, this.magnitude)
}
override operator fun plus(b: KBigInteger): KBigInteger {
return when {
b.sign == 0 -> this
this.sign == 0 -> b
this == -b -> ZERO
this.sign == b.sign -> KBigInteger(this.sign, addMagnitudes(this.magnitude, b.magnitude))
else -> {
val comp: Int = compareMagnitudes(this.magnitude, b.magnitude)
if (comp == 1) {
KBigInteger(this.sign, subtractMagnitudes(this.magnitude, b.magnitude))
} else {
KBigInteger(-this.sign, subtractMagnitudes(b.magnitude, this.magnitude))
}
}
}
}
override operator fun minus(b: KBigInteger): KBigInteger {
return this + (-b)
}
override operator fun times(b: KBigInteger): KBigInteger {
return when {
this.sign == 0 -> ZERO
b.sign == 0 -> ZERO
// TODO: Karatsuba
else -> KBigInteger(this.sign * b.sign, multiplyMagnitudes(this.magnitude, b.magnitude))
}
}
operator fun times(other: UInt): KBigInteger {
return when {
this.sign == 0 -> ZERO
other == 0U -> ZERO
else -> KBigInteger(this.sign, multiplyMagnitudeByUInt(this.magnitude, other))
}
}
operator fun times(other: Int): KBigInteger {
return if (other > 0)
this * kotlin.math.abs(other).toUInt()
else
-this * kotlin.math.abs(other).toUInt()
}
operator fun div(other: UInt): KBigInteger {
return KBigInteger(this.sign, divideMagnitudeByUInt(this.magnitude, other))
}
operator fun div(other: Int): KBigInteger {
return KBigInteger(this.sign * other.sign, divideMagnitudeByUInt(this.magnitude, kotlin.math.abs(other).toUInt()))
}
private fun division(other: KBigInteger): Pair<KBigInteger, KBigInteger> {
// Long division algorithm:
// https://en.wikipedia.org/wiki/Division_algorithm#Integer_division_(unsigned)_with_remainder
// TODO: Implement more effective algorithm
var q: KBigInteger = ZERO
var r: KBigInteger = ZERO
val bitSize = (BASE_SIZE * (this.magnitude.size - 1) + log2(this.magnitude.last().toFloat() + 1)).toInt()
for (i in bitSize downTo 0) {
r = r shl 1
r = r or ((abs(this) shr i) and ONE)
if (r >= abs(other)) {
r -= abs(other)
q += (ONE shl i)
}
}
return Pair(KBigInteger(this.sign * other.sign, q.magnitude), r)
}
operator fun div(other: KBigInteger): KBigInteger {
return this.division(other).first
}
infix fun shl(i: Int): KBigInteger {
if (this == ZERO) return ZERO
if (i == 0) return this
val fullShifts = i / BASE_SIZE + 1
val relShift = i % BASE_SIZE
val shiftLeft = {x: UInt -> if (relShift >= 32) 0U else x shl relShift}
val shiftRight = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shr (BASE_SIZE - relShift)}
val newMagnitude: Magnitude = Magnitude(this.magnitude.size + fullShifts)
for (j in this.magnitude.indices) {
newMagnitude[j + fullShifts - 1] = shiftLeft(this.magnitude[j])
if (j != 0) {
newMagnitude[j + fullShifts - 1] = newMagnitude[j + fullShifts - 1] or shiftRight(this.magnitude[j - 1])
}
}
newMagnitude[this.magnitude.size + fullShifts - 1] = shiftRight(this.magnitude.last())
return KBigInteger(this.sign, newMagnitude)
}
infix fun shr(i: Int): KBigInteger {
if (this == ZERO) return ZERO
if (i == 0) return this
val fullShifts = i / BASE_SIZE
val relShift = i % BASE_SIZE
val shiftRight = {x: UInt -> if (relShift >= 32) 0U else x shr relShift}
val shiftLeft = {x: UInt -> if (BASE_SIZE - relShift >= 32) 0U else x shl (BASE_SIZE - relShift)}
if (this.magnitude.size - fullShifts <= 0) {
return ZERO
}
val newMagnitude: Magnitude = Magnitude(this.magnitude.size - fullShifts)
for (j in fullShifts until this.magnitude.size) {
newMagnitude[j - fullShifts] = shiftRight(this.magnitude[j])
if (j != this.magnitude.size - 1) {
newMagnitude[j - fullShifts] = newMagnitude[j - fullShifts] or shiftLeft(this.magnitude[j + 1])
}
}
return KBigInteger(this.sign, newMagnitude)
}
infix fun or(other: KBigInteger): KBigInteger {
if (this == ZERO) return other;
if (other == ZERO) return this;
val resSize = max(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) {
if (i < this.magnitude.size) {
newMagnitude[i] = newMagnitude[i] or this.magnitude[i]
}
if (i < other.magnitude.size) {
newMagnitude[i] = newMagnitude[i] or other.magnitude[i]
}
}
return KBigInteger(1, newMagnitude)
}
infix fun and(other: KBigInteger): KBigInteger {
if ((this == ZERO) or (other == ZERO)) return ZERO;
val resSize = min(this.magnitude.size, other.magnitude.size)
val newMagnitude: Magnitude = Magnitude(resSize)
for (i in 0 until resSize) {
newMagnitude[i] = this.magnitude[i] and other.magnitude[i]
}
return KBigInteger(1, newMagnitude)
}
operator fun rem(other: Int): Int {
val res = this - (this / other) * other
return if (res == ZERO) 0 else res.sign * res.magnitude[0].toInt()
}
operator fun rem(other: KBigInteger): KBigInteger {
return this - (this / other) * other
}
fun modPow(exponent: KBigInteger, m: KBigInteger): KBigInteger {
return when {
exponent == ZERO -> ONE
exponent % 2 == 1 -> (this * modPow(exponent - ONE, m)) % m
else -> {
val sqRoot = modPow(exponent / 2, m)
(sqRoot * sqRoot) % m
}
}
}
override fun toString(): String {
if (this.sign == 0) {
return "0x0"
}
var res: String = if (this.sign == -1) "-0x" else "0x"
var numberStarted = false
for (i in this.magnitude.size - 1 downTo 0) {
for (j in BASE_SIZE / 4 - 1 downTo 0) {
val curByte = (this.magnitude[i] shr 4 * j) and 0xfU
if (numberStarted or (curByte != 0U)) {
numberStarted = true
res += hexMapping[curByte]
}
}
}
return res
}
}
@kotlin.ExperimentalUnsignedTypes
fun abs(x: KBigInteger): KBigInteger {
return if (x.sign == 0) x else KBigInteger(1, x.magnitude)
}
fun abs(x: KBigInteger): KBigInteger = x.abs()
@kotlin.ExperimentalUnsignedTypes
// Can't put it as constructor in class due to platform declaration clash with KBigInteger(Int)
@ -405,26 +401,35 @@ fun KBigInteger(x: UInt): KBigInteger
@kotlin.ExperimentalUnsignedTypes
// Can't put it as constructor in class due to platform declaration clash with KBigInteger(Long)
fun KBigInteger(x: ULong): KBigInteger
= KBigInteger(1, uintArrayOf((x and KBigInteger.BASE).toUInt(), ((x shr KBigInteger.BASE_SIZE) and KBigInteger.BASE).toUInt()))
= KBigInteger(1,
KBigInteger.stripLeadingZeros(uintArrayOf(
(x and KBigInteger.BASE).toUInt(),
((x shr KBigInteger.BASE_SIZE) and KBigInteger.BASE).toUInt())
)
)
val hexChToInt = hashMapOf('0' to 0, '1' to 1, '2' to 2, '3' to 3, '4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11, 'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15)
val hexChToInt = hashMapOf(
'0' to 0, '1' to 1, '2' to 2, '3' to 3,
'4' to 4, '5' to 5, '6' to 6, '7' to 7,
'8' to 8, '9' to 9, 'A' to 10, 'B' to 11,
'C' to 12, 'D' to 13, 'E' to 14, 'F' to 15
)
// Returns None if a valid number can not be read from a string
fun KBigInteger(s: String): KBigInteger? {
fun String.toKBigInteger(): KBigInteger? {
val sign: Int
val sPositive: String
when {
s[0] == '+' -> {
this[0] == '+' -> {
sign = +1
sPositive = s.substring(1)
sPositive = this.substring(1)
}
s[0] == '-' -> {
this[0] == '-' -> {
sign = -1
sPositive = s.substring(1)
sPositive = this.substring(1)
}
else -> {
sPositive = s
sPositive = this
sign = +1
}
}

View File

@ -485,13 +485,13 @@ class KBigIntegerConversionsTest {
@Test
fun testFromString_0x17ead2ffffd11223344() {
val x = KBigInteger("0x17ead2ffffd11223344")!!
val x = "0x17ead2ffffd11223344".toKBigInteger()
assertEquals( "0x17ead2ffffd11223344", x.toString())
}
@Test
fun testFromString_7059135710711894913860() {
val x = KBigInteger("-7059135710711894913860")
val x = "-7059135710711894913860".toKBigInteger()
assertEquals("-0x17ead2ffffd11223344", x.toString())
}
}
@ -509,7 +509,7 @@ class KBigIntegerRingTest {
fun testKBigIntegerRingSum_100_000_000__100_000_000() {
KBigIntegerRing {
val sum = +"100_000_000" + +"100_000_000"
assertEquals(sum, KBigInteger("200_000_000"))
assertEquals(sum, "200_000_000".toKBigInteger())
}
}
@ -517,7 +517,7 @@ class KBigIntegerRingTest {
fun test_mul_3__4() {
KBigIntegerRing {
val prod = +"0x3000_0000_0000" * +"0x4000_0000_0000_0000_0000"
assertEquals(prod, KBigInteger("0xc00_0000_0000_0000_0000_0000_0000_0000"))
assertEquals(prod, "0xc00_0000_0000_0000_0000_0000_0000_0000".toKBigInteger())
}
}