Add QR decomposition support in GSL wrapper
This commit is contained in:
parent
e66f169655
commit
1cfe5a31a7
@ -12,7 +12,7 @@ kotlin {
|
||||
|
||||
val nativeTarget = when (System.getProperty("os.name")) {
|
||||
// "Mac OS X" -> macosX64()
|
||||
"Linux" -> linuxX64("native")
|
||||
"Linux" -> linuxX64()
|
||||
|
||||
else -> {
|
||||
logger.warn("Current OS cannot build any of kmath-gsl targets.")
|
||||
@ -29,7 +29,7 @@ kotlin {
|
||||
val test by nativeTarget.compilations.getting
|
||||
|
||||
sourceSets {
|
||||
val nativeMain by getting {
|
||||
val nativeMain by creating {
|
||||
val codegen by tasks.creating {
|
||||
matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
|
||||
vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
|
||||
@ -42,11 +42,11 @@ kotlin {
|
||||
}
|
||||
}
|
||||
|
||||
val nativeTest by getting {
|
||||
val nativeTest by creating {
|
||||
dependsOn(nativeMain)
|
||||
}
|
||||
|
||||
// main.defaultSourceSet.dependsOn(nativeMain)
|
||||
// test.defaultSourceSet.dependsOn(nativeTest)
|
||||
main.defaultSourceSet.dependsOn(nativeMain)
|
||||
test.defaultSourceSet.dependsOn(nativeTest)
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import kscience.kmath.operations.Complex
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kscience.kmath.operations.toComplex
|
||||
import org.gnu.gsl.*
|
||||
import kotlin.math.min
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
@ -179,6 +180,29 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
|
||||
}
|
||||
}
|
||||
|
||||
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
|
||||
private val qr by lazy {
|
||||
val a = m.toGsl().copy()
|
||||
val q = produce(m.rowNum, m.rowNum) { _, _ -> 0.0 }
|
||||
val r = produce(m.rowNum, m.colNum) { _, _ -> 0.0 }
|
||||
|
||||
if (m.rowNum < m.colNum) {
|
||||
val tau = point(min(m.rowNum, m.colNum)) { 0.0 }
|
||||
gsl_linalg_QR_decomp(a.nativeHandle, tau.nativeHandle)
|
||||
gsl_linalg_QR_unpack(a.nativeHandle, tau.nativeHandle, q.nativeHandle, r.nativeHandle)
|
||||
} else {
|
||||
val t = produce(m.colNum, m.colNum) { _, _ -> 0.0 }
|
||||
gsl_linalg_QR_decomp_r(a.nativeHandle, t.nativeHandle)
|
||||
gsl_linalg_QR_unpack_r(a.nativeHandle, t.nativeHandle, q.nativeHandle, r.nativeHandle)
|
||||
}
|
||||
|
||||
q to r
|
||||
}
|
||||
|
||||
override val q: Matrix<Double> get() = qr.first
|
||||
override val r: Matrix<Double> get() = qr.second
|
||||
}
|
||||
|
||||
else -> super.getFeature(m, type)
|
||||
}?.let(type::cast)
|
||||
}
|
||||
@ -261,17 +285,22 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
|
||||
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix(
|
||||
rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
|
||||
scope = scope,
|
||||
owned = true,
|
||||
)
|
||||
|
||||
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
|
||||
GslComplexVector(rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope)
|
||||
GslComplexVector(
|
||||
rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())),
|
||||
scope = scope,
|
||||
owned = true,
|
||||
)
|
||||
|
||||
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
val x = toGsl().nativeHandle
|
||||
val a = other.toGsl().nativeHandle
|
||||
val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
|
||||
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||
return GslComplexMatrix(rawNativeHandle = result, scope = scope)
|
||||
return GslComplexMatrix(rawNativeHandle = result, scope = scope, owned = true)
|
||||
}
|
||||
|
||||
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
|
||||
@ -279,7 +308,7 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
|
||||
val a = vector.toGsl().nativeHandle
|
||||
val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size))
|
||||
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||
return GslComplexVector(result, scope)
|
||||
return GslComplexVector(rawNativeHandle = result, scope = scope, owned = true)
|
||||
}
|
||||
|
||||
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
|
Loading…
Reference in New Issue
Block a user