From f341d029415515915750c85e5530bc5362f6f431 Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Thu, 29 Oct 2020 04:23:36 +0700 Subject: [PATCH] Rename files, improve error handling, minor codegen update --- .../kmath/gsl/codegen/vectorsCodegen.kt | 4 +- kmath-gsl/build.gradle.kts | 10 +- .../kotlin/kscience/kmath/gsl/GslException.kt | 70 ++++++++ .../kscience/kmath/gsl/GslMatrixContext.kt | 168 +++++++++++++++++ .../kscience/kmath/gsl/GslMatrixContexts.kt | 170 ------------------ .../kscience/kmath/gsl/GslMemoryHolder.kt | 1 + .../kotlin/kscience/kmath/gsl/_Vectors.kt | 36 ++-- .../src/nativeTest/kotlin/ErrorHandler.kt | 23 +++ 8 files changed, 288 insertions(+), 194 deletions(-) create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslException.kt delete mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt create mode 100644 kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt index 018d84773..e227dc721 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt @@ -18,8 +18,8 @@ private fun KtPsiFactory.createVectorClass( val structName = sn("gsl_vectorR", cTypeName) @Language("kotlin") val text = - """internal class $className(override val nativeHandle: CPointer<$structName>, scope: DeferScope) - : GslVector<$kotlinTypeName, $structName>(scope) { + """internal class $className(override val nativeHandle: CPointer<$structName>, scope: DeferScope) : + GslVector<$kotlinTypeName, $structName>(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index df1e635be..d7d53e797 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -12,7 +12,7 @@ kotlin { val nativeTarget = when (System.getProperty("os.name")) { "Mac OS X" -> macosX64() - "Linux" -> linuxX64() + "Linux" -> linuxX64("native") else -> { logger.warn("Current OS cannot build any of kmath-gsl targets.") @@ -29,7 +29,7 @@ kotlin { val test by nativeTarget.compilations.getting sourceSets { - val nativeMain by creating { + val nativeMain by getting { val codegen by tasks.creating { matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt") vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt") @@ -42,11 +42,11 @@ kotlin { } } - val nativeTest by creating { + val nativeTest by getting { dependsOn(nativeMain) } - main.defaultSourceSet.dependsOn(nativeMain) - test.defaultSourceSet.dependsOn(nativeTest) +// main.defaultSourceSet.dependsOn(nativeMain) +// test.defaultSourceSet.dependsOn(nativeTest) } } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslException.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslException.kt new file mode 100644 index 000000000..65e466015 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslException.kt @@ -0,0 +1,70 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.staticCFunction +import kotlinx.cinterop.toKString +import org.gnu.gsl.gsl_set_error_handler +import org.gnu.gsl.gsl_set_error_handler_off +import kotlin.native.concurrent.AtomicInt + +private object Container { + val isKmathHandlerRegistered = AtomicInt(0) +} + +internal enum class GslErrnoValue(val code: Int, val text: String) { + GSL_SUCCESS(org.gnu.gsl.GSL_SUCCESS, ""), + GSL_FAILURE(org.gnu.gsl.GSL_FAILURE, ""), + GSL_CONTINUE(org.gnu.gsl.GSL_CONTINUE, "iteration has not converged"), + GSL_EDOM(org.gnu.gsl.GSL_EDOM, "input domain error, e.g sqrt(-1)"), + GSL_ERANGE(org.gnu.gsl.GSL_ERANGE, "output range error, e.g. exp(1e100)"), + GSL_EFAULT(org.gnu.gsl.GSL_EFAULT, "invalid pointer"), + GSL_EINVAL(org.gnu.gsl.GSL_EINVAL, "invalid argument supplied by user"), + GSL_EFAILED(org.gnu.gsl.GSL_EFAILED, "generic failure"), + GSL_EFACTOR(org.gnu.gsl.GSL_EFACTOR, "factorization failed"), + GSL_ESANITY(org.gnu.gsl.GSL_ESANITY, "sanity check failed - shouldn't happen"), + GSL_ENOMEM(org.gnu.gsl.GSL_ENOMEM, "malloc failed"), + GSL_EBADFUNC(org.gnu.gsl.GSL_EBADFUNC, "problem with user-supplied function"), + GSL_ERUNAWAY(org.gnu.gsl.GSL_ERUNAWAY, "iterative process is out of control"), + GSL_EMAXITER(org.gnu.gsl.GSL_EMAXITER, "exceeded max number of iterations"), + GSL_EZERODIV(org.gnu.gsl.GSL_EZERODIV, "tried to divide by zero"), + GSL_EBADTOL(org.gnu.gsl.GSL_EBADTOL, "user specified an invalid tolerance"), + GSL_ETOL(org.gnu.gsl.GSL_ETOL, "failed to reach the specified tolerance"), + GSL_EUNDRFLW(org.gnu.gsl.GSL_EUNDRFLW, "underflow"), + GSL_EOVRFLW(org.gnu.gsl.GSL_EOVRFLW, "overflow "), + GSL_ELOSS(org.gnu.gsl.GSL_ELOSS, "loss of accuracy"), + GSL_EROUND(org.gnu.gsl.GSL_EROUND, "failed because of roundoff error"), + GSL_EBADLEN(org.gnu.gsl.GSL_EBADLEN, "matrix, vector lengths are not conformant"), + GSL_ENOTSQR(org.gnu.gsl.GSL_ENOTSQR, "matrix not square"), + GSL_ESING(org.gnu.gsl.GSL_ESING, "apparent singularity detected"), + GSL_EDIVERGE(org.gnu.gsl.GSL_EDIVERGE, "integral or series is divergent"), + GSL_EUNSUP(org.gnu.gsl.GSL_EUNSUP, "requested feature is not supported by the hardware"), + GSL_EUNIMPL(org.gnu.gsl.GSL_EUNIMPL, "requested feature not (yet) implemented"), + GSL_ECACHE(org.gnu.gsl.GSL_ECACHE, "cache limit exceeded"), + GSL_ETABLE(org.gnu.gsl.GSL_ETABLE, "table limit exceeded"), + GSL_ENOPROG(org.gnu.gsl.GSL_ENOPROG, "iteration is not making progress towards solution"), + GSL_ENOPROGJ(org.gnu.gsl.GSL_ENOPROGJ, "jacobian evaluations are not improving the solution"), + GSL_ETOLF(org.gnu.gsl.GSL_ETOLF, "cannot reach the specified tolerance in F"), + GSL_ETOLX(org.gnu.gsl.GSL_ETOLX, "cannot reach the specified tolerance in X"), + GSL_ETOLG(org.gnu.gsl.GSL_ETOLG, "cannot reach the specified tolerance in gradient"), + GSL_EOF(org.gnu.gsl.GSL_EOF, "end of file"); + + override fun toString(): String = "${name}('$text')" + + companion object { + fun valueOf(code: Int): GslErrnoValue? = values().find { it.code == code } + } +} + +internal class GslException internal constructor(file: String, line: Int, reason: String, errno: Int) : + RuntimeException("$file:$line: $reason. errno - $errno, ${GslErrnoValue.valueOf(errno)}") { +} + +internal fun ensureHasGslErrorHandler() { + if (Container.isKmathHandlerRegistered.value == 1) return + gsl_set_error_handler_off() + + gsl_set_error_handler(staticCFunction { reason, file, line, errno -> + throw GslException(checkNotNull(file).toKString(), line, checkNotNull(reason).toKString(), errno) + }) + + Container.isKmathHandlerRegistered.value = 1 +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt index 8b0f12883..762cf1252 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt @@ -1,7 +1,175 @@ package kscience.kmath.gsl import kotlinx.cinterop.CStructVar +import kotlinx.cinterop.DeferScope +import kotlinx.cinterop.pointed import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.Point +import kscience.kmath.operations.Complex +import kscience.kmath.operations.ComplexField +import kscience.kmath.operations.toComplex import kscience.kmath.structures.Matrix +import org.gnu.gsl.* +internal inline fun GslMatrix.fill(initializer: (Int, Int) -> T): GslMatrix = + apply { + (0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } } + } + +internal inline fun GslVector.fill(initializer: (Int) -> T): GslVector = + apply { (0 until size).forEach { index -> this[index] = initializer(index) } } + +public abstract class GslMatrixContext< + T : Any, + H1 : CStructVar, + H2 : CStructVar> internal constructor(internal val scope: DeferScope) : MatrixContext { + init { + ensureHasGslErrorHandler() + } + + @Suppress("UNCHECKED_CAST") + public fun Matrix.toGsl(): GslMatrix = (if (this is GslMatrix<*, *>) + this as GslMatrix + else + produce(rowNum, colNum) { i, j -> this[i, j] }).copy() + + @Suppress("UNCHECKED_CAST") + public fun Point.toGsl(): GslVector = + (if (this is GslVector<*, *>) this as GslVector else produceDirtyVector(size).fill { this[it] }).copy() + + internal abstract fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix + internal abstract fun produceDirtyVector(size: Int): GslVector + + public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix = + produceDirtyMatrix(rows, columns).fill(initializer) +} + +public class GslRealMatrixContext(scope: DeferScope) : GslMatrixContext(scope) { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = + GslRealMatrix(nativeHandle = requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())), scope = scope) + + override fun produceDirtyVector(size: Int): GslVector = + GslRealVector(nativeHandle = requireNotNull(gsl_vector_alloc(size.toULong())), scope = scope) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result) + return GslRealMatrix(result, scope = scope) + } + + public override fun Matrix.dot(vector: Point): GslVector { + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_calloc(a.pointed.size)) + gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result) + return GslRealVector(result, scope = scope) + } + + public override fun Matrix.times(value: Double): GslMatrix { + val g1 = toGsl() + gsl_matrix_scale(g1.nativeHandle, value) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_scale(g1.nativeHandle, k.toDouble()) + return g1 + } +} + +public class GslFloatMatrixContext(scope: DeferScope) : + GslMatrixContext(scope) { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = + GslFloatMatrix(requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())), scope = scope) + + override fun produceDirtyVector(size: Int): GslVector = + GslFloatVector(requireNotNull(gsl_vector_float_alloc(size.toULong())), scope) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result) + return GslFloatMatrix(nativeHandle = result, scope = scope) + } + + public override fun Matrix.dot(vector: Point): GslVector { + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size)) + gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result) + return GslFloatVector(nativeHandle = result, scope = scope) + } + + public override fun Matrix.times(value: Float): GslMatrix { + val g1 = toGsl() + gsl_matrix_float_scale(g1.nativeHandle, value.toDouble()) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_float_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_float_scale(g1.nativeHandle, k.toDouble()) + return g1 + } +} + +public class GslComplexMatrixContext(scope: DeferScope) : + GslMatrixContext(scope) { + override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = GslComplexMatrix( + nativeHandle = requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())), + scope = scope + ) + + override fun produceDirtyVector(size: Int): GslVector = + GslComplexVector(requireNotNull(gsl_vector_complex_alloc(size.toULong())), scope) + + public override fun Matrix.dot(other: Matrix): GslMatrix { + val x = toGsl().nativeHandle + val a = other.toGsl().nativeHandle + val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2)) + gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) + return GslComplexMatrix(nativeHandle = result, scope = scope) + } + + public override fun Matrix.dot(vector: Point): GslVector { + val x = toGsl().nativeHandle + val a = vector.toGsl().nativeHandle + val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size)) + gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) + return GslComplexVector(result, scope) + } + + public override fun Matrix.times(value: Complex): GslMatrix { + val g1 = toGsl() + gsl_matrix_complex_scale(g1.nativeHandle, value.toGsl()) + return g1 + } + + public override fun add(a: Matrix, b: Matrix): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_complex_add(g1.nativeHandle, b.toGsl().nativeHandle) + return g1 + } + + public override fun multiply(a: Matrix, k: Number): GslMatrix { + val g1 = a.toGsl() + gsl_matrix_complex_scale(g1.nativeHandle, k.toComplex().toGsl()) + return g1 + } +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt deleted file mode 100644 index 3ac6582eb..000000000 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContexts.kt +++ /dev/null @@ -1,170 +0,0 @@ -package kscience.kmath.gsl - -import kotlinx.cinterop.CStructVar -import kotlinx.cinterop.DeferScope -import kotlinx.cinterop.pointed -import kscience.kmath.linear.MatrixContext -import kscience.kmath.linear.Point -import kscience.kmath.operations.Complex -import kscience.kmath.operations.ComplexField -import kscience.kmath.operations.toComplex -import kscience.kmath.structures.Matrix -import org.gnu.gsl.* - -internal inline fun GslMatrix.fill(initializer: (Int, Int) -> T): GslMatrix = - apply { - (0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } } - } - -internal inline fun GslVector.fill(initializer: (Int) -> T): GslVector = - apply { (0 until size).forEach { index -> this[index] = initializer(index) } } - -public abstract class GslMatrixContext internal constructor( - internal val scope: DeferScope -) : MatrixContext { - @Suppress("UNCHECKED_CAST") - public fun Matrix.toGsl(): GslMatrix = (if (this is GslMatrix<*, *>) - this as GslMatrix - else - produce(rowNum, colNum) { i, j -> this[i, j] }).copy() - - @Suppress("UNCHECKED_CAST") - public fun Point.toGsl(): GslVector = - (if (this is GslVector<*, *>) this as GslVector else produceDirtyVector(size).fill { this[it] }).copy() - - internal abstract fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix - internal abstract fun produceDirtyVector(size: Int): GslVector - - public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix = - produceDirtyMatrix(rows, columns).fill(initializer) -} - -public class GslRealMatrixContext(scope: DeferScope) : GslMatrixContext(scope) { - override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = - GslRealMatrix(nativeHandle = requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())), scope = scope) - - override fun produceDirtyVector(size: Int): GslVector = - GslRealVector(nativeHandle = requireNotNull(gsl_vector_alloc(size.toULong())), scope = scope) - - public override fun Matrix.dot(other: Matrix): GslMatrix { - val x = toGsl().nativeHandle - val a = other.toGsl().nativeHandle - val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2)) - gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result) - return GslRealMatrix(result, scope = scope) - } - - public override fun Matrix.dot(vector: Point): GslVector { - val x = toGsl().nativeHandle - val a = vector.toGsl().nativeHandle - val result = requireNotNull(gsl_vector_calloc(a.pointed.size)) - gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result) - return GslRealVector(result, scope = scope) - } - - public override fun Matrix.times(value: Double): GslMatrix { - val g1 = toGsl() - gsl_matrix_scale(g1.nativeHandle, value) - return g1 - } - - public override fun add(a: Matrix, b: Matrix): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_add(g1.nativeHandle, b.toGsl().nativeHandle) - return g1 - } - - public override fun multiply(a: Matrix, k: Number): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_scale(g1.nativeHandle, k.toDouble()) - return g1 - } -} - -public class GslFloatMatrixContext(scope: DeferScope) : - GslMatrixContext(scope) { - override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = - GslFloatMatrix(requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())), scope = scope) - - override fun produceDirtyVector(size: Int): GslVector = - GslFloatVector(requireNotNull(gsl_vector_float_alloc(size.toULong())), scope) - - public override fun Matrix.dot(other: Matrix): GslMatrix { - val x = toGsl().nativeHandle - val a = other.toGsl().nativeHandle - val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2)) - gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result) - return GslFloatMatrix(nativeHandle = result, scope = scope) - } - - public override fun Matrix.dot(vector: Point): GslVector { - val x = toGsl().nativeHandle - val a = vector.toGsl().nativeHandle - val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size)) - gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result) - return GslFloatVector(nativeHandle = result, scope = scope) - } - - public override fun Matrix.times(value: Float): GslMatrix { - val g1 = toGsl() - gsl_matrix_float_scale(g1.nativeHandle, value.toDouble()) - return g1 - } - - public override fun add(a: Matrix, b: Matrix): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_float_add(g1.nativeHandle, b.toGsl().nativeHandle) - return g1 - } - - public override fun multiply(a: Matrix, k: Number): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_float_scale(g1.nativeHandle, k.toDouble()) - return g1 - } -} - -public class GslComplexMatrixContext(scope: DeferScope) : - GslMatrixContext(scope) { - override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = GslComplexMatrix( - nativeHandle = requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())), - scope = scope - ) - - override fun produceDirtyVector(size: Int): GslVector = - GslComplexVector(requireNotNull(gsl_vector_complex_alloc(size.toULong())), scope) - - public override fun Matrix.dot(other: Matrix): GslMatrix { - val x = toGsl().nativeHandle - val a = other.toGsl().nativeHandle - val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2)) - gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) - return GslComplexMatrix(nativeHandle = result, scope = scope) - } - - public override fun Matrix.dot(vector: Point): GslVector { - val x = toGsl().nativeHandle - val a = vector.toGsl().nativeHandle - val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size)) - gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) - return GslComplexVector(result, scope) - } - - public override fun Matrix.times(value: Complex): GslMatrix { - val g1 = toGsl() - gsl_matrix_complex_scale(g1.nativeHandle, value.toGsl()) - return g1 - } - - public override fun add(a: Matrix, b: Matrix): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_complex_add(g1.nativeHandle, b.toGsl().nativeHandle) - return g1 - } - - public override fun multiply(a: Matrix, k: Number): GslMatrix { - val g1 = a.toGsl() - gsl_matrix_complex_scale(g1.nativeHandle, k.toComplex().toGsl()) - return g1 - } -} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt index d3544874d..d52b78ccc 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMemoryHolder.kt @@ -8,6 +8,7 @@ public abstract class GslMemoryHolder internal constructor(inter internal abstract val nativeHandle: CPointer init { + ensureHasGslErrorHandler() scope.defer(::close) } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/_Vectors.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/_Vectors.kt index 8990416b3..7ef8c24e7 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/_Vectors.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/_Vectors.kt @@ -1,10 +1,12 @@ package kscience.kmath.gsl -import kotlinx.cinterop.* +import kotlinx.cinterop.CPointer +import kotlinx.cinterop.DeferScope +import kotlinx.cinterop.pointed import org.gnu.gsl.* -internal class GslRealVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslRealVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -25,8 +27,8 @@ internal class GslRealVector(override val nativeHandle: CPointer, sc override fun close(): Unit = gsl_vector_free(nativeHandle) } -internal class GslFloatVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslFloatVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -47,8 +49,8 @@ internal class GslFloatVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslShortVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -69,8 +71,8 @@ internal class GslShortVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslUShortVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -91,8 +93,8 @@ internal class GslUShortVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslLongVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -113,8 +115,8 @@ internal class GslLongVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslULongVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -135,8 +137,8 @@ internal class GslULongVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslIntVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() @@ -157,8 +159,8 @@ internal class GslIntVector(override val nativeHandle: CPointer, override fun close(): Unit = gsl_vector_int_free(nativeHandle) } -internal class GslUIntVector(override val nativeHandle: CPointer, scope: DeferScope) - : GslVector(scope) { +internal class GslUIntVector(override val nativeHandle: CPointer, scope: DeferScope) : + GslVector(scope) { override val size: Int get() = nativeHandle.pointed.size.toInt() diff --git a/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt b/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt new file mode 100644 index 000000000..1259f58b9 --- /dev/null +++ b/kmath-gsl/src/nativeTest/kotlin/ErrorHandler.kt @@ -0,0 +1,23 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.memScoped +import org.gnu.gsl.gsl_block_calloc +import kotlin.test.Test +import kotlin.test.assertFailsWith + +internal class ErrorHandler { + @Test + fun blockAllocation() { + assertFailsWith { + ensureHasGslErrorHandler() + gsl_block_calloc(ULong.MAX_VALUE) + } + } + + @Test + fun matrixAllocation() { + assertFailsWith { + memScoped { GslRealMatrixContext(this).produce(Int.MAX_VALUE, Int.MAX_VALUE) { _, _ -> 0.0 } } + } + } +}