Rework codegen, use GSL vectors to store vectors, implement MatrixContext for Float, Double and Complex matrices with BLAS

This commit is contained in:
Iaroslav Postovalov 2020-10-11 01:15:37 +07:00
parent 69a107819b
commit d46350e7b7
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
15 changed files with 218 additions and 436 deletions

View File

@ -6,8 +6,12 @@ import org.jetbrains.kotlin.com.intellij.psi.impl.DebugUtil
import org.jetbrains.kotlin.com.intellij.psi.impl.source.PsiFileImpl import org.jetbrains.kotlin.com.intellij.psi.impl.source.PsiFileImpl
import org.jetbrains.kotlin.com.intellij.testFramework.LightVirtualFile import org.jetbrains.kotlin.com.intellij.testFramework.LightVirtualFile
import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtFile
import java.util.regex.Pattern
import kotlin.math.min import kotlin.math.min
private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+")
internal fun splitByLinesDontTrim(string: String): Array<String> = EOL_SPLIT_DONT_TRIM_PATTERN.split(string)
internal object PsiTestUtil { internal object PsiTestUtil {
fun checkFileStructure(file: KtFile) { fun checkFileStructure(file: KtFile) {
compareFromAllRoots(file) { f -> DebugUtil.psiTreeToString(f, false) } compareFromAllRoots(file) { f -> DebugUtil.psiTreeToString(f, false) }

View File

@ -1,9 +0,0 @@
package kscience.kmath.gsl.codegen
import java.util.regex.Pattern
private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+")
internal fun splitByLinesDontTrim(string: String): Array<String> {
return EOL_SPLIT_DONT_TRIM_PATTERN.split(string)
}

View File

@ -22,10 +22,10 @@ private fun KtPsiFactory.createMatrixClass(
features: Set<MatrixFeature> = emptySet() features: Set<MatrixFeature> = emptySet()
) : GslMatrix<$kotlinTypeName, $structName>() { ) : GslMatrix<$kotlinTypeName, $structName>() {
override val rowNum: Int override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } get() = nativeHandle.pointed.size1.toInt()
override val colNum: Int override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } get() = nativeHandle.pointed.size2.toInt()
override val features: Set<MatrixFeature> = features override val features: Set<MatrixFeature> = features
@ -39,30 +39,26 @@ private fun KtPsiFactory.createMatrixClass(
override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit = override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit =
${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value) ${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): $className = memScoped { override fun copy(): $className {
val new = requireNotNull(${fn("gsl_matrixRalloc", cTypeName)}(rowNum.toULong(), colNum.toULong())) val new = requireNotNull(${fn("gsl_matrixRalloc", cTypeName)}(rowNum.toULong(), colNum.toULong()))
${fn("gsl_matrixRmemcpy", cTypeName)}(new, nativeHandle) ${fn("gsl_matrixRmemcpy", cTypeName)}(new, nativeHandle)
$className(new, features) return $className(new, features)
} }
override fun close(): Unit = ${fn("gsl_matrixRfree", cTypeName)}(nativeHandle) override fun close(): Unit = ${fn("gsl_matrixRfree", cTypeName)}(nativeHandle)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (other is $className) ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) if (other is $className) return ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1
return super.equals(other) return super.equals(other)
} }
}""" }"""
f += createClass( f += createClass(text)
text
)
f += createNewLine(2) f += createNewLine(2)
} }
fun matricesCodegen(outputFile: String, project: Project = createProject()) { fun matricesCodegen(outputFile: String, project: Project = createProject()) {
val f = KtPsiFactory(project, true).run { val f = KtPsiFactory(project, true).run {
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> createFile("").also { f ->
f += createNewLine(2)
f += createPackageDirective(FqName("kscience.kmath.gsl")) f += createPackageDirective(FqName("kscience.kmath.gsl"))
f += createNewLine(2) f += createNewLine(2)
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))

View File

@ -20,22 +20,34 @@ private fun KtPsiFactory.createVectorClass(
@Language("kotlin") val text = @Language("kotlin") val text =
"""internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() { """internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() {
override val size: Int override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } get() = nativeHandle.pointed.size.toInt()
override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong()) override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong())
override fun set(index: Int, value: $kotlinTypeName): Unit = ${
fn("gsl_vectorRset", cTypeName)
}(nativeHandle, index.toULong(), value)
override fun copy(): $className {
val new = requireNotNull(${fn("gsl_vectorRalloc", cTypeName)}(size.toULong()))
${fn("gsl_vectorRmemcpy", cTypeName)}(new, nativeHandle)
return ${className}(new)
}
override fun equals(other: Any?): Boolean {
if (other is $className) return ${fn("gsl_vectorRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1
return super.equals(other)
}
override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle) override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle)
}""" }"""
f += createClass(
text
)
f += createClass(text)
f += createNewLine(2) f += createNewLine(2)
} }
fun vectorsCodegen(outputFile: String, project: Project = createProject()) { fun vectorsCodegen(outputFile: String, project: Project = createProject()) {
val f = KtPsiFactory(project, true).run { val f = KtPsiFactory(project, true).run {
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> createFile("").also { f ->
f += createNewLine(2)
f += createPackageDirective(FqName("kscience.kmath.gsl")) f += createPackageDirective(FqName("kscience.kmath.gsl"))
f += createNewLine(2) f += createNewLine(2)
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))

View File

@ -55,7 +55,7 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
* Multiplies an element by a matrix of it. * Multiplies an element by a matrix of it.
* *
* @receiver the multiplicand. * @receiver the multiplicand.
* @param value the multiplier. * @param m the multiplier.
* @receiver the product. * @receiver the product.
*/ */
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this

View File

@ -38,8 +38,8 @@ kotlin {
} }
internal val codegen: Task by tasks.creating { internal val codegen: Task by tasks.creating {
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt") matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt") vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
} }
kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen)) kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen))

