Improve null checking and introduce internal wrapper of permutation object

This commit is contained in:
Iaroslav Postovalov 2021-01-24 16:13:08 +07:00
parent 472e2bf671
commit c50fab027a
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
8 changed files with 64 additions and 58 deletions

View File

@ -51,7 +51,7 @@ private fun KtPsiFactory.createMatrixClass(
${fn("gsl_matrixRset")}(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): $className {
val new = requireNotNull(${fn("gsl_matrixRalloc")}(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(${fn("gsl_matrixRalloc")}(rowNum.toULong(), colNum.toULong()))
${fn("gsl_matrixRmemcpy")}(new, nativeHandle)
return $className(new, scope)
}

View File

@ -31,7 +31,7 @@ private fun KtPsiFactory.createVectorClass(
${fn("gsl_vectorRset")}(nativeHandle, index.toULong(), value)
override fun copy(): $className {
val new = requireNotNull(${fn("gsl_vectorRalloc")}(size.toULong()))
val new = checkNotNull(${fn("gsl_vectorRalloc")}(size.toULong()))
${fn("gsl_vectorRmemcpy")}(new, nativeHandle)
return ${className}(new, scope)
}

View File

@ -12,7 +12,7 @@ kotlin {
val nativeTarget = when (System.getProperty("os.name")) {
// "Mac OS X" -> macosX64()
"Linux" -> linuxX64()
"Linux" -> linuxX64("native")
else -> {
logger.warn("Current OS cannot build any of kmath-gsl targets.")
@ -29,7 +29,7 @@ kotlin {
val test by nativeTarget.compilations.getting
sourceSets {
val nativeMain by creating {
val nativeMain by getting {
val codegen by tasks.creating {
matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
@ -42,11 +42,11 @@ kotlin {
}
}
val nativeTest by creating {
val nativeTest by getting {
dependsOn(nativeMain)
}
main.defaultSourceSet.dependsOn(nativeMain)
test.defaultSourceSet.dependsOn(nativeTest)
// main.defaultSourceSet.dependsOn(nativeMain)
// test.defaultSourceSet.dependsOn(nativeTest)
}
}

View File

@ -44,7 +44,7 @@ internal class GslComplexMatrix(override val rawNativeHandle: CPointer<gsl_matri
gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl())
override fun copy(): GslComplexMatrix {
val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_complex_memcpy(new, nativeHandle)
return GslComplexMatrix(new, scope)
}
@ -70,7 +70,7 @@ internal class GslComplexVector(override val rawNativeHandle: CPointer<gsl_vecto
gsl_vector_complex_set(nativeHandle, index.toULong(), value.toGsl())
override fun copy(): GslComplexVector {
val new = requireNotNull(gsl_vector_complex_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_complex_alloc(size.toULong()))
gsl_vector_complex_memcpy(new, nativeHandle)
return GslComplexVector(new, scope)
}

View File

@ -63,17 +63,17 @@ public abstract class GslMatrixContext<T : Any, H1 : CStructVar, H2 : CStructVar
public class GslRealMatrixContext(internal val scope: AutofreeScope) :
GslMatrixContext<Double, gsl_matrix, gsl_vector>() {
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Double, gsl_matrix> = GslRealMatrix(
rawNativeHandle = requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())),
rawNativeHandle = checkNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())),
scope = scope,
)
override fun produceDirtyVector(size: Int): GslVector<Double, gsl_vector> =
GslRealVector(rawNativeHandle = requireNotNull(gsl_vector_alloc(size.toULong())), scope = scope)
GslRealVector(rawNativeHandle = checkNotNull(gsl_vector_alloc(size.toULong())), scope = scope)
public override fun Matrix<Double>.dot(other: Matrix<Double>): GslMatrix<Double, gsl_matrix> {
val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle
val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2))
val result = checkNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result)
return GslRealMatrix(result, scope = scope)
}
@ -81,7 +81,7 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
public override fun Matrix<Double>.dot(vector: Point<Double>): GslVector<Double, gsl_vector> {
val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_calloc(a.pointed.size))
val result = checkNotNull(gsl_vector_calloc(a.pointed.size))
gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result)
return GslRealVector(result, scope = scope)
}
@ -118,12 +118,11 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
private val lups by lazy {
val lu = m.toGsl().copy()
val n = m.rowNum
val perm = gsl_permutation_alloc(n.toULong())
scope.defer { gsl_permutation_free(perm) }
val perm = GslPermutation(checkNotNull(gsl_permutation_alloc(n.toULong())), scope)
val signum = memScoped {
val i = alloc<IntVar>()
gsl_linalg_LU_decomp(lu.nativeHandle, perm, i.ptr)
gsl_linalg_LU_decomp(lu.nativeHandle, perm.nativeHandle, i.ptr)
i.value
}
@ -135,11 +134,8 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
val one = produce(n, n) { i, j -> if (i == j) 1.0 else 0.0 }
val perm = produce(n, n) { _, _ -> 0.0 }
for (j in 0 until lups.second!!.pointed.size.toInt()) {
val k = gsl_permutation_get(lups.second!!, j.toULong()).toInt()
val col = one.columns[k]
gsl_matrix_set_col(perm.nativeHandle, j.toULong(), col.toGsl().nativeHandle)
}
for (j in 0 until lups.second.size)
gsl_matrix_set_col(perm.nativeHandle, j.toULong(), one.columns[lups.second[j]].toGsl().nativeHandle)
perm
}
@ -165,7 +161,7 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
override val inverse by lazy {
val inv = lups.first.copy()
gsl_linalg_LU_invx(inv.nativeHandle, lups.second)
gsl_linalg_LU_invx(inv.nativeHandle, lups.second.nativeHandle)
inv
}
}
@ -194,17 +190,17 @@ public fun <R> GslRealMatrixContext(block: GslRealMatrixContext.() -> R): R =
public class GslFloatMatrixContext(internal val scope: AutofreeScope) :
GslMatrixContext<Float, gsl_matrix_float, gsl_vector_float>() {
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Float, gsl_matrix_float> = GslFloatMatrix(
rawNativeHandle = requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())),
rawNativeHandle = checkNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())),
scope = scope,
)
override fun produceDirtyVector(size: Int): GslVector<Float, gsl_vector_float> =
GslFloatVector(rawNativeHandle = requireNotNull(value = gsl_vector_float_alloc(size.toULong())), scope = scope)
GslFloatVector(rawNativeHandle = checkNotNull(value = gsl_vector_float_alloc(size.toULong())), scope = scope)
public override fun Matrix<Float>.dot(other: Matrix<Float>): GslMatrix<Float, gsl_matrix_float> {
val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle
val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2))
val result = checkNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result)
return GslFloatMatrix(rawNativeHandle = result, scope = scope)
}
@ -212,7 +208,7 @@ public class GslFloatMatrixContext(internal val scope: AutofreeScope) :
public override fun Matrix<Float>.dot(vector: Point<Float>): GslVector<Float, gsl_vector_float> {
val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size))
val result = checkNotNull(gsl_vector_float_calloc(a.pointed.size))
gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result)
return GslFloatVector(rawNativeHandle = result, scope = scope)
}
@ -254,17 +250,17 @@ public fun <R> GslFloatMatrixContext(block: GslFloatMatrixContext.() -> R): R =
public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
GslMatrixContext<Complex, gsl_matrix_complex, gsl_vector_complex>() {
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix(
rawNativeHandle = requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
scope = scope,
)
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
GslComplexVector(rawNativeHandle = requireNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope)
GslComplexVector(rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope)
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle
val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexMatrix(rawNativeHandle = result, scope = scope)
}
@ -272,7 +268,7 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size))
val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size))
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexVector(result, scope)
}
@ -309,12 +305,11 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
private val lups by lazy {
val lu = m.toGsl().copy()
val n = m.rowNum
val perm = gsl_permutation_alloc(n.toULong())
scope.defer { gsl_permutation_free(perm) }
val perm = GslPermutation(checkNotNull(gsl_permutation_alloc(n.toULong())), scope)
val signum = memScoped {
val i = alloc<IntVar>()
gsl_linalg_complex_LU_decomp(lu.nativeHandle, perm, i.ptr)
gsl_linalg_complex_LU_decomp(lu.nativeHandle, perm.nativeHandle, i.ptr)
i.value
}
@ -326,11 +321,8 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
val one = produce(n, n) { i, j -> if (i == j) 1.0.toComplex() else 0.0.toComplex() }
val perm = produce(n, n) { _, _ -> 0.0.toComplex() }
for (j in 0 until lups.second!!.pointed.size.toInt()) {
val k = gsl_permutation_get(lups.second!!, j.toULong()).toInt()
val col = one.columns[k]
gsl_matrix_complex_set_col(perm.nativeHandle, j.toULong(), col.toGsl().nativeHandle)
}
for (j in 0 until lups.second.size)
gsl_matrix_complex_set_col(perm.nativeHandle, j.toULong(), one.columns[lups.second[j]].toGsl().nativeHandle)
perm
}
@ -352,13 +344,11 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
) { i, j -> if (j >= i) lups.first[i, j] else 0.0.toComplex() } + UFeature
}
override val determinant by lazy {
gsl_linalg_complex_LU_det(lups.first.nativeHandle, lups.third).toKMath()
}
override val determinant by lazy { gsl_linalg_complex_LU_det(lups.first.nativeHandle, lups.third).toKMath() }
override val inverse by lazy {
val inv = lups.first.copy()
gsl_linalg_complex_LU_invx(inv.nativeHandle, lups.second)
gsl_linalg_complex_LU_invx(inv.nativeHandle, lups.second.nativeHandle)
inv
}
}

