Implement vector codegen

This commit is contained in:
Iaroslav Postovalov 2020-10-05 17:46:41 +07:00
parent 77dac58efb
commit 69a107819b
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
7 changed files with 149 additions and 20 deletions

View File

@ -16,7 +16,6 @@ import org.jetbrains.kotlin.com.intellij.pom.tree.TreeAspect
import org.jetbrains.kotlin.com.intellij.psi.PsiElement import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.TreeCopyHandler import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.TreeCopyHandler
import org.jetbrains.kotlin.config.CompilerConfiguration import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.psi.KtFile
import sun.reflect.ReflectionFactory import sun.reflect.ReflectionFactory
internal fun createProject(): MockProject { internal fun createProject(): MockProject {
@ -59,3 +58,13 @@ internal fun createProject(): MockProject {
internal operator fun PsiElement.plusAssign(e: PsiElement) { internal operator fun PsiElement.plusAssign(e: PsiElement) {
add(e) add(e)
} }
internal fun fn(pattern: String, type: String): String {
if (type == "double") return pattern.replace("R", "_")
return pattern.replace("R", "_${type}_")
}
internal fun sn(pattern: String, type: String): String {
if (type == "double") return pattern.replace("R", "")
return pattern.replace("R", "_$type")
}

View File

@ -1,21 +1,13 @@
package kscience.kmath.gsl.codegen package kscience.kmath.gsl.codegen
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlin.com.intellij.openapi.project.Project import org.jetbrains.kotlin.com.intellij.openapi.project.Project
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFile import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtPsiFactory import org.jetbrains.kotlin.psi.KtPsiFactory
import org.jetbrains.kotlin.resolve.ImportPath import org.jetbrains.kotlin.resolve.ImportPath
import java.io.File import java.io.File
private fun fn(pattern: String, type: String): String {
if (type == "double") return pattern.replace("R", "_")
return pattern.replace("R", "_${type}_")
}
private fun sn(pattern: String, type: String): String {
if (type == "double") return pattern.replace("R", "")
return pattern.replace("R", "_$type")
}
private fun KtPsiFactory.createMatrixClass( private fun KtPsiFactory.createMatrixClass(
f: KtFile, f: KtFile,
cTypeName: String, cTypeName: String,
@ -25,8 +17,7 @@ private fun KtPsiFactory.createMatrixClass(
val className = "Gsl${kotlinTypeAlias}Matrix" val className = "Gsl${kotlinTypeAlias}Matrix"
val structName = sn("gsl_matrixR", cTypeName) val structName = sn("gsl_matrixR", cTypeName)
f += createClass( @Language("kotlin") val text = """internal class $className(
"""internal class $className(
override val nativeHandle: CPointer<$structName>, override val nativeHandle: CPointer<$structName>,
features: Set<MatrixFeature> = emptySet() features: Set<MatrixFeature> = emptySet()
) : GslMatrix<$kotlinTypeName, $structName>() { ) : GslMatrix<$kotlinTypeName, $structName>() {
@ -61,6 +52,8 @@ private fun KtPsiFactory.createMatrixClass(
return super.equals(other) return super.equals(other)
} }
}""" }"""
f += createClass(
text
) )
f += createNewLine(2) f += createNewLine(2)
@ -68,14 +61,14 @@ private fun KtPsiFactory.createMatrixClass(
fun matricesCodegen(outputFile: String, project: Project = createProject()) { fun matricesCodegen(outputFile: String, project: Project = createProject()) {
val f = KtPsiFactory(project, true).run { val f = KtPsiFactory(project, true).run {
createFile("package kscience.kmath.gsl").also { f -> createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
f += createNewLine(2)
f += createPackageDirective(FqName("kscience.kmath.gsl"))
f += createNewLine(2) f += createNewLine(2)
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*")) f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
f += createNewLine(1) f += createNewLine(1)
f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*")) f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*"))
f += createNewLine(1) f += createNewLine(1)
f += createImportDirective(ImportPath.fromString("kscience.kmath.operations.*"))
f += createNewLine(1)
f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*")) f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*"))
f += createNewLine(2) f += createNewLine(2)
createMatrixClass(f, "double", "Double", "Real") createMatrixClass(f, "double", "Double", "Real")

View File

@ -0,0 +1,62 @@
package kscience.kmath.gsl.codegen
import org.intellij.lang.annotations.Language
import org.jetbrains.kotlin.com.intellij.openapi.project.Project
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.psi.KtPsiFactory
import org.jetbrains.kotlin.resolve.ImportPath
import java.io.File
private fun KtPsiFactory.createVectorClass(
f: KtFile,
cTypeName: String,
kotlinTypeName: String,
kotlinTypeAlias: String = kotlinTypeName
) {
val className = "Gsl${kotlinTypeAlias}Vector"
val structName = sn("gsl_vectorR", cTypeName)
@Language("kotlin") val text =
"""internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong())
override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle)
}"""
f += createClass(
text
)
f += createNewLine(2)
}
fun vectorsCodegen(outputFile: String, project: Project = createProject()) {
val f = KtPsiFactory(project, true).run {
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
f += createNewLine(2)
f += createPackageDirective(FqName("kscience.kmath.gsl"))
f += createNewLine(2)
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
f += createNewLine(1)
f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*"))
f += createNewLine(2)
createVectorClass(f, "double", "Double", "Real")
createVectorClass(f, "float", "Float")
createVectorClass(f, "short", "Short")
createVectorClass(f, "ushort", "UShort")
createVectorClass(f, "long", "Long")
createVectorClass(f, "ulong", "ULong")
createVectorClass(f, "int", "Int")
createVectorClass(f, "uint", "UInt")
}
}
PsiTestUtil.checkFileStructure(f)
File(outputFile).apply {
parentFile.mkdirs()
writeText(f.text)
}
}

View File

@ -1,6 +1,7 @@
@file:Suppress("UNUSED_VARIABLE") @file:Suppress("UNUSED_VARIABLE")
import kscience.kmath.gsl.codegen.matricesCodegen import kscience.kmath.gsl.codegen.matricesCodegen
import kscience.kmath.gsl.codegen.vectorsCodegen
plugins { plugins {
id("ru.mipt.npm.mpp") id("ru.mipt.npm.mpp")
@ -38,6 +39,7 @@ kotlin {
internal val codegen: Task by tasks.creating { internal val codegen: Task by tasks.creating {
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt") matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt")
vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt")
} }
kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen)) kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen))

View File

@ -1,8 +1,9 @@
@file:Suppress("PackageDirectoryMismatch")
package kscience.kmath.gsl package kscience.kmath.gsl
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.linear.* import kscience.kmath.linear.*
import kscience.kmath.operations.*
import org.gnu.gsl.* import org.gnu.gsl.*
internal class GslRealMatrix( internal class GslRealMatrix(

View File

@ -1,8 +1,9 @@
package kscience.kmath.gsl package kscience.kmath.gsl
import kotlinx.cinterop.* import kotlinx.cinterop.*
import kscience.kmath.linear.MatrixFeature
import kscience.kmath.operations.Complex import kscience.kmath.operations.Complex
import org.gnu.gsl.gsl_complex import org.gnu.gsl.*
internal fun CValue<gsl_complex>.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) } internal fun CValue<gsl_complex>.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) }
@ -10,3 +11,47 @@ internal fun Complex.toGsl(): CValue<gsl_complex> = cValue {
dat[0] = re dat[0] = re
dat[1] = im dat[1] = im
} }
internal class GslComplexMatrix(
override val nativeHandle: CPointer<gsl_matrix_complex>,
features: Set<MatrixFeature> = emptySet()
) : GslMatrix<Complex, gsl_matrix_complex>() {
override val rowNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
override val colNum: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
override val features: Set<MatrixFeature> = features
override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix =
GslComplexMatrix(nativeHandle, this.features + features)
override operator fun get(i: Int, j: Int): Complex =
gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).toKMath()
override operator fun set(i: Int, j: Int, value: Complex): Unit =
gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl())
override fun copy(): GslComplexMatrix = memScoped {
val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
gsl_matrix_complex_memcpy(new, nativeHandle)
GslComplexMatrix(new, features)
}
override fun close(): Unit = gsl_matrix_complex_free(nativeHandle)
override fun equals(other: Any?): Boolean {
if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle)
return super.equals(other)
}
}
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
GslVector<Complex, gsl_vector_complex>() {
override val size: Int
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
}

View File

@ -0,0 +1,17 @@
package kscience.kmath.gsl
import kotlinx.cinterop.CStructVar
import kscience.kmath.linear.Point
abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> {
public override fun iterator(): Iterator<T> = object : Iterator<T> {
private var cursor = 0
override fun hasNext(): Boolean = cursor < size
override fun next(): T {
cursor++
return this@GslVector[cursor - 1]
}
}
}