View File

@ -1,280 +0,0 @@
@file:Suppress("PackageDirectoryMismatch")
package kscience.kmath.gsl
import kotlinx.cinterop.*
import kscience.kmath.linear.*
import org.gnu.gsl.*
internal class GslRealMatrix(
override val nativeHandle: CPointer<gsl_matrix>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Double, gsl_matrix>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix =
GslRealMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: Double): Unit =
gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslRealMatrix = memScoped {
val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_memcpy(new, nativeHandle)
GslRealMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslRealMatrix) gsl_matrix_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslFloatMatrix(
override val nativeHandle: CPointer<gsl_matrix_float>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Float, gsl_matrix_float>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix =
GslFloatMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Float = gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: Float): Unit =
gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslFloatMatrix = memScoped {
val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_float_memcpy(new, nativeHandle)
GslFloatMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_float_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslFloatMatrix) gsl_matrix_float_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslShortMatrix(
override val nativeHandle: CPointer<gsl_matrix_short>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Short, gsl_matrix_short>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix =
GslShortMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Short = gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: Short): Unit =
gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslShortMatrix = memScoped {
val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_short_memcpy(new, nativeHandle)
GslShortMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_short_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslShortMatrix) gsl_matrix_short_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslUShortMatrix(
override val nativeHandle: CPointer<gsl_matrix_ushort>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<UShort, gsl_matrix_ushort>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix =
GslUShortMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): UShort = gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: UShort): Unit =
gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslUShortMatrix = memScoped {
val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_ushort_memcpy(new, nativeHandle)
GslUShortMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_ushort_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslUShortMatrix) gsl_matrix_ushort_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslLongMatrix(
override val nativeHandle: CPointer<gsl_matrix_long>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Long, gsl_matrix_long>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix =
GslLongMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: Long): Unit =
gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslLongMatrix = memScoped {
val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_long_memcpy(new, nativeHandle)
GslLongMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_long_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslLongMatrix) gsl_matrix_long_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslULongMatrix(
override val nativeHandle: CPointer<gsl_matrix_ulong>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<ULong, gsl_matrix_ulong>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix =
GslULongMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): ULong = gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: ULong): Unit =
gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslULongMatrix = memScoped {
val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_ulong_memcpy(new, nativeHandle)
GslULongMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_ulong_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslULongMatrix) gsl_matrix_ulong_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslIntMatrix(
override val nativeHandle: CPointer<gsl_matrix_int>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Int, gsl_matrix_int>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix =
GslIntMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: Int): Unit =
gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslIntMatrix = memScoped {
val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_int_memcpy(new, nativeHandle)
GslIntMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_int_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslIntMatrix) gsl_matrix_int_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslUIntMatrix(
override val nativeHandle: CPointer<gsl_matrix_uint>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<UInt, gsl_matrix_uint>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix =
GslUIntMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong())
override operator fun set(i: Int, j: Int, value: UInt): Unit =
gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value)
override fun copy(): GslUIntMatrix = memScoped {
val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_uint_memcpy(new, nativeHandle)
GslUIntMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_uint_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslUIntMatrix) gsl_matrix_uint_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}

View File

