Add QR decomposition support in GSL wrapper

This commit is contained in:
Iaroslav Postovalov 2021-01-31 00:57:15 +07:00
parent e66f169655
commit 1cfe5a31a7
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
2 changed files with 37 additions and 8 deletions

View File

@ -12,7 +12,7 @@ kotlin {
val nativeTarget = when (System.getProperty("os.name")) { val nativeTarget = when (System.getProperty("os.name")) {
// "Mac OS X" -> macosX64() // "Mac OS X" -> macosX64()
"Linux" -> linuxX64("native") "Linux" -> linuxX64()
else -> { else -> {
logger.warn("Current OS cannot build any of kmath-gsl targets.") logger.warn("Current OS cannot build any of kmath-gsl targets.")
@ -29,7 +29,7 @@ kotlin {
val test by nativeTarget.compilations.getting val test by nativeTarget.compilations.getting
sourceSets { sourceSets {
val nativeMain by getting { val nativeMain by creating {
val codegen by tasks.creating { val codegen by tasks.creating {
matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt") matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt") vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
@ -42,11 +42,11 @@ kotlin {
} }
} }
val nativeTest by getting { val nativeTest by creating {
dependsOn(nativeMain) dependsOn(nativeMain)
} }
// main.defaultSourceSet.dependsOn(nativeMain) main.defaultSourceSet.dependsOn(nativeMain)
// test.defaultSourceSet.dependsOn(nativeTest) test.defaultSourceSet.dependsOn(nativeTest)
} }
} }

View File

@ -7,6 +7,7 @@ import kscience.kmath.operations.Complex
import kscience.kmath.operations.ComplexField import kscience.kmath.operations.ComplexField
import kscience.kmath.operations.toComplex import kscience.kmath.operations.toComplex
import org.gnu.gsl.* import org.gnu.gsl.*
import kotlin.math.min
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.reflect.cast import kotlin.reflect.cast
@ -179,6 +180,29 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
} }
} }
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
private val qr by lazy {
val a = m.toGsl().copy()
val q = produce(m.rowNum, m.rowNum) { _, _ -> 0.0 }
val r = produce(m.rowNum, m.colNum) { _, _ -> 0.0 }
if (m.rowNum < m.colNum) {
val tau = point(min(m.rowNum, m.colNum)) { 0.0 }
gsl_linalg_QR_decomp(a.nativeHandle, tau.nativeHandle)
gsl_linalg_QR_unpack(a.nativeHandle, tau.nativeHandle, q.nativeHandle, r.nativeHandle)
} else {
val t = produce(m.colNum, m.colNum) { _, _ -> 0.0 }
gsl_linalg_QR_decomp_r(a.nativeHandle, t.nativeHandle)
gsl_linalg_QR_unpack_r(a.nativeHandle, t.nativeHandle, q.nativeHandle, r.nativeHandle)
}
q to r
}
override val q: Matrix<Double> get() = qr.first
override val r: Matrix<Double> get() = qr.second
}
else -> super.getFeature(m, type) else -> super.getFeature(m, type)
}?.let(type::cast) }?.let(type::cast)
} }
@ -261,17 +285,22 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix( override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix(
rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())), rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
scope = scope, scope = scope,
owned = true,
) )
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> = override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
GslComplexVector(rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope) GslComplexVector(
rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())),
scope = scope,
owned = true,
)
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> { public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
val x = toGsl().nativeHandle val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle val a = other.toGsl().nativeHandle
val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2)) val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexMatrix(rawNativeHandle = result, scope = scope) return GslComplexMatrix(rawNativeHandle = result, scope = scope, owned = true)
} }
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> { public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
@ -279,7 +308,7 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
val a = vector.toGsl().nativeHandle val a = vector.toGsl().nativeHandle
val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size)) val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size))
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result) gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexVector(result, scope) return GslComplexVector(rawNativeHandle = result, scope = scope, owned = true)
} }
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> { public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {