From 69a107819bbf5af2a96da2305cfce2e34e89adac Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Mon, 5 Oct 2020 17:46:41 +0700 Subject: [PATCH] Implement vector codegen --- .../kmath/gsl/codegen/codegenUtilities.kt | 11 +++- .../kmath/gsl/codegen/matricesCodegen.kt | 27 +++----- .../kmath/gsl/codegen/vectorsCodegen.kt | 62 +++++++++++++++++++ kmath-gsl/build.gradle.kts | 2 + .../nativeMain/kotlin/generated/Matrices.kt | 3 +- .../kotlin/kscience/kmath/gsl/GslComplex.kt | 47 +++++++++++++- .../kotlin/kscience/kmath/gsl/GslVector.kt | 17 +++++ 7 files changed, 149 insertions(+), 20 deletions(-) create mode 100644 buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt create mode 100644 kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/codegenUtilities.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/codegenUtilities.kt index 0df7141d7..7f930324f 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/codegenUtilities.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/codegenUtilities.kt @@ -16,7 +16,6 @@ import org.jetbrains.kotlin.com.intellij.pom.tree.TreeAspect import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.TreeCopyHandler import org.jetbrains.kotlin.config.CompilerConfiguration -import org.jetbrains.kotlin.psi.KtFile import sun.reflect.ReflectionFactory internal fun createProject(): MockProject { @@ -59,3 +58,13 @@ internal fun createProject(): MockProject { internal operator fun PsiElement.plusAssign(e: PsiElement) { add(e) } + +internal fun fn(pattern: String, type: String): String { + if (type == "double") return pattern.replace("R", "_") + return pattern.replace("R", "_${type}_") +} + +internal fun sn(pattern: String, type: String): String { + if (type == "double") return pattern.replace("R", "") + return pattern.replace("R", "_$type") +} diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt index 14aafffff..1c507cc3d 100644 --- a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/matricesCodegen.kt @@ -1,21 +1,13 @@ package kscience.kmath.gsl.codegen +import org.intellij.lang.annotations.Language import org.jetbrains.kotlin.com.intellij.openapi.project.Project +import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtPsiFactory import org.jetbrains.kotlin.resolve.ImportPath import java.io.File -private fun fn(pattern: String, type: String): String { - if (type == "double") return pattern.replace("R", "_") - return pattern.replace("R", "_${type}_") -} - -private fun sn(pattern: String, type: String): String { - if (type == "double") return pattern.replace("R", "") - return pattern.replace("R", "_$type") -} - private fun KtPsiFactory.createMatrixClass( f: KtFile, cTypeName: String, @@ -25,8 +17,7 @@ private fun KtPsiFactory.createMatrixClass( val className = "Gsl${kotlinTypeAlias}Matrix" val structName = sn("gsl_matrixR", cTypeName) - f += createClass( - """internal class $className( + @Language("kotlin") val text = """internal class $className( override val nativeHandle: CPointer<$structName>, features: Set = emptySet() ) : GslMatrix<$kotlinTypeName, $structName>() { @@ -42,8 +33,8 @@ private fun KtPsiFactory.createMatrixClass( ${className}(nativeHandle, this.features + features) override operator fun get(i: Int, j: Int): $kotlinTypeName = ${ - fn("gsl_matrixRget", cTypeName) - }(nativeHandle, i.toULong(), j.toULong()) + fn("gsl_matrixRget", cTypeName) + }(nativeHandle, i.toULong(), j.toULong()) override operator fun set(i: Int, j: Int, value: ${kotlinTypeName}): Unit = ${fn("gsl_matrixRset", cTypeName)}(nativeHandle, i.toULong(), j.toULong(), value) @@ -61,6 +52,8 @@ private fun KtPsiFactory.createMatrixClass( return super.equals(other) } }""" + f += createClass( + text ) f += createNewLine(2) @@ -68,14 +61,14 @@ private fun KtPsiFactory.createMatrixClass( fun matricesCodegen(outputFile: String, project: Project = createProject()) { val f = KtPsiFactory(project, true).run { - createFile("package kscience.kmath.gsl").also { f -> + createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> + f += createNewLine(2) + f += createPackageDirective(FqName("kscience.kmath.gsl")) f += createNewLine(2) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) f += createNewLine(1) f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*")) f += createNewLine(1) - f += createImportDirective(ImportPath.fromString("kscience.kmath.operations.*")) - f += createNewLine(1) f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*")) f += createNewLine(2) createMatrixClass(f, "double", "Double", "Real") diff --git a/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt new file mode 100644 index 000000000..f29def5b0 --- /dev/null +++ b/buildSrc/src/main/kotlin/kscience/kmath/gsl/codegen/vectorsCodegen.kt @@ -0,0 +1,62 @@ +package kscience.kmath.gsl.codegen + +import org.intellij.lang.annotations.Language +import org.jetbrains.kotlin.com.intellij.openapi.project.Project +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.psi.KtFile +import org.jetbrains.kotlin.psi.KtPsiFactory +import org.jetbrains.kotlin.resolve.ImportPath +import java.io.File + +private fun KtPsiFactory.createVectorClass( + f: KtFile, + cTypeName: String, + kotlinTypeName: String, + kotlinTypeAlias: String = kotlinTypeName +) { + val className = "Gsl${kotlinTypeAlias}Vector" + val structName = sn("gsl_vectorR", cTypeName) + + @Language("kotlin") val text = + """internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() { + override val size: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } + + override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong()) + override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle) +}""" + f += createClass( + text + ) + + f += createNewLine(2) +} + +fun vectorsCodegen(outputFile: String, project: Project = createProject()) { + val f = KtPsiFactory(project, true).run { + createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f -> + f += createNewLine(2) + f += createPackageDirective(FqName("kscience.kmath.gsl")) + f += createNewLine(2) + f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) + f += createNewLine(1) + f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*")) + f += createNewLine(2) + createVectorClass(f, "double", "Double", "Real") + createVectorClass(f, "float", "Float") + createVectorClass(f, "short", "Short") + createVectorClass(f, "ushort", "UShort") + createVectorClass(f, "long", "Long") + createVectorClass(f, "ulong", "ULong") + createVectorClass(f, "int", "Int") + createVectorClass(f, "uint", "UInt") + } + } + + PsiTestUtil.checkFileStructure(f) + + File(outputFile).apply { + parentFile.mkdirs() + writeText(f.text) + } +} diff --git a/kmath-gsl/build.gradle.kts b/kmath-gsl/build.gradle.kts index 9ff2c30d8..048d7afed 100644 --- a/kmath-gsl/build.gradle.kts +++ b/kmath-gsl/build.gradle.kts @@ -1,6 +1,7 @@ @file:Suppress("UNUSED_VARIABLE") import kscience.kmath.gsl.codegen.matricesCodegen +import kscience.kmath.gsl.codegen.vectorsCodegen plugins { id("ru.mipt.npm.mpp") @@ -38,6 +39,7 @@ kotlin { internal val codegen: Task by tasks.creating { matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt") + vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt") } kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen)) diff --git a/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt b/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt index 47cb145c8..8d471beaa 100644 --- a/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt +++ b/kmath-gsl/src/nativeMain/kotlin/generated/Matrices.kt @@ -1,8 +1,9 @@ +@file:Suppress("PackageDirectoryMismatch") + package kscience.kmath.gsl import kotlinx.cinterop.* import kscience.kmath.linear.* -import kscience.kmath.operations.* import org.gnu.gsl.* internal class GslRealMatrix( diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt index 0282afb54..68bde7ef2 100644 --- a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslComplex.kt @@ -1,8 +1,9 @@ package kscience.kmath.gsl import kotlinx.cinterop.* +import kscience.kmath.linear.MatrixFeature import kscience.kmath.operations.Complex -import org.gnu.gsl.gsl_complex +import org.gnu.gsl.* internal fun CValue.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) } @@ -10,3 +11,47 @@ internal fun Complex.toGsl(): CValue = cValue { dat[0] = re dat[1] = im } + +internal class GslComplexMatrix( + override val nativeHandle: CPointer, + features: Set = emptySet() +) : GslMatrix() { + override val rowNum: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() } + + override val colNum: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() } + + override val features: Set = features + + override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix = + GslComplexMatrix(nativeHandle, this.features + features) + + override operator fun get(i: Int, j: Int): Complex = + gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).toKMath() + + override operator fun set(i: Int, j: Int, value: Complex): Unit = + gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl()) + + override fun copy(): GslComplexMatrix = memScoped { + val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong())) + gsl_matrix_complex_memcpy(new, nativeHandle) + GslComplexMatrix(new, features) + } + + override fun close(): Unit = gsl_matrix_complex_free(nativeHandle) + + override fun equals(other: Any?): Boolean { + if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle) + return super.equals(other) + } +} + +internal class GslComplexVector(override val nativeHandle: CPointer) : + GslVector() { + override val size: Int + get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() } + + override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath() + override fun close(): Unit = gsl_vector_complex_free(nativeHandle) +} diff --git a/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt new file mode 100644 index 000000000..22b80c435 --- /dev/null +++ b/kmath-gsl/src/nativeMain/kotlin/kscience/kmath/gsl/GslVector.kt @@ -0,0 +1,17 @@ +package kscience.kmath.gsl + +import kotlinx.cinterop.CStructVar +import kscience.kmath.linear.Point + +abstract class GslVector internal constructor(): GslMemoryHolder(), Point { + public override fun iterator(): Iterator = object : Iterator { + private var cursor = 0 + + override fun hasNext(): Boolean = cursor < size + + override fun next(): T { + cursor++ + return this@GslVector[cursor - 1] + } + } +} \ No newline at end of file