From 0a02fd07b478359cafeb4529beb843f79ce40595 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 5 Oct 2020 13:13:06 +0700 Subject: [PATCH] Replace manual matrices classes building with codegen --- build.gradle.kts | 6 +- buildSrc/build.gradle.kts | 9 + .../kscience/kmath/gsl/codegen/PsiTestUtil.kt | 69 ++++ .../kmath/gsl/codegen/StringUtilExt.kt | 9 + .../kmath/gsl/codegen/matricesCodegen.kt | 98 +++++ kmath-gsl/build.gradle.kts | 18 +- .../kotlin/kscience/kmath/gsl/GslMatrices.kt | 340 ------------------ .../kotlin/kscience/kmath/gsl/GslMatrix.kt | 21 ++ .../kotlin/kscience/kmath/gsl/GslVectors.kt | 2 +- kmath-gsl/src/nativeTest/kotlin/RealTest.kt | 9 +- 10 files changed, 229 insertions(+), 352 deletions(-) create mode 100644 buildSrc/build.gradle.kts create mode 100644 buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt create mode 100644 buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt create mode 100644 buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt delete mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt diff --git a/build.gradle.kts b/build.gradle.kts index 92ad16b85..22ba35d18 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -2,9 +2,9 @@ plugins { id("ru.mipt.npm.project") } -val kmathVersion: String by extra("0.2.0-dev-2") -val bintrayRepo: String by extra("kscience") -val githubProject: String by extra("kmath") +internal val kmathVersion: String by extra("0.2.0-dev-2") +internal val bintrayRepo: String by extra("kscience") +internal val githubProject: String by extra("kmath") allprojects { repositories { diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts new file mode 100644 index 000000000..d0e424c37 --- /dev/null +++ b/buildSrc/build.gradle.kts @@ -0,0 +1,9 @@ +plugins { + `kotlin-dsl` +} + +repositories.jcenter() + +dependencies { + implementation(kotlin("compiler-embeddable")) +} diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt new file mode 100644 index 000000000..edc244bb3 --- /dev/null +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt @@ -0,0 +1,69 @@ +package kscience.kmath.gsl.codegen + +import org.jetbrains.kotlin.com.intellij.openapi.util.text.StringUtil +import org.jetbrains.kotlin.com.intellij.psi.PsiFile +import org.jetbrains.kotlin.com.intellij.psi.impl.DebugUtil +import org.jetbrains.kotlin.com.intellij.psi.impl.source.PsiFileImpl +import org.jetbrains.kotlin.com.intellij.testFramework.LightVirtualFile +import org.jetbrains.kotlin.psi.KtFile +import kotlin.math.min + +internal object PsiTestUtil { + fun checkFileStructure(file: KtFile) { + compareFromAllRoots(file) { f -> DebugUtil.psiTreeToString(f, false) } + } + + private fun compareFromAllRoots( + file: KtFile, + function: (PsiFile) -> String + ) { + val dummyFile = createDummyCopy(file) + + val psiTree = StringUtil.join( + file.viewProvider.allFiles, + { param -> function(param) }, + "\n" + ) + + val reparsedTree = StringUtil.join( + dummyFile.viewProvider.allFiles, + { param -> function(param) }, + "\n" + ) + + assertPsiTextTreeConsistency(psiTree, reparsedTree) + } + + private fun assertPsiTextTreeConsistency(psiTree: String, reparsedTree: String) { + var psiTreeMutable = psiTree + var reparsedTreeMutable = reparsedTree + + if (psiTreeMutable != reparsedTreeMutable) { + val psiLines = splitByLinesDontTrim(psiTreeMutable) + val reparsedLines = splitByLinesDontTrim(reparsedTreeMutable) + var i = 0 + + while (true) { + if (i >= psiLines.size || i >= reparsedLines.size || psiLines[i] != reparsedLines[i]) { + psiLines[min(i, psiLines.size - 1)] += " // in PSI structure" + reparsedLines[min(i, reparsedLines.size - 1)] += " // re-created from text" + break + } + + i++ + } + + psiTreeMutable = StringUtil.join(psiLines, "\n") + reparsedTreeMutable = StringUtil.join(reparsedLines, "\n") + assert(reparsedTreeMutable == psiTreeMutable) + } + } + + private fun createDummyCopy(file: KtFile): PsiFile { + val copy = LightVirtualFile(file.name, file.text) + copy.originalFile = file.viewProvider.virtualFile + val dummyCopy = requireNotNull(file.manager.findFile(copy)) + if (dummyCopy is PsiFileImpl) dummyCopy.originalFile = file + return dummyCopy + } +} \ No newline at end of file diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt new file mode 100644 index 000000000..c56655b9f --- /dev/null +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt @@ -0,0 +1,9 @@ +package kscience.kmath.gsl.codegen + +import java.util.regex.Pattern + +private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+") + +internal fun splitByLinesDontTrim(string: String): Array { + return EOL_SPLIT_DONT_TRIM_PATTERN.split(string) +} diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt new file mode 100644 index 000000000..14aafffff --- /dev/null +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt @@ -0,0 +1,98 @@ +package kscience.kmath.gsl.codegen + +import org.jetbrains.kotlin.com.intellij.openapi.project.Project +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtPsiFactory +import org.jetbrains.kotlin.resolve.ImportPath +import java.io.File + +private fun fn(pattern: String, type: String): String { + if (type == "double") return pattern.replace("R", "_") + return pattern.replace("R", "_${type}_") +} + +private fun sn(pattern: String, type: String): String { + if (type == "double") return pattern.replace("R", "") + return pattern.replace("R", "_$type") +} + +private fun KtPsiFactory.createMatrixClass( + f: KtFile, + cTypeName: String, + kotlinTypeName: String, + kotlinTypeAlias: String = kotlinTypeName +) { + val className = "Gsl${kotlinTypeAlias}Matrix" + val structName = sn("gsl_matrixR", cTypeName) + + f += createClass( + """internal class $className( + override val nativeHandle: CPointer<$structName>, + features: Set = emptySet() +) : GslMatrix<$kotlinTypeName, $structName>() { + override val rowNum: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } + + override val colNum: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } + + override val features: Set = features + + override fun suggestFeature(vararg features: MatrixFeature): $className = + ${className}(nativeHandle, this.features + features) + + override operator fun get(i: Int, j: Int): $kotlinTypeName = ${ + fn("gsl_matrixRget", cTypeName) + }(nativeHandle, i.toULong(), j.toULong()) + + override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit = + ${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value) + + override fun copy(): $className = memScoped { + val new = requireNotNull(${fn("gsl_matrixRalloc", cTypeName)}(rowNum.toULong(), colNum.toULong())) + ${fn("gsl_matrixRmemcpy", cTypeName)}(new, nativeHandle) + $className(new, features) + } + + override fun close(): Unit = ${fn("gsl_matrixRfree", cTypeName)}(nativeHandle) + + override fun equals(other: Any?): Boolean { + if (other is $className) ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) + return super.equals(other) + } +}""" + ) + + f += createNewLine(2) +} + +fun matricesCodegen(outputFile: String, project: Project = createProject()) { + val f = KtPsiFactory(project, true).run { + createFile("package kscience.kmath.gsl").also { f -> + f += createNewLine(2) + f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) + f += createNewLine(1) + f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*")) + f += createNewLine(1) + f += createImportDirective(ImportPath.fromString("kscience.kmath.operations.*")) + f += createNewLine(1) + f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*")) + f += createNewLine(2) + createMatrixClass(f, "double", "Double", "Real") + createMatrixClass(f, "float", "Float") + createMatrixClass(f, "short", "Short") + createMatrixClass(f, "ushort", "UShort") + createMatrixClass(f, "long", "Long") + createMatrixClass(f, "ulong", "ULong") + createMatrixClass(f, "int", "Int") + createMatrixClass(f, "uint", "UInt") + } + } + + PsiTestUtil.checkFileStructure(f) + + File(outputFile).apply { + parentFile.mkdirs() + writeText(f.text) + } +} diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index 43538a35d..9ff2c30d8 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -1,5 +1,7 @@ @file:Suppress("UNUSED_VARIABLE") +import kscience.kmath.gsl.codegen.matricesCodegen + plugins { id("ru.mipt.npm.mpp") } @@ -24,10 +26,18 @@ kotlin { } } - sourceSets.commonMain { - dependencies { - api(project(":kmath-core")) - api("org.jetbrains.kotlinx:kotlinx-io:0.2.0-tvis-3") + sourceSets { + val nativeMain by getting { + dependencies { + api(project(":kmath-core")) + api("org.jetbrains.kotlinx:kotlinx-io:0.2.0-tvis-3") + } } } } + +internal val codegen: Task by tasks.creating { + matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt") +} + +kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen)) diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt deleted file mode 100644 index 5eb8794f2..000000000 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrices.kt +++ /dev/null @@ -1,340 +0,0 @@ -package kscience.kmath.gsl - -import kotlinx.cinterop.* -import kscience.kmath.linear.FeaturedMatrix -import kscience.kmath.linear.MatrixFeature -import kscience.kmath.operations.Complex -import kscience.kmath.structures.NDStructure -import org.gnu.gsl.* - -public sealed class GslMatrix : GslMemoryHolder(), FeaturedMatrix { - internal abstract operator fun set(i: Int, j: Int, value: T) - internal abstract fun copy(): GslMatrix - - public override fun equals(other: Any?): Boolean { - return NDStructure.equals(this, other as? NDStructure<*> ?: return false) - } - - public override fun hashCode(): Int { - var result = nativeHandle.hashCode() - result = 31 * result + features.hashCode() - return result - } -} - -internal class GslRealMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix = - GslRealMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Double): Unit = - gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslRealMatrix = memScoped { - val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_memcpy(new, nativeHandle) - GslRealMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslRealMatrix) gsl_matrix_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslFloatMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix = - GslFloatMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Float = - gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Float): Unit = - gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslFloatMatrix = memScoped { - val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_float_memcpy(new, nativeHandle) - GslFloatMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_float_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslFloatMatrix) gsl_matrix_float_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslIntMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix = - GslIntMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Int): Unit = - gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslIntMatrix = memScoped { - val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_int_memcpy(new, nativeHandle) - GslIntMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_int_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslIntMatrix) gsl_matrix_int_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslUIntMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix = - GslUIntMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: UInt): Unit = - gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslUIntMatrix = memScoped { - val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_uint_memcpy(new, nativeHandle) - GslUIntMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_uint_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslUIntMatrix) gsl_matrix_uint_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslLongMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix = - GslLongMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Long): Unit = - gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslLongMatrix = memScoped { - val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_long_memcpy(new, nativeHandle) - GslLongMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_long_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslLongMatrix) gsl_matrix_long_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslULongMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix = - GslULongMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): ULong = - gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: ULong): Unit = - gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslULongMatrix = memScoped { - val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_ulong_memcpy(new, nativeHandle) - GslULongMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_ulong_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslULongMatrix) gsl_matrix_ulong_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslShortMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix = - GslShortMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Short = - gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Short): Unit = - gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslShortMatrix = memScoped { - val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_short_memcpy(new, nativeHandle) - GslShortMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_short_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslShortMatrix) gsl_matrix_short_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslUShortMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix = - GslUShortMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): UShort = - gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: UShort): Unit = - gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslUShortMatrix = memScoped { - val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_ushort_memcpy(new, nativeHandle) - GslUShortMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_ushort_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslUShortMatrix) gsl_matrix_ushort_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslComplexMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix = - GslComplexMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Complex = - gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).toKMath() - - override operator fun set(i: Int, j: Int, value: Complex): Unit = - gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl()) - - override fun copy(): GslComplexMatrix = memScoped { - val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_complex_memcpy(new, nativeHandle) - GslComplexMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_complex_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt new file mode 100644 index 000000000..fadecc4eb --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt @@ -0,0 +1,21 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.CStructVar +import kscience.kmath.linear.FeaturedMatrix +import kscience.kmath.structures.NDStructure + +public abstract class GslMatrix internal constructor(): GslMemoryHolder(), + FeaturedMatrix { + internal abstract operator fun set(i: Int, j: Int, value: T) + internal abstract fun copy(): GslMatrix + + public override fun equals(other: Any?): Boolean { + return NDStructure.equals(this, other as? NDStructure<*> ?: return false) + } + + public override fun hashCode(): Int { + var result = nativeHandle.hashCode() + result = 31 * result + features.hashCode() + return result + } +} \ No newline at end of file diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt index 3806a2e1c..4741d8af1 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt @@ -8,7 +8,7 @@ import kscience.kmath.linear.Point import kscience.kmath.operations.Complex import org.gnu.gsl.* -public sealed class GslVector : GslMemoryHolder(), Point { +public abstract class GslVector internal constructor(): GslMemoryHolder(), Point { public override fun iterator(): Iterator = object : Iterator { private var cursor = 0 diff --git a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt index 8fa664336..41efbd195 100644 --- a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt +++ b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt @@ -1,6 +1,5 @@ package kscience.kmath.gsl -import kotlinx.io.use import kscience.kmath.operations.invoke import kotlin.test.Test import kotlin.test.assertEquals @@ -8,8 +7,10 @@ import kotlin.test.assertEquals internal class RealTest { @Test fun testScale() = GslRealMatrixContext { - GslRealMatrixContext.produce(10, 10) { _, _ -> 0.1 }.use { ma -> - (ma * 20.0).use { mb -> assertEquals(mb[0, 1], 2.0) } - } + val ma = GslRealMatrixContext.produce(10, 10) { _, _ -> 0.1 } + val mb = (ma * 20.0) + assertEquals(mb[0, 1], 2.0) + mb.close() + ma.close() } }