@ -17,10 +17,10 @@ internal class GslComplexMatrix(
features: Set<MatrixFeature> = emptySet() features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Complex, gsl_matrix_complex>() { ) : GslMatrix<Complex, gsl_matrix_complex>() {
override val rowNum: Int override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } get() = nativeHandle.pointed.size1.toInt()
override val colNum: Int override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } get() = nativeHandle.pointed.size2.toInt()
override val features: Set<MatrixFeature> = features override val features: Set<MatrixFeature> = features
@ -33,16 +33,16 @@ internal class GslComplexMatrix(
override operator fun set(i: Int, j: Int, value: Complex): Unit = override operator fun set(i: Int, j: Int, value: Complex): Unit =
gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl()) gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl())
override fun copy(): GslComplexMatrix = memScoped { override fun copy(): GslComplexMatrix {
val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong())) val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_complex_memcpy(new, nativeHandle) gsl_matrix_complex_memcpy(new, nativeHandle)
GslComplexMatrix(new, features) return GslComplexMatrix(new, features)
} }
override fun close(): Unit = gsl_matrix_complex_free(nativeHandle) override fun close(): Unit = gsl_matrix_complex_free(nativeHandle)
override fun equals(other: Any?): Boolean { override fun equals(other: Any?): Boolean {
if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) if (other is GslComplexMatrix) return gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) == 1
return super.equals(other) return super.equals(other)
} }
} }
@ -50,8 +50,21 @@ internal class GslComplexMatrix(
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) : internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
GslVector<Complex, gsl_vector_complex>() { GslVector<Complex, gsl_vector_complex>() {
override val size: Int override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } get() = nativeHandle.pointed.size.toInt()
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath() override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
override fun set(index: Int, value: Complex): Unit = gsl_vector_complex_set(nativeHandle, index.toULong(), value.toGsl())
override fun copy(): GslComplexVector {
val new = requireNotNull(gsl_vector_complex_alloc(size.toULong()))
gsl_vector_complex_memcpy(new, nativeHandle)
return GslComplexVector(new)
}
override fun equals(other: Any?): Boolean {
if (other is GslComplexVector) return gsl_vector_complex_equal(nativeHandle, other.nativeHandle) == 1
return super.equals(other)
}
override fun close(): Unit = gsl_vector_complex_free(nativeHandle) override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
} }

View File

