From 1cfe5a31a77398e9a1f3511017a94cf2a3a1290a Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Sun, 31 Jan 2021 00:57:15 +0700 Subject: [PATCH] Add QR decomposition support in GSL wrapper --- kmath-gsl/build.gradle.kts | 10 +++--- .../kscience/kmath/gsl/GslMatrixContext.kt | 35 +++++++++++++++++-- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index 83859130c..df405326f 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -12,7 +12,7 @@ kotlin { val nativeTarget = when (System.getProperty("os.name")) { // "Mac OS X" -> macosX64() - "Linux" -> linuxX64("native") + "Linux" -> linuxX64() else -> { logger.warn("Current OS cannot build any of kmath-gsl targets.") @@ -29,7 +29,7 @@ kotlin { val test by nativeTarget.compilations.getting sourceSets { - val nativeMain by getting { + val nativeMain by creating { val codegen by tasks.creating { matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt") vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt") @@ -42,11 +42,11 @@ kotlin { } } - val nativeTest by getting { + val nativeTest by creating { dependsOn(nativeMain) } -// main.defaultSourceSet.dependsOn(nativeMain) -// test.defaultSourceSet.dependsOn(nativeTest) + main.defaultSourceSet.dependsOn(nativeMain) + test.defaultSourceSet.dependsOn(nativeTest) } } diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt index e86d3676e..cc8bba662 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslMatrixContext.kt @@ -7,6 +7,7 @@ import kscience.kmath.operations.Complex import kscience.kmath.operations.ComplexField import kscience.kmath.operations.toComplex import org.gnu.gsl.* +import kotlin.math.min import kotlin.reflect.KClass import kotlin.reflect.cast @@ -179,6 +180,29 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) : } } + QRDecompositionFeature::class -> object : QRDecompositionFeature { + private val qr by lazy { + val a = m.toGsl().copy() + val q = produce(m.rowNum, m.rowNum) { _, _ -> 0.0 } + val r = produce(m.rowNum, m.colNum) { _, _ -> 0.0 } + + if (m.rowNum < m.colNum) { + val tau = point(min(m.rowNum, m.colNum)) { 0.0 } + gsl_linalg_QR_decomp(a.nativeHandle, tau.nativeHandle) + gsl_linalg_QR_unpack(a.nativeHandle, tau.nativeHandle, q.nativeHandle, r.nativeHandle) + } else { + val t = produce(m.colNum, m.colNum) { _, _ -> 0.0 } + gsl_linalg_QR_decomp_r(a.nativeHandle, t.nativeHandle) + gsl_linalg_QR_unpack_r(a.nativeHandle, t.nativeHandle, q.nativeHandle, r.nativeHandle) + } + + q to r + } + + override val q: Matrix get() = qr.first + override val r: Matrix get() = qr.second + } + else -> super.getFeature(m, type) }?.let(type::cast) } @@ -261,17 +285,22 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) : override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix = GslComplexMatrix( rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())), scope = scope, + owned = true, ) override fun produceDirtyVector(size: Int): GslVector = - GslComplexVector(rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope) + GslComplexVector( + rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), + scope = scope, + owned = true, + ) public override fun Matrix.dot(other: Matrix): GslMatrix { val x = toGsl().nativeHandle val a = other.toGsl().nativeHandle val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2)) gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) - return GslComplexMatrix(rawNativeHandle = result, scope = scope) + return GslComplexMatrix(rawNativeHandle = result, scope = scope, owned = true) } public override fun Matrix.dot(vector: Point): GslVector { @@ -279,7 +308,7 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) : val a = vector.toGsl().nativeHandle val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size)) gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) - return GslComplexVector(result, scope) + return GslComplexVector(rawNativeHandle = result, scope = scope, owned = true) } public override fun Matrix.times(value: Complex): GslMatrix {