Add QR decomposition support in GSL wrapper
This commit is contained in:
parent
e66f169655
commit
1cfe5a31a7
@ -12,7 +12,7 @@ kotlin {
|
|||||||
|
|
||||||
val nativeTarget = when (System.getProperty("os.name")) {
|
val nativeTarget = when (System.getProperty("os.name")) {
|
||||||
// "Mac OS X" -> macosX64()
|
// "Mac OS X" -> macosX64()
|
||||||
"Linux" -> linuxX64("native")
|
"Linux" -> linuxX64()
|
||||||
|
|
||||||
else -> {
|
else -> {
|
||||||
logger.warn("Current OS cannot build any of kmath-gsl targets.")
|
logger.warn("Current OS cannot build any of kmath-gsl targets.")
|
||||||
@ -29,7 +29,7 @@ kotlin {
|
|||||||
val test by nativeTarget.compilations.getting
|
val test by nativeTarget.compilations.getting
|
||||||
|
|
||||||
sourceSets {
|
sourceSets {
|
||||||
val nativeMain by getting {
|
val nativeMain by creating {
|
||||||
val codegen by tasks.creating {
|
val codegen by tasks.creating {
|
||||||
matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
|
matricesCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
|
||||||
vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
|
vectorsCodegen(kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
|
||||||
@ -42,11 +42,11 @@ kotlin {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val nativeTest by getting {
|
val nativeTest by creating {
|
||||||
dependsOn(nativeMain)
|
dependsOn(nativeMain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// main.defaultSourceSet.dependsOn(nativeMain)
|
main.defaultSourceSet.dependsOn(nativeMain)
|
||||||
// test.defaultSourceSet.dependsOn(nativeTest)
|
test.defaultSourceSet.dependsOn(nativeTest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import kscience.kmath.operations.Complex
|
|||||||
import kscience.kmath.operations.ComplexField
|
import kscience.kmath.operations.ComplexField
|
||||||
import kscience.kmath.operations.toComplex
|
import kscience.kmath.operations.toComplex
|
||||||
import org.gnu.gsl.*
|
import org.gnu.gsl.*
|
||||||
|
import kotlin.math.min
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
import kotlin.reflect.cast
|
import kotlin.reflect.cast
|
||||||
|
|
||||||
@ -179,6 +180,29 @@ public class GslRealMatrixContext(internal val scope: AutofreeScope) :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
|
||||||
|
private val qr by lazy {
|
||||||
|
val a = m.toGsl().copy()
|
||||||
|
val q = produce(m.rowNum, m.rowNum) { _, _ -> 0.0 }
|
||||||
|
val r = produce(m.rowNum, m.colNum) { _, _ -> 0.0 }
|
||||||
|
|
||||||
|
if (m.rowNum < m.colNum) {
|
||||||
|
val tau = point(min(m.rowNum, m.colNum)) { 0.0 }
|
||||||
|
gsl_linalg_QR_decomp(a.nativeHandle, tau.nativeHandle)
|
||||||
|
gsl_linalg_QR_unpack(a.nativeHandle, tau.nativeHandle, q.nativeHandle, r.nativeHandle)
|
||||||
|
} else {
|
||||||
|
val t = produce(m.colNum, m.colNum) { _, _ -> 0.0 }
|
||||||
|
gsl_linalg_QR_decomp_r(a.nativeHandle, t.nativeHandle)
|
||||||
|
gsl_linalg_QR_unpack_r(a.nativeHandle, t.nativeHandle, q.nativeHandle, r.nativeHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
q to r
|
||||||
|
}
|
||||||
|
|
||||||
|
override val q: Matrix<Double> get() = qr.first
|
||||||
|
override val r: Matrix<Double> get() = qr.second
|
||||||
|
}
|
||||||
|
|
||||||
else -> super.getFeature(m, type)
|
else -> super.getFeature(m, type)
|
||||||
}?.let(type::cast)
|
}?.let(type::cast)
|
||||||
}
|
}
|
||||||
@ -261,17 +285,22 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
|
|||||||
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix(
|
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> = GslComplexMatrix(
|
||||||
rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
|
rawNativeHandle = checkNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())),
|
||||||
scope = scope,
|
scope = scope,
|
||||||
|
owned = true,
|
||||||
)
|
)
|
||||||
|
|
||||||
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
|
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
|
||||||
GslComplexVector(rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())), scope = scope)
|
GslComplexVector(
|
||||||
|
rawNativeHandle = checkNotNull(gsl_vector_complex_alloc(size.toULong())),
|
||||||
|
scope = scope,
|
||||||
|
owned = true,
|
||||||
|
)
|
||||||
|
|
||||||
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
|
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
|
||||||
val x = toGsl().nativeHandle
|
val x = toGsl().nativeHandle
|
||||||
val a = other.toGsl().nativeHandle
|
val a = other.toGsl().nativeHandle
|
||||||
val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
|
val result = checkNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
|
||||||
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||||
return GslComplexMatrix(rawNativeHandle = result, scope = scope)
|
return GslComplexMatrix(rawNativeHandle = result, scope = scope, owned = true)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
|
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
|
||||||
@ -279,7 +308,7 @@ public class GslComplexMatrixContext(internal val scope: AutofreeScope) :
|
|||||||
val a = vector.toGsl().nativeHandle
|
val a = vector.toGsl().nativeHandle
|
||||||
val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size))
|
val result = checkNotNull(gsl_vector_complex_calloc(a.pointed.size))
|
||||||
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||||
return GslComplexVector(result, scope)
|
return GslComplexVector(rawNativeHandle = result, scope = scope, owned = true)
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {
|
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {
|
||||||
|
Loading…
Reference in New Issue
Block a user