@ -13,7 +13,7 @@ public abstract class GslMatrix<T : Any, H : CStructVar> internal constructor():
return NDStructure.equals(this, other as? NDStructure<*> ?: return false) return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
} }
public override fun hashCode(): Int { public final override fun hashCode(): Int {
var result = nativeHandle.hashCode() var result = nativeHandle.hashCode()
result = 31 * result + features.hashCode() result = 31 * result + features.hashCode()
return result return result

View File

@ -0,0 +1,7 @@
package kscience.kmath.gsl
import kotlinx.cinterop.CStructVar
import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.Point
import kscience.kmath.structures.Matrix

View File

@ -1,40 +1,63 @@
package kscience.kmath.gsl package kscience.kmath.gsl
import kotlinx.cinterop.CStructVar import kotlinx.cinterop.CStructVar
import kotlinx.cinterop.pointed
import kscience.kmath.linear.MatrixContext import kscience.kmath.linear.MatrixContext
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
import kscience.kmath.operations.invoke import kscience.kmath.operations.Complex
import kscience.kmath.operations.ComplexField
import kscience.kmath.operations.toComplex
import kscience.kmath.structures.Matrix import kscience.kmath.structures.Matrix
import org.gnu.gsl.* import org.gnu.gsl.*
private inline fun <T : Any, H : CStructVar> GslMatrix<T, H>.fill(initializer: (Int, Int) -> T): GslMatrix<T, H> = internal inline fun <T : Any, H : CStructVar> GslMatrix<T, H>.fill(initializer: (Int, Int) -> T): GslMatrix<T, H> =
apply { apply {
(0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } } (0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } }
} }
public sealed class GslMatrixContext<T : Any, H : CStructVar> : MatrixContext<T> { internal inline fun <T : Any, H : CStructVar> GslVector<T, H>.fill(initializer: (Int) -> T): GslVector<T, H> =
apply { (0 until size).forEach { index -> this[index] = initializer(index) } }
public abstract class GslMatrixContext<T : Any, H1 : CStructVar, H2 : CStructVar> internal constructor() :
MatrixContext<T> {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public fun Matrix<T>.toGsl(): GslMatrix<T, H> = public fun Matrix<T>.toGsl(): GslMatrix<T, H1> = (if (this is GslMatrix<*, *>)
(if (this is GslMatrix<*, *>) this as GslMatrix<T, H> else produce(rowNum, colNum) { i, j -> get(i, j) }).copy() this as GslMatrix<T, H1>
else
produce(rowNum, colNum) { i, j -> this[i, j] }).copy()
internal abstract fun produceDirty(rows: Int, columns: Int): GslMatrix<T, H> @Suppress("UNCHECKED_CAST")
public fun Point<T>.toGsl(): GslVector<T, H2> =
(if (this is GslVector<*, *>) this as GslVector<T, H2> else produceDirtyVector(size).fill { this[it] }).copy()
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix<T, H> = internal abstract fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<T, H1>
produceDirty(rows, columns).fill(initializer) internal abstract fun produceDirtyVector(size: Int): GslVector<T, H2>
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix<T, H1> =
produceDirtyMatrix(rows, columns).fill(initializer)
} }
public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix>() { public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix, gsl_vector>() {
public override fun produceDirty(rows: Int, columns: Int): GslMatrix<Double, gsl_matrix> = override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Double, gsl_matrix> =
GslRealMatrix(requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong()))) GslRealMatrix(requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())))
override fun produceDirtyVector(size: Int): GslVector<Double, gsl_vector> =
GslRealVector(requireNotNull(gsl_vector_alloc(size.toULong())))
public override fun Matrix<Double>.dot(other: Matrix<Double>): GslMatrix<Double, gsl_matrix> { public override fun Matrix<Double>.dot(other: Matrix<Double>): GslMatrix<Double, gsl_matrix> {
val g1 = toGsl() val x = toGsl().nativeHandle
gsl_matrix_mul_elements(g1.nativeHandle, other.toGsl().nativeHandle) val a = other.toGsl().nativeHandle
return g1 val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result)
return GslRealMatrix(result)
} }
public override fun Matrix<Double>.dot(vector: Point<Double>): GslVector<Double, gsl_vector> { public override fun Matrix<Double>.dot(vector: Point<Double>): GslVector<Double, gsl_vector> {
TODO() val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_calloc(a.pointed.size))
gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result)
return GslRealVector(result)
} }
public override fun Matrix<Double>.times(value: Double): GslMatrix<Double, gsl_matrix> { public override fun Matrix<Double>.times(value: Double): GslMatrix<Double, gsl_matrix> {
@ -55,3 +78,87 @@ public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix>() {
return g1 return g1
} }
} }
public object GslFloatMatrixContext : GslMatrixContext<Float, gsl_matrix_float, gsl_vector_float>() {
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Float, gsl_matrix_float> =
GslFloatMatrix(requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())))
override fun produceDirtyVector(size: Int): GslVector<Float, gsl_vector_float> =
GslFloatVector(requireNotNull(gsl_vector_float_alloc(size.toULong())))
public override fun Matrix<Float>.dot(other: Matrix<Float>): GslMatrix<Float, gsl_matrix_float> {
val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle
val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result)
return GslFloatMatrix(result)
}
public override fun Matrix<Float>.dot(vector: Point<Float>): GslVector<Float, gsl_vector_float> {
val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size))
gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result)
return GslFloatVector(result)
}
public override fun Matrix<Float>.times(value: Float): GslMatrix<Float, gsl_matrix_float> {
val g1 = toGsl()
gsl_matrix_float_scale(g1.nativeHandle, value.toDouble())
return g1
}
public override fun add(a: Matrix<Float>, b: Matrix<Float>): GslMatrix<Float, gsl_matrix_float> {
val g1 = a.toGsl()
gsl_matrix_float_add(g1.nativeHandle, b.toGsl().nativeHandle)
return g1
}
public override fun multiply(a: Matrix<Float>, k: Number): GslMatrix<Float, gsl_matrix_float> {
val g1 = a.toGsl()
gsl_matrix_float_scale(g1.nativeHandle, k.toDouble())
return g1
}
}
public object GslComplexMatrixContext : GslMatrixContext<Complex, gsl_matrix_complex, gsl_vector_complex>() {
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> =
GslComplexMatrix(requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())))
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
GslComplexVector(requireNotNull(gsl_vector_complex_alloc(size.toULong())))
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
val x = toGsl().nativeHandle
val a = other.toGsl().nativeHandle
val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexMatrix(result)
}
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
val x = toGsl().nativeHandle
val a = vector.toGsl().nativeHandle
val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size))
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
return GslComplexVector(result)
}
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {
val g1 = toGsl()
gsl_matrix_complex_scale(g1.nativeHandle, value.toGsl())
return g1
}
public override fun add(a: Matrix<Complex>, b: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
val g1 = a.toGsl()
gsl_matrix_complex_add(g1.nativeHandle, b.toGsl().nativeHandle)
return g1
}
public override fun multiply(a: Matrix<Complex>, k: Number): GslMatrix<Complex, gsl_matrix_complex> {
val g1 = a.toGsl()
gsl_matrix_complex_scale(g1.nativeHandle, k.toComplex().toGsl())
return g1
}
}

