forked from kscience/kmath
Rework codegen, use GSL vectors to store vectors, implement MatrixContext for Float, Double and Complex matrices with BLAS
This commit is contained in:
parent
69a107819b
commit
d46350e7b7
@ -6,8 +6,12 @@ import org.jetbrains.kotlin.com.intellij.psi.impl.DebugUtil
|
||||
import org.jetbrains.kotlin.com.intellij.psi.impl.source.PsiFileImpl
|
||||
import org.jetbrains.kotlin.com.intellij.testFramework.LightVirtualFile
|
||||
import org.jetbrains.kotlin.psi.KtFile
|
||||
import java.util.regex.Pattern
|
||||
import kotlin.math.min
|
||||
|
||||
private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+")
|
||||
internal fun splitByLinesDontTrim(string: String): Array<String> = EOL_SPLIT_DONT_TRIM_PATTERN.split(string)
|
||||
|
||||
internal object PsiTestUtil {
|
||||
fun checkFileStructure(file: KtFile) {
|
||||
compareFromAllRoots(file) { f -> DebugUtil.psiTreeToString(f, false) }
|
||||
|
@ -1,9 +0,0 @@
|
||||
package kscience.kmath.gsl.codegen
|
||||
|
||||
import java.util.regex.Pattern
|
||||
|
||||
private val EOL_SPLIT_DONT_TRIM_PATTERN: Pattern = Pattern.compile("(\r|\n|\r\n)+")
|
||||
|
||||
internal fun splitByLinesDontTrim(string: String): Array<String> {
|
||||
return EOL_SPLIT_DONT_TRIM_PATTERN.split(string)
|
||||
}
|
@ -22,10 +22,10 @@ private fun KtPsiFactory.createMatrixClass(
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<$kotlinTypeName, $structName>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
get() = nativeHandle.pointed.size1.toInt()
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
get() = nativeHandle.pointed.size2.toInt()
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
@ -39,30 +39,26 @@ private fun KtPsiFactory.createMatrixClass(
|
||||
override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit =
|
||||
${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): $className = memScoped {
|
||||
override fun copy(): $className {
|
||||
val new = requireNotNull(${fn("gsl_matrixRalloc", cTypeName)}(rowNum.toULong(), colNum.toULong()))
|
||||
${fn("gsl_matrixRmemcpy", cTypeName)}(new, nativeHandle)
|
||||
$className(new, features)
|
||||
return $className(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = ${fn("gsl_matrixRfree", cTypeName)}(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is $className) ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle)
|
||||
if (other is $className) return ${fn("gsl_matrixRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1
|
||||
return super.equals(other)
|
||||
}
|
||||
}"""
|
||||
f += createClass(
|
||||
text
|
||||
)
|
||||
|
||||
f += createClass(text)
|
||||
f += createNewLine(2)
|
||||
}
|
||||
|
||||
fun matricesCodegen(outputFile: String, project: Project = createProject()) {
|
||||
val f = KtPsiFactory(project, true).run {
|
||||
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
|
||||
f += createNewLine(2)
|
||||
createFile("").also { f ->
|
||||
f += createPackageDirective(FqName("kscience.kmath.gsl"))
|
||||
f += createNewLine(2)
|
||||
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
|
||||
|
@ -20,22 +20,34 @@ private fun KtPsiFactory.createVectorClass(
|
||||
@Language("kotlin") val text =
|
||||
"""internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
get() = nativeHandle.pointed.size.toInt()
|
||||
|
||||
override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong())
|
||||
override fun set(index: Int, value: $kotlinTypeName): Unit = ${
|
||||
fn("gsl_vectorRset", cTypeName)
|
||||
}(nativeHandle, index.toULong(), value)
|
||||
|
||||
override fun copy(): $className {
|
||||
val new = requireNotNull(${fn("gsl_vectorRalloc", cTypeName)}(size.toULong()))
|
||||
${fn("gsl_vectorRmemcpy", cTypeName)}(new, nativeHandle)
|
||||
return ${className}(new)
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is $className) return ${fn("gsl_vectorRequal", cTypeName)}(nativeHandle, other.nativeHandle) == 1
|
||||
return super.equals(other)
|
||||
}
|
||||
|
||||
override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle)
|
||||
}"""
|
||||
f += createClass(
|
||||
text
|
||||
)
|
||||
|
||||
f += createClass(text)
|
||||
f += createNewLine(2)
|
||||
}
|
||||
|
||||
fun vectorsCodegen(outputFile: String, project: Project = createProject()) {
|
||||
val f = KtPsiFactory(project, true).run {
|
||||
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
|
||||
f += createNewLine(2)
|
||||
createFile("").also { f ->
|
||||
f += createPackageDirective(FqName("kscience.kmath.gsl"))
|
||||
f += createNewLine(2)
|
||||
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
|
||||
|
@ -55,7 +55,7 @@ public interface MatrixContext<T : Any> : SpaceOperations<Matrix<T>> {
|
||||
* Multiplies an element by a matrix of it.
|
||||
*
|
||||
* @receiver the multiplicand.
|
||||
* @param value the multiplier.
|
||||
* @param m the multiplier.
|
||||
* @receiver the product.
|
||||
*/
|
||||
public operator fun T.times(m: Matrix<T>): Matrix<T> = m * this
|
||||
|
@ -38,7 +38,7 @@ public object BigIntField : Field<BigInt> {
|
||||
public class BigInt internal constructor(
|
||||
private val sign: Byte,
|
||||
private val magnitude: Magnitude
|
||||
) : Comparable<BigInt> {
|
||||
) : Comparable<BigInt> {
|
||||
public override fun compareTo(other: BigInt): Int = when {
|
||||
(sign == 0.toByte()) and (other.sign == 0.toByte()) -> 0
|
||||
sign < other.sign -> -1
|
||||
|
@ -38,8 +38,8 @@ kotlin {
|
||||
}
|
||||
|
||||
internal val codegen: Task by tasks.creating {
|
||||
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt")
|
||||
vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt")
|
||||
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Matrices.kt")
|
||||
vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/kscience/kmath/gsl/_Vectors.kt")
|
||||
}
|
||||
|
||||
kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen))
|
||||
|
@ -1,280 +0,0 @@
|
||||
@file:Suppress("PackageDirectoryMismatch")
|
||||
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.*
|
||||
import kscience.kmath.linear.*
|
||||
import org.gnu.gsl.*
|
||||
|
||||
internal class GslRealMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Double, gsl_matrix>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslRealMatrix =
|
||||
GslRealMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): Double = gsl_matrix_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: Double): Unit =
|
||||
gsl_matrix_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslRealMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_memcpy(new, nativeHandle)
|
||||
GslRealMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslRealMatrix) gsl_matrix_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslFloatMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_float>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Float, gsl_matrix_float>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslFloatMatrix =
|
||||
GslFloatMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): Float = gsl_matrix_float_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: Float): Unit =
|
||||
gsl_matrix_float_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslFloatMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_float_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_float_memcpy(new, nativeHandle)
|
||||
GslFloatMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_float_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslFloatMatrix) gsl_matrix_float_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslShortMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_short>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Short, gsl_matrix_short>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslShortMatrix =
|
||||
GslShortMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): Short = gsl_matrix_short_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: Short): Unit =
|
||||
gsl_matrix_short_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslShortMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_short_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_short_memcpy(new, nativeHandle)
|
||||
GslShortMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_short_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslShortMatrix) gsl_matrix_short_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslUShortMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_ushort>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<UShort, gsl_matrix_ushort>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslUShortMatrix =
|
||||
GslUShortMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): UShort = gsl_matrix_ushort_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: UShort): Unit =
|
||||
gsl_matrix_ushort_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslUShortMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_ushort_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_ushort_memcpy(new, nativeHandle)
|
||||
GslUShortMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_ushort_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslUShortMatrix) gsl_matrix_ushort_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslLongMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_long>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Long, gsl_matrix_long>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslLongMatrix =
|
||||
GslLongMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): Long = gsl_matrix_long_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: Long): Unit =
|
||||
gsl_matrix_long_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslLongMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_long_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_long_memcpy(new, nativeHandle)
|
||||
GslLongMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_long_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslLongMatrix) gsl_matrix_long_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslULongMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_ulong>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<ULong, gsl_matrix_ulong>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslULongMatrix =
|
||||
GslULongMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): ULong = gsl_matrix_ulong_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: ULong): Unit =
|
||||
gsl_matrix_ulong_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslULongMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_ulong_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_ulong_memcpy(new, nativeHandle)
|
||||
GslULongMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_ulong_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslULongMatrix) gsl_matrix_ulong_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslIntMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_int>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Int, gsl_matrix_int>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslIntMatrix =
|
||||
GslIntMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): Int = gsl_matrix_int_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: Int): Unit =
|
||||
gsl_matrix_int_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslIntMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_int_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_int_memcpy(new, nativeHandle)
|
||||
GslIntMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_int_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslIntMatrix) gsl_matrix_int_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslUIntMatrix(
|
||||
override val nativeHandle: CPointer<gsl_matrix_uint>,
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<UInt, gsl_matrix_uint>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): GslUIntMatrix =
|
||||
GslUIntMatrix(nativeHandle, this.features + features)
|
||||
|
||||
override operator fun get(i: Int, j: Int): UInt = gsl_matrix_uint_get(nativeHandle, i.toULong(), j.toULong())
|
||||
|
||||
override operator fun set(i: Int, j: Int, value: UInt): Unit =
|
||||
gsl_matrix_uint_set(nativeHandle, i.toULong(), j.toULong(), value)
|
||||
|
||||
override fun copy(): GslUIntMatrix = memScoped {
|
||||
val new = requireNotNull(gsl_matrix_uint_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_uint_memcpy(new, nativeHandle)
|
||||
GslUIntMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_uint_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslUIntMatrix) gsl_matrix_uint_equal(nativeHandle, other.nativeHandle)
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
|
@ -17,10 +17,10 @@ internal class GslComplexMatrix(
|
||||
features: Set<MatrixFeature> = emptySet()
|
||||
) : GslMatrix<Complex, gsl_matrix_complex>() {
|
||||
override val rowNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||
get() = nativeHandle.pointed.size1.toInt()
|
||||
|
||||
override val colNum: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||
get() = nativeHandle.pointed.size2.toInt()
|
||||
|
||||
override val features: Set<MatrixFeature> = features
|
||||
|
||||
@ -33,16 +33,16 @@ internal class GslComplexMatrix(
|
||||
override operator fun set(i: Int, j: Int, value: Complex): Unit =
|
||||
gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl())
|
||||
|
||||
override fun copy(): GslComplexMatrix = memScoped {
|
||||
override fun copy(): GslComplexMatrix {
|
||||
val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
|
||||
gsl_matrix_complex_memcpy(new, nativeHandle)
|
||||
GslComplexMatrix(new, features)
|
||||
return GslComplexMatrix(new, features)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_matrix_complex_free(nativeHandle)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle)
|
||||
if (other is GslComplexMatrix) return gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) == 1
|
||||
return super.equals(other)
|
||||
}
|
||||
}
|
||||
@ -50,8 +50,21 @@ internal class GslComplexMatrix(
|
||||
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
|
||||
GslVector<Complex, gsl_vector_complex>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
get() = nativeHandle.pointed.size.toInt()
|
||||
|
||||
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
|
||||
override fun set(index: Int, value: Complex): Unit = gsl_vector_complex_set(nativeHandle, index.toULong(), value.toGsl())
|
||||
|
||||
override fun copy(): GslComplexVector {
|
||||
val new = requireNotNull(gsl_vector_complex_alloc(size.toULong()))
|
||||
gsl_vector_complex_memcpy(new, nativeHandle)
|
||||
return GslComplexVector(new)
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (other is GslComplexVector) return gsl_vector_complex_equal(nativeHandle, other.nativeHandle) == 1
|
||||
return super.equals(other)
|
||||
}
|
||||
|
||||
override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
|
||||
}
|
||||
|
@ -13,9 +13,9 @@ public abstract class GslMatrix<T : Any, H : CStructVar> internal constructor():
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
public final override fun hashCode(): Int {
|
||||
var result = nativeHandle.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,7 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.CStructVar
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.structures.Matrix
|
||||
|
@ -1,40 +1,63 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.CStructVar
|
||||
import kotlinx.cinterop.pointed
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.operations.Complex
|
||||
import kscience.kmath.operations.ComplexField
|
||||
import kscience.kmath.operations.toComplex
|
||||
import kscience.kmath.structures.Matrix
|
||||
import org.gnu.gsl.*
|
||||
|
||||
private inline fun <T : Any, H : CStructVar> GslMatrix<T, H>.fill(initializer: (Int, Int) -> T): GslMatrix<T, H> =
|
||||
internal inline fun <T : Any, H : CStructVar> GslMatrix<T, H>.fill(initializer: (Int, Int) -> T): GslMatrix<T, H> =
|
||||
apply {
|
||||
(0 until rowNum).forEach { row -> (0 until colNum).forEach { col -> this[row, col] = initializer(row, col) } }
|
||||
}
|
||||
|
||||
public sealed class GslMatrixContext<T : Any, H : CStructVar> : MatrixContext<T> {
|
||||
internal inline fun <T : Any, H : CStructVar> GslVector<T, H>.fill(initializer: (Int) -> T): GslVector<T, H> =
|
||||
apply { (0 until size).forEach { index -> this[index] = initializer(index) } }
|
||||
|
||||
public abstract class GslMatrixContext<T : Any, H1 : CStructVar, H2 : CStructVar> internal constructor() :
|
||||
MatrixContext<T> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public fun Matrix<T>.toGsl(): GslMatrix<T, H> =
|
||||
(if (this is GslMatrix<*, *>) this as GslMatrix<T, H> else produce(rowNum, colNum) { i, j -> get(i, j) }).copy()
|
||||
public fun Matrix<T>.toGsl(): GslMatrix<T, H1> = (if (this is GslMatrix<*, *>)
|
||||
this as GslMatrix<T, H1>
|
||||
else
|
||||
produce(rowNum, colNum) { i, j -> this[i, j] }).copy()
|
||||
|
||||
internal abstract fun produceDirty(rows: Int, columns: Int): GslMatrix<T, H>
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public fun Point<T>.toGsl(): GslVector<T, H2> =
|
||||
(if (this is GslVector<*, *>) this as GslVector<T, H2> else produceDirtyVector(size).fill { this[it] }).copy()
|
||||
|
||||
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix<T, H> =
|
||||
produceDirty(rows, columns).fill(initializer)
|
||||
internal abstract fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<T, H1>
|
||||
internal abstract fun produceDirtyVector(size: Int): GslVector<T, H2>
|
||||
|
||||
public override fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): GslMatrix<T, H1> =
|
||||
produceDirtyMatrix(rows, columns).fill(initializer)
|
||||
}
|
||||
|
||||
public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix>() {
|
||||
public override fun produceDirty(rows: Int, columns: Int): GslMatrix<Double, gsl_matrix> =
|
||||
public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix, gsl_vector>() {
|
||||
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Double, gsl_matrix> =
|
||||
GslRealMatrix(requireNotNull(gsl_matrix_alloc(rows.toULong(), columns.toULong())))
|
||||
|
||||
override fun produceDirtyVector(size: Int): GslVector<Double, gsl_vector> =
|
||||
GslRealVector(requireNotNull(gsl_vector_alloc(size.toULong())))
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): GslMatrix<Double, gsl_matrix> {
|
||||
val g1 = toGsl()
|
||||
gsl_matrix_mul_elements(g1.nativeHandle, other.toGsl().nativeHandle)
|
||||
return g1
|
||||
val x = toGsl().nativeHandle
|
||||
val a = other.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_matrix_calloc(a.pointed.size1, a.pointed.size2))
|
||||
gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, x, a, 1.0, result)
|
||||
return GslRealMatrix(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.dot(vector: Point<Double>): GslVector<Double, gsl_vector> {
|
||||
TODO()
|
||||
val x = toGsl().nativeHandle
|
||||
val a = vector.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_vector_calloc(a.pointed.size))
|
||||
gsl_blas_dgemv(CblasNoTrans, 1.0, x, a, 1.0, result)
|
||||
return GslRealVector(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.times(value: Double): GslMatrix<Double, gsl_matrix> {
|
||||
@ -55,3 +78,87 @@ public object GslRealMatrixContext : GslMatrixContext<Double, gsl_matrix>() {
|
||||
return g1
|
||||
}
|
||||
}
|
||||
|
||||
public object GslFloatMatrixContext : GslMatrixContext<Float, gsl_matrix_float, gsl_vector_float>() {
|
||||
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Float, gsl_matrix_float> =
|
||||
GslFloatMatrix(requireNotNull(gsl_matrix_float_alloc(rows.toULong(), columns.toULong())))
|
||||
|
||||
override fun produceDirtyVector(size: Int): GslVector<Float, gsl_vector_float> =
|
||||
GslFloatVector(requireNotNull(gsl_vector_float_alloc(size.toULong())))
|
||||
|
||||
public override fun Matrix<Float>.dot(other: Matrix<Float>): GslMatrix<Float, gsl_matrix_float> {
|
||||
val x = toGsl().nativeHandle
|
||||
val a = other.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_matrix_float_calloc(a.pointed.size1, a.pointed.size2))
|
||||
gsl_blas_sgemm(CblasNoTrans, CblasNoTrans, 1f, x, a, 1f, result)
|
||||
return GslFloatMatrix(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Float>.dot(vector: Point<Float>): GslVector<Float, gsl_vector_float> {
|
||||
val x = toGsl().nativeHandle
|
||||
val a = vector.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_vector_float_calloc(a.pointed.size))
|
||||
gsl_blas_sgemv(CblasNoTrans, 1f, x, a, 1f, result)
|
||||
return GslFloatVector(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Float>.times(value: Float): GslMatrix<Float, gsl_matrix_float> {
|
||||
val g1 = toGsl()
|
||||
gsl_matrix_float_scale(g1.nativeHandle, value.toDouble())
|
||||
return g1
|
||||
}
|
||||
|
||||
public override fun add(a: Matrix<Float>, b: Matrix<Float>): GslMatrix<Float, gsl_matrix_float> {
|
||||
val g1 = a.toGsl()
|
||||
gsl_matrix_float_add(g1.nativeHandle, b.toGsl().nativeHandle)
|
||||
return g1
|
||||
}
|
||||
|
||||
public override fun multiply(a: Matrix<Float>, k: Number): GslMatrix<Float, gsl_matrix_float> {
|
||||
val g1 = a.toGsl()
|
||||
gsl_matrix_float_scale(g1.nativeHandle, k.toDouble())
|
||||
return g1
|
||||
}
|
||||
}
|
||||
|
||||
public object GslComplexMatrixContext : GslMatrixContext<Complex, gsl_matrix_complex, gsl_vector_complex>() {
|
||||
override fun produceDirtyMatrix(rows: Int, columns: Int): GslMatrix<Complex, gsl_matrix_complex> =
|
||||
GslComplexMatrix(requireNotNull(gsl_matrix_complex_alloc(rows.toULong(), columns.toULong())))
|
||||
|
||||
override fun produceDirtyVector(size: Int): GslVector<Complex, gsl_vector_complex> =
|
||||
GslComplexVector(requireNotNull(gsl_vector_complex_alloc(size.toULong())))
|
||||
|
||||
public override fun Matrix<Complex>.dot(other: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
val x = toGsl().nativeHandle
|
||||
val a = other.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_matrix_complex_calloc(a.pointed.size1, a.pointed.size2))
|
||||
gsl_blas_zgemm(CblasNoTrans, CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||
return GslComplexMatrix(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Complex>.dot(vector: Point<Complex>): GslVector<Complex, gsl_vector_complex> {
|
||||
val x = toGsl().nativeHandle
|
||||
val a = vector.toGsl().nativeHandle
|
||||
val result = requireNotNull(gsl_vector_complex_calloc(a.pointed.size))
|
||||
gsl_blas_zgemv(CblasNoTrans, ComplexField.one.toGsl(), x, a, ComplexField.one.toGsl(), result)
|
||||
return GslComplexVector(result)
|
||||
}
|
||||
|
||||
public override fun Matrix<Complex>.times(value: Complex): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
val g1 = toGsl()
|
||||
gsl_matrix_complex_scale(g1.nativeHandle, value.toGsl())
|
||||
return g1
|
||||
}
|
||||
|
||||
public override fun add(a: Matrix<Complex>, b: Matrix<Complex>): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
val g1 = a.toGsl()
|
||||
gsl_matrix_complex_add(g1.nativeHandle, b.toGsl().nativeHandle)
|
||||
return g1
|
||||
}
|
||||
|
||||
public override fun multiply(a: Matrix<Complex>, k: Number): GslMatrix<Complex, gsl_matrix_complex> {
|
||||
val g1 = a.toGsl()
|
||||
gsl_matrix_complex_scale(g1.nativeHandle, k.toComplex().toGsl())
|
||||
return g1
|
||||
}
|
||||
}
|
||||
|
@ -3,8 +3,11 @@ package kscience.kmath.gsl
|
||||
import kotlinx.cinterop.CStructVar
|
||||
import kscience.kmath.linear.Point
|
||||
|
||||
abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> {
|
||||
public override fun iterator(): Iterator<T> = object : Iterator<T> {
|
||||
public abstract class GslVector<T, H : CStructVar> internal constructor() : GslMemoryHolder<H>(), Point<T> {
|
||||
internal abstract operator fun set(index: Int, value: T)
|
||||
internal abstract fun copy(): GslVector<T, H>
|
||||
|
||||
public final override fun iterator(): Iterator<T> = object : Iterator<T> {
|
||||
private var cursor = 0
|
||||
|
||||
override fun hasNext(): Boolean = cursor < size
|
||||
@ -14,4 +17,4 @@ abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHol
|
||||
return this@GslVector[cursor - 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,101 +0,0 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.CPointer
|
||||
import kotlinx.cinterop.CStructVar
|
||||
import kotlinx.cinterop.memScoped
|
||||
import kotlinx.cinterop.pointed
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.operations.Complex
|
||||
import org.gnu.gsl.*
|
||||
|
||||
public abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> {
|
||||
public override fun iterator(): Iterator<T> = object : Iterator<T> {
|
||||
private var cursor = 0
|
||||
|
||||
override fun hasNext(): Boolean = cursor < size
|
||||
|
||||
override fun next(): T {
|
||||
cursor++
|
||||
return this@GslVector[cursor - 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class GslRealVector(override val nativeHandle: CPointer<gsl_vector>) : GslVector<Double, gsl_vector>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): Double = gsl_vector_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslFloatVector(override val nativeHandle: CPointer<gsl_vector_float>) :
|
||||
GslVector<Float, gsl_vector_float>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): Float = gsl_vector_float_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_float_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslIntVector(override val nativeHandle: CPointer<gsl_vector_int>) : GslVector<Int, gsl_vector_int>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): Int = gsl_vector_int_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_int_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslUIntVector(override val nativeHandle: CPointer<gsl_vector_uint>) :
|
||||
GslVector<UInt, gsl_vector_uint>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): UInt = gsl_vector_uint_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_uint_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslLongVector(override val nativeHandle: CPointer<gsl_vector_long>) :
|
||||
GslVector<Long, gsl_vector_long>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): Long = gsl_vector_long_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_long_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslULongVector(override val nativeHandle: CPointer<gsl_vector_ulong>) :
|
||||
GslVector<ULong, gsl_vector_ulong>() {
|
||||
public override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
public override fun get(index: Int): ULong = gsl_vector_ulong_get(nativeHandle, index.toULong())
|
||||
public override fun close(): Unit = gsl_vector_ulong_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslShortVector(override val nativeHandle: CPointer<gsl_vector_short>) :
|
||||
GslVector<Short, gsl_vector_short>() {
|
||||
public override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
public override fun get(index: Int): Short = gsl_vector_short_get(nativeHandle, index.toULong())
|
||||
public override fun close(): Unit = gsl_vector_short_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslUShortVector(override val nativeHandle: CPointer<gsl_vector_ushort>) :
|
||||
GslVector<UShort, gsl_vector_ushort>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): UShort = gsl_vector_ushort_get(nativeHandle, index.toULong())
|
||||
override fun close(): Unit = gsl_vector_ushort_free(nativeHandle)
|
||||
}
|
||||
|
||||
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
|
||||
GslVector<Complex, gsl_vector_complex>() {
|
||||
override val size: Int
|
||||
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||
|
||||
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
|
||||
override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
|
||||
}
|
@ -1,8 +1,13 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kscience.kmath.linear.RealMatrixContext
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import kscience.kmath.structures.asIterable
|
||||
import kscience.kmath.structures.asSequence
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal class RealTest {
|
||||
@Test
|
||||
@ -13,4 +18,29 @@ internal class RealTest {
|
||||
mb.close()
|
||||
ma.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDotOfMatrixAndVector() {
|
||||
val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
|
||||
val vb = RealBuffer(2) { 0.1 }
|
||||
val res1 = GslRealMatrixContext { ma dot vb }
|
||||
val res2 = RealMatrixContext { ma dot vb }
|
||||
println(res1.asSequence().toList())
|
||||
println(res2.asSequence().toList())
|
||||
assertTrue(res1.contentEquals(res2))
|
||||
res1.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDotOfMatrixAndMatrix() {
|
||||
val ma = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
|
||||
val mb = GslRealMatrixContext.produce(2, 2) { _, _ -> 100.0 }
|
||||
val res1 = GslRealMatrixContext { ma dot mb }
|
||||
val res2 = RealMatrixContext { ma dot mb }
|
||||
println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
assertEquals(res1, res2)
|
||||
ma.close()
|
||||
mb.close()
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user