diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt index edc244bb3..2981b9502 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/PsiTestUtil.kt @@ -6,8 +6,12 @@ import org.jetbrains.kotlin.com.intellij.psi.impl.DebugUtil import org.jetbrains.kotlin.com.intellij.psi.impl.source.PsiFileImpl import org.jetbrains.kotlin.com.intellij.testFramework.LightVirtualFile import org.jetbrains.kotlin.psi.KtFile +import java.util.regex.Pattern import kotlin.math.min +private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+") +internal fun splitByLinesDontTrim(string: String): Array = EOL_SPLIT_DONT_TRIM_PATTERN.split(string) + internal object PsiTestUtil { fun checkFileStructure(file: KtFile) { compareFromAllRoots(file) { f -> DebugUtil.psiTreeToString(f, false) } diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt deleted file mode 100644 index c56655b9f..000000000 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/StringUtilExt.kt +++ /dev/null @@ -1,9 +0,0 @@ -package kscience.kmath.gsl.codegen - -import java.util.regex.Pattern - -private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+") - -internal fun splitByLinesDontTrim(string: String): Array { - return EOL_SPLIT_DONT_TRIM_PATTERN.split(string) -} diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt index 1c507cc3d..92125f835 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt @@ -22,10 +22,10 @@ private fun KtPsiFactory.createMatrixClass( features: Set = emptySet() ) : GslMatrix<$kotlinTypeName, $structName>() { override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } + get() = nativeHandle.pointed.size1.toInt() override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } + get() = nativeHandle.pointed.size2.toInt() override val features: Set = features @@ -39,30 +39,26 @@ private fun KtPsiFactory.createMatrixClass( override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit = ${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value) - override fun copy(): $className = memScoped { + override fun copy(): $className { val new = requireNotNull(${fn("gsl_matrixRalloc", cTypeName)}(rowNum.toULong(), colNum.toULong())) ${fn("gsl_matrixRmemcpy", cTypeName)}(new, nativeHandle) - $className(new, features) + return $className(new, features) } override fun close(): Unit = ${fn("gsl_matrixRfree", cTypeName)}(nativeHandle) override fun equals(other: Any?): Boolean { - if (other is $className) ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) + if (other is $className) return ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1 return super.equals(other) } }""" - f += createClass( - text - ) - + f += createClass(text) f += createNewLine(2) } fun matricesCodegen(outputFile: String, project: Project = createProject()) { val f = KtPsiFactory(project, true).run { - createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> - f += createNewLine(2) + createFile("").also { f -> f += createPackageDirective(FqName("kscience.kmath.gsl")) f += createNewLine(2) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt index f29def5b0..830e858c4 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt @@ -20,22 +20,34 @@ private fun KtPsiFactory.createVectorClass( @Language("kotlin") val text = """internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() { override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } + get() = nativeHandle.pointed.size.toInt() override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong()) + override fun set(index: Int, value: $kotlinTypeName): Unit = ${ + fn("gsl_vectorRset", cTypeName) + }(nativeHandle, index.toULong(), value) + + override fun copy(): $className { + val new = requireNotNull(${fn("gsl_vectorRalloc", cTypeName)}(size.toULong())) + ${fn("gsl_vectorRmemcpy", cTypeName)}(new, nativeHandle) + return ${className}(new) + } + + override fun equals(other: Any?): Boolean { + if (other is $className) return ${fn("gsl_vectorRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1 + return super.equals(other) + } + override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle) }""" - f += createClass( - text - ) + f += createClass(text) f += createNewLine(2) } fun vectorsCodegen(outputFile: String, project: Project = createProject()) { val f = KtPsiFactory(project, true).run { - createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> - f += createNewLine(2) + createFile("").also { f -> f += createPackageDirective(FqName("kscience.kmath.gsl")) f += createNewLine(2) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt index 300b963a6..c1c4a2d5e 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/linear/MatrixContext.kt @@ -55,7 +55,7 @@ public interface MatrixContext : SpaceOperations> { * Multiplies an element by a matrix of it. * * @receiver the multiplicand. - * @param value the multiplier. + * @param m the multiplier. * @receiver the product. */ public operator fun T.times(m: Matrix): Matrix = m * this diff --git a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt index 20f289596..a4f8406fb 100644 --- a/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt +++ b/kmath-core/src/commonMain/kotlin/kscience/kmath/operations/BigInt.kt @@ -38,7 +38,7 @@ public object BigIntField : Field { public class BigInt internal constructor( private val sign: Byte, private val magnitude: Magnitude -) : Comparable { + ) : Comparable { public override fun compareTo(other: BigInt): Int = when { (sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0 sign < other.sign -> -1 diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index 048d7afed..09f614d4d 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -38,8 +38,8 @@ kotlin { } internal val codegen: Task by tasks.creating { - matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt") - vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt") + matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt") + vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt") } kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen)) diff --git a/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt b/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt deleted file mode 100644 index 8d471beaa..000000000 --- a/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt +++ /dev/null @@ -1,280 +0,0 @@ -@file:Suppress("PackageDirectoryMismatch") - -package kscience.kmath.gsl - -import kotlinx.cinterop.* -import kscience.kmath.linear.* -import org.gnu.gsl.* - -internal class GslRealMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix = - GslRealMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Double): Unit = - gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslRealMatrix = memScoped { - val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_memcpy(new, nativeHandle) - GslRealMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslRealMatrix) gsl_matrix_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslFloatMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix = - GslFloatMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Float = gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Float): Unit = - gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslFloatMatrix = memScoped { - val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_float_memcpy(new, nativeHandle) - GslFloatMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_float_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslFloatMatrix) gsl_matrix_float_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslShortMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix = - GslShortMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Short = gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Short): Unit = - gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslShortMatrix = memScoped { - val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_short_memcpy(new, nativeHandle) - GslShortMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_short_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslShortMatrix) gsl_matrix_short_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslUShortMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix = - GslUShortMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): UShort = gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: UShort): Unit = - gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslUShortMatrix = memScoped { - val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_ushort_memcpy(new, nativeHandle) - GslUShortMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_ushort_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslUShortMatrix) gsl_matrix_ushort_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslLongMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix = - GslLongMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Long): Unit = - gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslLongMatrix = memScoped { - val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_long_memcpy(new, nativeHandle) - GslLongMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_long_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslLongMatrix) gsl_matrix_long_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslULongMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix = - GslULongMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): ULong = gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: ULong): Unit = - gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslULongMatrix = memScoped { - val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_ulong_memcpy(new, nativeHandle) - GslULongMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_ulong_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslULongMatrix) gsl_matrix_ulong_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslIntMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix = - GslIntMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: Int): Unit = - gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslIntMatrix = memScoped { - val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_int_memcpy(new, nativeHandle) - GslIntMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_int_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslIntMatrix) gsl_matrix_int_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - -internal class GslUIntMatrix( - override val nativeHandle: CPointer, - features: Set = emptySet() -) : GslMatrix() { - override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } - - override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } - - override val features: Set = features - - override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix = - GslUIntMatrix(nativeHandle, this.features + features) - - override operator fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong()) - - override operator fun set(i: Int, j: Int, value: UInt): Unit = - gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value) - - override fun copy(): GslUIntMatrix = memScoped { - val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong())) - gsl_matrix_uint_memcpy(new, nativeHandle) - GslUIntMatrix(new, features) - } - - override fun close(): Unit = gsl_matrix_uint_free(nativeHandle) - - override fun equals(other: Any?): Boolean { - if (other is GslUIntMatrix) gsl_matrix_uint_equal(nativeHandle, other.nativeHandle) - return super.equals(other) - } -} - diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt index 68bde7ef2..38d091414 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt @@ -17,10 +17,10 @@ internal class GslComplexMatrix( features: Set = emptySet() ) : GslMatrix() { override val rowNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } + get() = nativeHandle.pointed.size1.toInt() override val colNum: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } + get() = nativeHandle.pointed.size2.toInt() override val features: Set = features @@ -33,16 +33,16 @@ internal class GslComplexMatrix( override operator fun set(i: Int, j: Int, value: Complex): Unit = gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl()) - override fun copy(): GslComplexMatrix = memScoped { + override fun copy(): GslComplexMatrix { val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong())) gsl_matrix_complex_memcpy(new, nativeHandle) - GslComplexMatrix(new, features) + return GslComplexMatrix(new, features) } override fun close(): Unit = gsl_matrix_complex_free(nativeHandle) override fun equals(other: Any?): Boolean { - if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) + if (other is GslComplexMatrix) return gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) == 1 return super.equals(other) } } @@ -50,8 +50,21 @@ internal class GslComplexMatrix( internal class GslComplexVector(override val nativeHandle: CPointer) : GslVector() { override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } + get() = nativeHandle.pointed.size.toInt() override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath() + override fun set(index: Int, value: Complex): Unit = gsl_vector_complex_set(nativeHandle, index.toULong(), value.toGsl()) + + override fun copy(): GslComplexVector { + val new = requireNotNull(gsl_vector_complex_alloc(size.toULong())) + gsl_vector_complex_memcpy(new, nativeHandle) + return GslComplexVector(new) + } + + override fun equals(other: Any?): Boolean { + if (other is GslComplexVector) return gsl_vector_complex_equal(nativeHandle, other.nativeHandle) == 1 + return super.equals(other) + } + override fun close(): Unit = gsl_vector_complex_free(nativeHandle) } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt index fadecc4eb..c7323437d 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrix.kt @@ -13,9 +13,9 @@ public abstract class GslMatrix internal constructor(): return NDStructure.equals(this, other as? NDStructure<*> ?: return false) } - public override fun hashCode(): Int { + public final override fun hashCode(): Int { var result = nativeHandle.hashCode() result = 31 * result + features.hashCode() return result } -} \ No newline at end of file +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt new file mode 100644 index 000000000..8b0f12883 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt @@ -0,0 +1,7 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.CStructVar +import kscience.kmath.linear.MatrixContext +import kscience.kmath.linear.Point +import kscience.kmath.structures.Matrix + diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt index a27a22110..4428b0185 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt @@ -1,40 +1,63 @@ package kscience.kmath.gsl import kotlinx.cinterop.CStructVar +import kotlinx.cinterop.pointed import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.Point -import kscience.kmath.operations.invoke +import kscience.kmath.operations.Complex +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.toComplex import kscience.kmath.structures.Matrix import org.gnu.gsl.* -private inline fun GslMatrix.fill(initializer: (Int, Int) -> T): GslMatrix = +internal inline fun GslMatrix.fill(initializer: (Int, Int) -> T): GslMatrix = apply { (0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } } } -public sealed class GslMatrixContext : MatrixContext { +internal inline fun GslVector.fill(initializer: (Int) -> T): GslVector = + apply { (0 until size).forEach { index -> this[index] = initializer(index) } } + +public abstract class GslMatrixContext internal constructor() : + MatrixContext { @Suppress("UNCHECKED_CAST") - public fun Matrix.toGsl(): GslMatrix = - (if (this is GslMatrix<*, *>) this as GslMatrix else produce(rowNum, colNum) { i, j -> get(i, j) }).copy() + public fun Matrix.toGsl(): GslMatrix = (if (this is GslMatrix<*, *>) + this as GslMatrix + else + produce(rowNum, colNum) { i, j -> this[i, j] }).copy() - internal abstract fun produceDirty(rows: Int, columns: Int): GslMatrix + @Suppress("UNCHECKED_CAST") + public fun Point.toGsl(): GslVector = + (if (this is GslVector<*, *>) this as GslVector else produceDirtyVector(size).fill { this[it] }).copy() - public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix = - produceDirty(rows, columns).fill(initializer) + internal abstract fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix + internal abstract fun produceDirtyVector(size: Int): GslVector + + public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix = + produceDirtyMatrix(rows, columns).fill(initializer) } -public object GslRealMatrixContext : GslMatrixContext() { - public override fun produceDirty(rows: Int, columns: Int): GslMatrix = +public object GslRealMatrixContext : GslMatrixContext() { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = GslRealMatrix(requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong()))) + override fun produceDirtyVector(size: Int): GslVector = + GslRealVector(requireNotNull(gsl_vector_alloc(size.toULong()))) + public override fun Matrix.dot(other: Matrix): GslMatrix { - val g1 = toGsl() - gsl_matrix_mul_elements(g1.nativeHandle, other.toGsl().nativeHandle) - return g1 + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result) + return GslRealMatrix(result) } public override fun Matrix.dot(vector: Point): GslVector { - TODO() + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_calloc(a.pointed.size)) + gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result) + return GslRealVector(result) } public override fun Matrix.times(value: Double): GslMatrix { @@ -55,3 +78,87 @@ public object GslRealMatrixContext : GslMatrixContext() { return g1 } } + +public object GslFloatMatrixContext : GslMatrixContext() { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = + GslFloatMatrix(requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong()))) + + override fun produceDirtyVector(size: Int): GslVector = + GslFloatVector(requireNotNull(gsl_vector_float_alloc(size.toULong()))) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result) + return GslFloatMatrix(result) + } + + public override fun Matrix.dot(vector: Point): GslVector { + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size)) + gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result) + return GslFloatVector(result) + } + + public override fun Matrix.times(value: Float): GslMatrix { + val g1 = toGsl() + gsl_matrix_float_scale(g1.nativeHandle, value.toDouble()) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_float_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_float_scale(g1.nativeHandle, k.toDouble()) + return g1 + } +} + +public object GslComplexMatrixContext : GslMatrixContext() { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = + GslComplexMatrix(requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong()))) + + override fun produceDirtyVector(size: Int): GslVector = + GslComplexVector(requireNotNull(gsl_vector_complex_alloc(size.toULong()))) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) + return GslComplexMatrix(result) + } + + public override fun Matrix.dot(vector: Point): GslVector { + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size)) + gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) + return GslComplexVector(result) + } + + public override fun Matrix.times(value: Complex): GslMatrix { + val g1 = toGsl() + gsl_matrix_complex_scale(g1.nativeHandle, value.toGsl()) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_complex_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_complex_scale(g1.nativeHandle, k.toComplex().toGsl()) + return g1 + } +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt index 22b80c435..87fbce607 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt @@ -3,8 +3,11 @@ package kscience.kmath.gsl import kotlinx.cinterop.CStructVar import kscience.kmath.linear.Point -abstract class GslVector internal constructor(): GslMemoryHolder(), Point { - public override fun iterator(): Iterator = object : Iterator { +public abstract class GslVector internal constructor() : GslMemoryHolder(), Point { + internal abstract operator fun set(index: Int, value: T) + internal abstract fun copy(): GslVector + + public final override fun iterator(): Iterator = object : Iterator { private var cursor = 0 override fun hasNext(): Boolean = cursor < size @@ -14,4 +17,4 @@ abstract class GslVector internal constructor(): GslMemoryHol return this@GslVector[cursor - 1] } } -} \ No newline at end of file +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt deleted file mode 100644 index 4741d8af1..000000000 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVectors.kt +++ /dev/null @@ -1,101 +0,0 @@ -package kscience.kmath.gsl - -import kotlinx.cinterop.CPointer -import kotlinx.cinterop.CStructVar -import kotlinx.cinterop.memScoped -import kotlinx.cinterop.pointed -import kscience.kmath.linear.Point -import kscience.kmath.operations.Complex -import org.gnu.gsl.* - -public abstract class GslVector internal constructor(): GslMemoryHolder(), Point { - public override fun iterator(): Iterator = object : Iterator { - private var cursor = 0 - - override fun hasNext(): Boolean = cursor < size - - override fun next(): T { - cursor++ - return this@GslVector[cursor - 1] - } - } -} - -internal class GslRealVector(override val nativeHandle: CPointer) : GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): Double = gsl_vector_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_free(nativeHandle) -} - -internal class GslFloatVector(override val nativeHandle: CPointer) : - GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): Float = gsl_vector_float_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_float_free(nativeHandle) -} - -internal class GslIntVector(override val nativeHandle: CPointer) : GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): Int = gsl_vector_int_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_int_free(nativeHandle) -} - -internal class GslUIntVector(override val nativeHandle: CPointer) : - GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): UInt = gsl_vector_uint_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_uint_free(nativeHandle) -} - -internal class GslLongVector(override val nativeHandle: CPointer) : - GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): Long = gsl_vector_long_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_long_free(nativeHandle) -} - -internal class GslULongVector(override val nativeHandle: CPointer) : - GslVector() { - public override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - public override fun get(index: Int): ULong = gsl_vector_ulong_get(nativeHandle, index.toULong()) - public override fun close(): Unit = gsl_vector_ulong_free(nativeHandle) -} - -internal class GslShortVector(override val nativeHandle: CPointer) : - GslVector() { - public override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - public override fun get(index: Int): Short = gsl_vector_short_get(nativeHandle, index.toULong()) - public override fun close(): Unit = gsl_vector_short_free(nativeHandle) -} - -internal class GslUShortVector(override val nativeHandle: CPointer) : - GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): UShort = gsl_vector_ushort_get(nativeHandle, index.toULong()) - override fun close(): Unit = gsl_vector_ushort_free(nativeHandle) -} - -internal class GslComplexVector(override val nativeHandle: CPointer) : - GslVector() { - override val size: Int - get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } - - override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath() - override fun close(): Unit = gsl_vector_complex_free(nativeHandle) -} diff --git a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt index 41efbd195..1208443c5 100644 --- a/kmath-gsl/src/nativeTest/kotlin/RealTest.kt +++ b/kmath-gsl/src/nativeTest/kotlin/RealTest.kt @@ -1,8 +1,13 @@ package kscience.kmath.gsl +import kscience.kmath.linear.RealMatrixContext import kscience.kmath.operations.invoke +import kscience.kmath.structures.RealBuffer +import kscience.kmath.structures.asIterable +import kscience.kmath.structures.asSequence import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue internal class RealTest { @Test @@ -13,4 +18,29 @@ internal class RealTest { mb.close() ma.close() } + + @Test + fun testDotOfMatrixAndVector() { + val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 } + val vb = RealBuffer(2) { 0.1 } + val res1 = GslRealMatrixContext { ma dot vb } + val res2 = RealMatrixContext { ma dot vb } + println(res1.asSequence().toList()) + println(res2.asSequence().toList()) + assertTrue(res1.contentEquals(res2)) + res1.close() + } + + @Test + fun testDotOfMatrixAndMatrix() { + val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 } + val mb = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 } + val res1 = GslRealMatrixContext { ma dot mb } + val res2 = RealMatrixContext { ma dot mb } + println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) + println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList)) + assertEquals(res1, res2) + ma.close() + mb.close() + } }