View File

@ -3,8 +3,11 @@ package kscience.kmath.gsl
import kotlinx.cinterop.CStructVar import kotlinx.cinterop.CStructVar
import kscience.kmath.linear.Point import kscience.kmath.linear.Point
abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> { public abstract class GslVector<T, H : CStructVar> internal constructor() : GslMemoryHolder<H>(), Point<T> {
public override fun iterator(): Iterator<T> = object : Iterator<T> { internal abstract operator fun set(index: Int, value: T)
internal abstract fun copy(): GslVector<T, H>
public final override fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor = 0 private var cursor = 0
override fun hasNext(): Boolean = cursor < size override fun hasNext(): Boolean = cursor < size

View File

@ -1,101 +0,0 @@
package kscience.kmath.gsl
import kotlinx.cinterop.CPointer
import kotlinx.cinterop.CStructVar
import kotlinx.cinterop.memScoped
import kotlinx.cinterop.pointed
import kscience.kmath.linear.Point
import kscience.kmath.operations.Complex
import org.gnu.gsl.*
public abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> {
public override fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor = 0
override fun hasNext(): Boolean = cursor < size
override fun next(): T {
cursor++
return this@GslVector[cursor - 1]
}
}
}
internal class GslRealVector(override val nativeHandle: CPointer<gsl_vector>) : GslVector<Double, gsl_vector>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Double = gsl_vector_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_free(nativeHandle)
}
internal class GslFloatVector(override val nativeHandle: CPointer<gsl_vector_float>) :
GslVector<Float, gsl_vector_float>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Float = gsl_vector_float_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_float_free(nativeHandle)
}
internal class GslIntVector(override val nativeHandle: CPointer<gsl_vector_int>) : GslVector<Int, gsl_vector_int>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Int = gsl_vector_int_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_int_free(nativeHandle)
}
internal class GslUIntVector(override val nativeHandle: CPointer<gsl_vector_uint>) :
GslVector<UInt, gsl_vector_uint>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): UInt = gsl_vector_uint_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_uint_free(nativeHandle)
}
internal class GslLongVector(override val nativeHandle: CPointer<gsl_vector_long>) :
GslVector<Long, gsl_vector_long>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Long = gsl_vector_long_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_long_free(nativeHandle)
}
internal class GslULongVector(override val nativeHandle: CPointer<gsl_vector_ulong>) :
GslVector<ULong, gsl_vector_ulong>() {
public override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
public override fun get(index: Int): ULong = gsl_vector_ulong_get(nativeHandle, index.toULong())
public override fun close(): Unit = gsl_vector_ulong_free(nativeHandle)
}
internal class GslShortVector(override val nativeHandle: CPointer<gsl_vector_short>) :
GslVector<Short, gsl_vector_short>() {
public override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
public override fun get(index: Int): Short = gsl_vector_short_get(nativeHandle, index.toULong())
public override fun close(): Unit = gsl_vector_short_free(nativeHandle)
}
internal class GslUShortVector(override val nativeHandle: CPointer<gsl_vector_ushort>) :
GslVector<UShort, gsl_vector_ushort>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): UShort = gsl_vector_ushort_get(nativeHandle, index.toULong())
override fun close(): Unit = gsl_vector_ushort_free(nativeHandle)
}
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
GslVector<Complex, gsl_vector_complex>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
}

View File

@ -1,8 +1,13 @@
package kscience.kmath.gsl package kscience.kmath.gsl
import kscience.kmath.linear.RealMatrixContext
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import kscience.kmath.structures.RealBuffer
import kscience.kmath.structures.asIterable
import kscience.kmath.structures.asSequence
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal class RealTest { internal class RealTest {
@Test @Test
@ -13,4 +18,29 @@ internal class RealTest {
mb.close() mb.close()
ma.close() ma.close()
} }
@Test
fun testDotOfMatrixAndVector() {
val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
val vb = RealBuffer(2) { 0.1 }
val res1 = GslRealMatrixContext { ma dot vb }
val res2 = RealMatrixContext { ma dot vb }
println(res1.asSequence().toList())
println(res2.asSequence().toList())
assertTrue(res1.contentEquals(res2))
res1.close()
}
@Test
fun testDotOfMatrixAndMatrix() {
val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
val mb = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
val res1 = GslRealMatrixContext { ma dot mb }
val res2 = RealMatrixContext { ma dot mb }
println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
assertEquals(res1, res2)
ma.close()
mb.close()
}
} }