From abcde808dc4c0f9515d2cbad1fa5ddda2b4bca31 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sun, 4 Oct 2020 20:17:44 +0700 Subject: [PATCH] Add first working test, use kotlinx-io fork, major rework of GSL API --- build.gradle.kts | 1 + examples/build.gradle.kts | 1 - .../kscience/kmath/linear/MatrixContext.kt | 9 +- .../kscience/kmath/ejml/EjmlMatrixContext.kt | 9 +- kmath-gsl/build.gradle.kts | 2 + .../kotlin/kscience/kmath/gsl/GslComplex.kt | 12 + .../kotlin/kscience/kmath/gsl/GslMatrices.kt | 290 ++++++++++++------ .../kscience/kmath/gsl/GslMatrixContexts.kt | 57 ++++ .../kscience/kmath/gsl/GslMemoryHolder.kt | 9 + .../kotlin/kscience/kmath/gsl/GslVectors.kt | 71 +++-- .../kscience/kmath/gsl/nativeUtilities.kt | 8 - kmath-gsl/src/nativeTest/kotlin/RealTest.kt | 15 + 12 files changed, 348 insertions(+), 136 deletions(-) create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt delete mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/nativeUtilities.kt create mode 100644 kmath-gsl/src/nativeTest/kotlin/RealTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index 05e2d5979..92ad16b85 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,6 +12,7 @@ allprojects { maven("https://dl.bintray.com/kotlin/kotlin-eap") maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/hotkeytlt/maven") + maven("https://dl.bintray.com/commandertvis/kotlinx-io/") } group = "kscience.kmath" diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index c8b173d1c..a1cc4b6a2 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -26,7 +26,6 @@ dependencies { implementation(project(":kmath-viktor")) implementation(project(":kmath-dimensions")) implementation(project(":kmath-ejml")) - implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") implementation("org.slf4j:slf4j-simple:1.7.30") "benchmarksImplementation"("org.jetbrains.kotlinx:kotlinx.benchmark.runtime-jvm:0.2.0-dev-8") diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index f4dbce89a..300b963a6 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -18,10 +18,11 @@ public interface MatrixContext : SpaceOperations> { */ public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): Matrix - public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = when (operation) { - "dot" -> left dot right - else -> super.binaryOperation(operation, left, right) - } + public override fun binaryOperation(operation: String, left: Matrix, right: Matrix): Matrix = + when (operation) { + "dot" -> left dot right + else -> super.binaryOperation(operation, left, right) + } /** * Computes the dot product of this matrix and another one. diff --git a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt index 52826a7b1..64dc605f1 100644 --- a/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt +++ b/kmath-ejml/src/main/kotlin/kscience/kmath/ejml/EjmlMatrixContext.kt @@ -1,11 +1,11 @@ package kscience.kmath.ejml -import org.ejml.simple.SimpleMatrix import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.Point import kscience.kmath.operations.Space import kscience.kmath.operations.invoke import kscience.kmath.structures.Matrix +import org.ejml.simple.SimpleMatrix /** * Represents context of basic operations operating with [EjmlMatrix]. @@ -29,8 +29,8 @@ public class EjmlMatrixContext(private val space: Space) : MatrixContext override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> Double): EjmlMatrix = EjmlMatrix(SimpleMatrix(rows, columns).also { - (0 until it.numRows()).forEach { row -> - (0 until it.numCols()).forEach { col -> it[row, col] = initializer(row, col) } + (0 until rows).forEach { row -> + (0 until columns).forEach { col -> it[row, col] = initializer(row, col) } } }) @@ -49,7 +49,8 @@ public class EjmlMatrixContext(private val space: Space) : MatrixContext public override fun multiply(a: Matrix, k: Number): EjmlMatrix = produce(a.rowNum, a.colNum) { i, j -> space { a[i, j] * k } } - public override operator fun Matrix.times(value: Double): EjmlMatrix = EjmlMatrix(toEjml().origin.scale(value)) + public override operator fun Matrix.times(value: Double): EjmlMatrix = + EjmlMatrix(toEjml().origin.scale(value)) public companion object } diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index b73b1c82c..43538a35d 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -8,6 +8,7 @@ kotlin { val nativeTarget = when (System.getProperty("os.name")) { "Mac OS X" -> macosX64("native") "Linux" -> linuxX64("native") + else -> { logger.warn("Current OS cannot build any of kmath-gsl targets.") return@kotlin @@ -26,6 +27,7 @@ kotlin { sourceSets.commonMain { dependencies { api(project(":kmath-core")) + api("org.jetbrains.kotlinx:kotlinx-io:0.2.0-tvis-3") } } } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt new file mode 100644 index 000000000..0282afb54 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt @@ -0,0 +1,12 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.* +import kscience.kmath.operations.Complex +import org.gnu.gsl.gsl_complex + +internal fun CValue.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) } + +internal fun Complex.toGsl(): CValue = cValue { + dat[0] = re + dat[1] = im +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt index 95f918209..5eb8794f2 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt @@ -7,7 +7,10 @@ import kscience.kmath.operations.Complex import kscience.kmath.structures.NDStructure import org.gnu.gsl.* -public sealed class GslMatrix : StructHolder(), FeaturedMatrix { +public sealed class GslMatrix : GslMemoryHolder(), FeaturedMatrix { + internal abstract operator fun set(i: Int, j: Int, value: T) + internal abstract fun copy(): GslMatrix + public override fun equals(other: Any?): Boolean { return NDStructure.equals(this, other as? NDStructure<*> ?: return false) } @@ -19,215 +22,318 @@ public sealed class GslMatrix : StructHolder(), FeaturedMatrix { } } -public class GslRealMatrix(protected override val nativeHandle: CValues, features: Set) : - GslMatrix() { - - public override val rowNum: Int +internal class GslRealMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix = GslRealMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Double): Unit = + gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslRealMatrix = memScoped { + val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_memcpy(new, nativeHandle) + GslRealMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslRealMatrix) gsl_matrix_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslFloatMatrix( - protected override val nativeHandle: CValues, - features: Set -) : - GslMatrix() { - public override val rowNum: Int +internal class GslFloatMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix = GslFloatMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Float = gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): Float = + gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Float): Unit = + gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslFloatMatrix = memScoped { + val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_float_memcpy(new, nativeHandle) + GslFloatMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_float_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslFloatMatrix) gsl_matrix_float_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslIntMatrix(protected override val nativeHandle: CValues, features: Set) : - GslMatrix() { +internal class GslIntMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix = GslIntMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Int): Unit = + gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslIntMatrix = memScoped { + val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_int_memcpy(new, nativeHandle) + GslIntMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_int_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslIntMatrix) gsl_matrix_int_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslUIntMatrix( - protected override val nativeHandle: CValues, - features: Set -) : GslMatrix() { +internal class GslUIntMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix = GslUIntMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: UInt): Unit = + gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslUIntMatrix = memScoped { + val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_uint_memcpy(new, nativeHandle) + GslUIntMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_uint_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslUIntMatrix) gsl_matrix_uint_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslLongMatrix( - protected override val nativeHandle: CValues, - features: Set -) : - GslMatrix() { +internal class GslLongMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix = GslLongMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Long): Unit = + gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslLongMatrix = memScoped { + val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_long_memcpy(new, nativeHandle) + GslLongMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_long_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslLongMatrix) gsl_matrix_long_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslULongMatrix( - protected override val nativeHandle: CValues, - features: Set -) : GslMatrix() { +internal class GslULongMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix = GslULongMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): ULong = gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): ULong = + gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: ULong): Unit = + gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslULongMatrix = memScoped { + val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_ulong_memcpy(new, nativeHandle) + GslULongMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_ulong_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslULongMatrix) gsl_matrix_ulong_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslShortMatrix( - protected override val nativeHandle: CValues, - features: Set -) : GslMatrix() { +internal class GslShortMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix = GslShortMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Short = gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): Short = + gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Short): Unit = + gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslShortMatrix = memScoped { + val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_short_memcpy(new, nativeHandle) + GslShortMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_short_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslShortMatrix) gsl_matrix_short_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslUShortMatrix( - protected override val nativeHandle: CValues, - features: Set -) : GslMatrix() { +internal class GslUShortMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { - public override val rowNum: Int + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix = GslUShortMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): UShort = gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong()) + override operator fun get(i: Int, j: Int): UShort = + gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong()) - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: UShort): Unit = + gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): GslUShortMatrix = memScoped { + val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_ushort_memcpy(new, nativeHandle) + GslUShortMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_ushort_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslUShortMatrix) gsl_matrix_ushort_equal(nativeHandle, other.nativeHandle) return super.equals(other) } } -public class GslComplexMatrix( - protected override val nativeHandle: CValues, - features: Set -) : GslMatrix() { - - public override val rowNum: Int +internal class GslComplexMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { + override val rowNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - public override val colNum: Int + override val colNum: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - public override val features: Set = features + override val features: Set = features - public override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix = + override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix = GslComplexMatrix(nativeHandle, this.features + features) - public override fun get(i: Int, j: Int): Complex = - gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).useContents { Complex(dat[0], dat[1]) } + override operator fun get(i: Int, j: Int): Complex = + gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).toKMath() - public override fun equals(other: Any?): Boolean { + override operator fun set(i: Int, j: Int, value: Complex): Unit = + gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl()) + + override fun copy(): GslComplexMatrix = memScoped { + val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_complex_memcpy(new, nativeHandle) + GslComplexMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_complex_free(nativeHandle) + + override fun equals(other: Any?): Boolean { if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) return super.equals(other) } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt new file mode 100644 index 000000000..a27a22110 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt @@ -0,0 +1,57 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.CStructVar +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.operations.invoke +import kscience.kmath.structures.Matrix +import org.gnu.gsl.* + +private inline fun GslMatrix.fill(initializer: (Int, Int) -> T): GslMatrix = + apply { + (0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } } + } + +public sealed class GslMatrixContext : MatrixContext { + @Suppress("UNCHECKED_CAST") + public fun Matrix.toGsl(): GslMatrix = + (if (this is GslMatrix<*, *>) this as GslMatrix else produce(rowNum, colNum) { i, j -> get(i, j) }).copy() + + internal abstract fun produceDirty(rows: Int, columns: Int): GslMatrix + + public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix = + produceDirty(rows, columns).fill(initializer) +} + +public object GslRealMatrixContext : GslMatrixContext() { + public override fun produceDirty(rows: Int, columns: Int): GslMatrix = + GslRealMatrix(requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong()))) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val g1 = toGsl() + gsl_matrix_mul_elements(g1.nativeHandle, other.toGsl().nativeHandle) + return g1 + } + + public override fun Matrix.dot(vector: Point): GslVector { + TODO() + } + + public override fun Matrix.times(value: Double): GslMatrix { + val g1 = toGsl() + gsl_matrix_scale(g1.nativeHandle, value) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_scale(g1.nativeHandle, k.toDouble()) + return g1 + } +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt new file mode 100644 index 000000000..4c6aa17d7 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt @@ -0,0 +1,9 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.CStructVar +import kotlinx.io.Closeable + +public abstract class GslMemoryHolder internal constructor() : Closeable { + internal abstract val nativeHandle: CPointer +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt index fc6e8d329..3806a2e1c 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt @@ -1,11 +1,14 @@ package kscience.kmath.gsl -import kotlinx.cinterop.* +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.CStructVar +import kotlinx.cinterop.memScoped +import kotlinx.cinterop.pointed import kscience.kmath.linear.Point import kscience.kmath.operations.Complex import org.gnu.gsl.* -public sealed class GslVector : StructHolder(), Point { +public sealed class GslVector : GslMemoryHolder(), Point { public override fun iterator(): Iterator = object : Iterator { private var cursor = 0 @@ -18,67 +21,81 @@ public sealed class GslVector : StructHolder(), Point { } } -public class GslRealVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslRealVector(override val nativeHandle: CPointer) : GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): Double = gsl_vector_get(nativeHandle, index.toULong()) + override fun get(index: Int): Double = gsl_vector_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_free(nativeHandle) } -public class GslFloatVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslFloatVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): Float = gsl_vector_float_get(nativeHandle, index.toULong()) + override fun get(index: Int): Float = gsl_vector_float_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_float_free(nativeHandle) } -public class GslIntVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslIntVector(override val nativeHandle: CPointer) : GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): Int = gsl_vector_int_get(nativeHandle, index.toULong()) + override fun get(index: Int): Int = gsl_vector_int_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_int_free(nativeHandle) } -public class GslUIntVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslUIntVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): UInt = gsl_vector_uint_get(nativeHandle, index.toULong()) + override fun get(index: Int): UInt = gsl_vector_uint_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_uint_free(nativeHandle) } -public class GslLongVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslLongVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): Long = gsl_vector_long_get(nativeHandle, index.toULong()) + override fun get(index: Int): Long = gsl_vector_long_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_long_free(nativeHandle) } -public class GslULongVector(override val nativeHandle: CValues) : GslVector() { +internal class GslULongVector(override val nativeHandle: CPointer) : + GslVector() { public override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } public override fun get(index: Int): ULong = gsl_vector_ulong_get(nativeHandle, index.toULong()) + public override fun close(): Unit = gsl_vector_ulong_free(nativeHandle) } -public class GslShortVector(override val nativeHandle: CValues) : GslVector() { +internal class GslShortVector(override val nativeHandle: CPointer) : + GslVector() { public override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } public override fun get(index: Int): Short = gsl_vector_short_get(nativeHandle, index.toULong()) + public override fun close(): Unit = gsl_vector_short_free(nativeHandle) } -public class GslUShortVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslUShortVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): UShort = gsl_vector_ushort_get(nativeHandle, index.toULong()) + override fun get(index: Int): UShort = gsl_vector_ushort_get(nativeHandle, index.toULong()) + override fun close(): Unit = gsl_vector_ushort_free(nativeHandle) } -public class GslComplexVector(override val nativeHandle: CValues) : GslVector() { - public override val size: Int +internal class GslComplexVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - public override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).useContents { - Complex(dat[0], dat[1]) - } + override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath() + override fun close(): Unit = gsl_vector_complex_free(nativeHandle) } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/nativeUtilities.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/nativeUtilities.kt deleted file mode 100644 index fed31dcaa..000000000 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/nativeUtilities.kt +++ /dev/null @@ -1,8 +0,0 @@ -package kscience.kmath.gsl - -import kotlinx.cinterop.CStructVar -import kotlinx.cinterop.CValues - -public abstract class StructHolder internal constructor() { - protected abstract val nativeHandle: CValues -} diff --git a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt new file mode 100644 index 000000000..8fa664336 --- /dev/null +++ b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt @@ -0,0 +1,15 @@ +package kscience.kmath.gsl + +import kotlinx.io.use +import kscience.kmath.operations.invoke +import kotlin.test.Test +import kotlin.test.assertEquals + +internal class RealTest { + @Test + fun testScale() = GslRealMatrixContext { + GslRealMatrixContext.produce(10, 10) { _, _ -> 0.1 }.use { ma -> + (ma * 20.0).use { mb -> assertEquals(mb[0, 1], 2.0) } + } + } +}