View File

@ -0,0 +1,16 @@
package kscience.kmath.gsl
import kotlinx.cinterop.AutofreeScope
import kotlinx.cinterop.CPointer
import kotlinx.cinterop.pointed
import org.gnu.gsl.gsl_permutation
import org.gnu.gsl.gsl_permutation_free
import org.gnu.gsl.gsl_permutation_get
internal class GslPermutation(override val rawNativeHandle: CPointer<gsl_permutation>, scope: AutofreeScope) :
GslMemoryHolder<gsl_permutation>(scope) {
val size get() = nativeHandle.pointed.size.toInt()
operator fun get(i: Int) = gsl_permutation_get(nativeHandle, i.toULong()).toInt()
override fun close() = gsl_permutation_free(nativeHandle)
}

View File

@ -37,7 +37,7 @@ internal class GslRealMatrix(
gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslRealMatrix {
val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_memcpy(new, nativeHandle)
return GslRealMatrix(new, scope)
}
@ -85,7 +85,7 @@ internal class GslFloatMatrix(
gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslFloatMatrix {
val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_float_memcpy(new, nativeHandle)
return GslFloatMatrix(new, scope)
}
@ -133,7 +133,7 @@ internal class GslShortMatrix(
gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslShortMatrix {
val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_short_memcpy(new, nativeHandle)
return GslShortMatrix(new, scope)
}
@ -181,7 +181,7 @@ internal class GslUShortMatrix(
gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslUShortMatrix {
val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_ushort_memcpy(new, nativeHandle)
return GslUShortMatrix(new, scope)
}
@ -229,7 +229,7 @@ internal class GslLongMatrix(
gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslLongMatrix {
val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_long_memcpy(new, nativeHandle)
return GslLongMatrix(new, scope)
}
@ -277,7 +277,7 @@ internal class GslULongMatrix(
gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslULongMatrix {
val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_ulong_memcpy(new, nativeHandle)
return GslULongMatrix(new, scope)
}
@ -325,7 +325,7 @@ internal class GslIntMatrix(
gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslIntMatrix {
val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_int_memcpy(new, nativeHandle)
return GslIntMatrix(new, scope)
}
@ -373,7 +373,7 @@ internal class GslUIntMatrix(
gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslUIntMatrix {
val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong()))
val new = checkNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_uint_memcpy(new, nativeHandle)
return GslUIntMatrix(new, scope)
}

View File

@ -15,7 +15,7 @@ internal class GslRealVector(override val rawNativeHandle: CPointer<gsl_vector>,
gsl_vector_set(nativeHandle, index.toULong(), value)
override fun copy(): GslRealVector {
val new = requireNotNull(gsl_vector_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_alloc(size.toULong()))
gsl_vector_memcpy(new, nativeHandle)
return GslRealVector(new, scope)
}
@ -42,7 +42,7 @@ internal class GslFloatVector(override val rawNativeHandle: CPointer<gsl_vector_
gsl_vector_float_set(nativeHandle, index.toULong(), value)
override fun copy(): GslFloatVector {
val new = requireNotNull(gsl_vector_float_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_float_alloc(size.toULong()))
gsl_vector_float_memcpy(new, nativeHandle)
return GslFloatVector(new, scope)
}
@ -69,7 +69,7 @@ internal class GslShortVector(override val rawNativeHandle: CPointer<gsl_vector_
gsl_vector_short_set(nativeHandle, index.toULong(), value)
override fun copy(): GslShortVector {
val new = requireNotNull(gsl_vector_short_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_short_alloc(size.toULong()))
gsl_vector_short_memcpy(new, nativeHandle)
return GslShortVector(new, scope)
}
@ -96,7 +96,7 @@ internal class GslUShortVector(override val rawNativeHandle: CPointer<gsl_vector
gsl_vector_ushort_set(nativeHandle, index.toULong(), value)
override fun copy(): GslUShortVector {
val new = requireNotNull(gsl_vector_ushort_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_ushort_alloc(size.toULong()))
gsl_vector_ushort_memcpy(new, nativeHandle)
return GslUShortVector(new, scope)
}
@ -123,7 +123,7 @@ internal class GslLongVector(override val rawNativeHandle: CPointer<gsl_vector_l
gsl_vector_long_set(nativeHandle, index.toULong(), value)
override fun copy(): GslLongVector {
val new = requireNotNull(gsl_vector_long_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_long_alloc(size.toULong()))
gsl_vector_long_memcpy(new, nativeHandle)
return GslLongVector(new, scope)
}
@ -150,7 +150,7 @@ internal class GslULongVector(override val rawNativeHandle: CPointer<gsl_vector_
gsl_vector_ulong_set(nativeHandle, index.toULong(), value)
override fun copy(): GslULongVector {
val new = requireNotNull(gsl_vector_ulong_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_ulong_alloc(size.toULong()))
gsl_vector_ulong_memcpy(new, nativeHandle)
return GslULongVector(new, scope)
}
@ -177,7 +177,7 @@ internal class GslIntVector(override val rawNativeHandle: CPointer<gsl_vector_in
gsl_vector_int_set(nativeHandle, index.toULong(), value)
override fun copy(): GslIntVector {
val new = requireNotNull(gsl_vector_int_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_int_alloc(size.toULong()))
gsl_vector_int_memcpy(new, nativeHandle)
return GslIntVector(new, scope)
}
@ -204,7 +204,7 @@ internal class GslUIntVector(override val rawNativeHandle: CPointer<gsl_vector_u
gsl_vector_uint_set(nativeHandle, index.toULong(), value)
override fun copy(): GslUIntVector {
val new = requireNotNull(gsl_vector_uint_alloc(size.toULong()))
val new = checkNotNull(gsl_vector_uint_alloc(size.toULong()))
gsl_vector_uint_memcpy(new, nativeHandle)
return GslUIntVector(new, scope)
}