Implement vector codegen
This commit is contained in:
parent
77dac58efb
commit
69a107819b
@ -16,7 +16,6 @@ import org.jetbrains.kotlin.com.intellij.pom.tree.TreeAspect
|
|||||||
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
|
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
|
||||||
import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.TreeCopyHandler
|
import org.jetbrains.kotlin.com.intellij.psi.impl.source.tree.TreeCopyHandler
|
||||||
import org.jetbrains.kotlin.config.CompilerConfiguration
|
import org.jetbrains.kotlin.config.CompilerConfiguration
|
||||||
import org.jetbrains.kotlin.psi.KtFile
|
|
||||||
import sun.reflect.ReflectionFactory
|
import sun.reflect.ReflectionFactory
|
||||||
|
|
||||||
internal fun createProject(): MockProject {
|
internal fun createProject(): MockProject {
|
||||||
@ -59,3 +58,13 @@ internal fun createProject(): MockProject {
|
|||||||
internal operator fun PsiElement.plusAssign(e: PsiElement) {
|
internal operator fun PsiElement.plusAssign(e: PsiElement) {
|
||||||
add(e)
|
add(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun fn(pattern: String, type: String): String {
|
||||||
|
if (type == "double") return pattern.replace("R", "_")
|
||||||
|
return pattern.replace("R", "_${type}_")
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun sn(pattern: String, type: String): String {
|
||||||
|
if (type == "double") return pattern.replace("R", "")
|
||||||
|
return pattern.replace("R", "_$type")
|
||||||
|
}
|
||||||
|
@ -1,21 +1,13 @@
|
|||||||
package kscience.kmath.gsl.codegen
|
package kscience.kmath.gsl.codegen
|
||||||
|
|
||||||
|
import org.intellij.lang.annotations.Language
|
||||||
import org.jetbrains.kotlin.com.intellij.openapi.project.Project
|
import org.jetbrains.kotlin.com.intellij.openapi.project.Project
|
||||||
|
import org.jetbrains.kotlin.name.FqName
|
||||||
import org.jetbrains.kotlin.psi.KtFile
|
import org.jetbrains.kotlin.psi.KtFile
|
||||||
import org.jetbrains.kotlin.psi.KtPsiFactory
|
import org.jetbrains.kotlin.psi.KtPsiFactory
|
||||||
import org.jetbrains.kotlin.resolve.ImportPath
|
import org.jetbrains.kotlin.resolve.ImportPath
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
private fun fn(pattern: String, type: String): String {
|
|
||||||
if (type == "double") return pattern.replace("R", "_")
|
|
||||||
return pattern.replace("R", "_${type}_")
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun sn(pattern: String, type: String): String {
|
|
||||||
if (type == "double") return pattern.replace("R", "")
|
|
||||||
return pattern.replace("R", "_$type")
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun KtPsiFactory.createMatrixClass(
|
private fun KtPsiFactory.createMatrixClass(
|
||||||
f: KtFile,
|
f: KtFile,
|
||||||
cTypeName: String,
|
cTypeName: String,
|
||||||
@ -25,8 +17,7 @@ private fun KtPsiFactory.createMatrixClass(
|
|||||||
val className = "Gsl${kotlinTypeAlias}Matrix"
|
val className = "Gsl${kotlinTypeAlias}Matrix"
|
||||||
val structName = sn("gsl_matrixR", cTypeName)
|
val structName = sn("gsl_matrixR", cTypeName)
|
||||||
|
|
||||||
f += createClass(
|
@Language("kotlin") val text = """internal class $className(
|
||||||
"""internal class $className(
|
|
||||||
override val nativeHandle: CPointer<$structName>,
|
override val nativeHandle: CPointer<$structName>,
|
||||||
features: Set<MatrixFeature> = emptySet()
|
features: Set<MatrixFeature> = emptySet()
|
||||||
) : GslMatrix<$kotlinTypeName, $structName>() {
|
) : GslMatrix<$kotlinTypeName, $structName>() {
|
||||||
@ -61,6 +52,8 @@ private fun KtPsiFactory.createMatrixClass(
|
|||||||
return super.equals(other)
|
return super.equals(other)
|
||||||
}
|
}
|
||||||
}"""
|
}"""
|
||||||
|
f += createClass(
|
||||||
|
text
|
||||||
)
|
)
|
||||||
|
|
||||||
f += createNewLine(2)
|
f += createNewLine(2)
|
||||||
@ -68,14 +61,14 @@ private fun KtPsiFactory.createMatrixClass(
|
|||||||
|
|
||||||
fun matricesCodegen(outputFile: String, project: Project = createProject()) {
|
fun matricesCodegen(outputFile: String, project: Project = createProject()) {
|
||||||
val f = KtPsiFactory(project, true).run {
|
val f = KtPsiFactory(project, true).run {
|
||||||
createFile("package kscience.kmath.gsl").also { f ->
|
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
|
||||||
|
f += createNewLine(2)
|
||||||
|
f += createPackageDirective(FqName("kscience.kmath.gsl"))
|
||||||
f += createNewLine(2)
|
f += createNewLine(2)
|
||||||
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
|
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
|
||||||
f += createNewLine(1)
|
f += createNewLine(1)
|
||||||
f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*"))
|
f += createImportDirective(ImportPath.fromString("kscience.kmath.linear.*"))
|
||||||
f += createNewLine(1)
|
f += createNewLine(1)
|
||||||
f += createImportDirective(ImportPath.fromString("kscience.kmath.operations.*"))
|
|
||||||
f += createNewLine(1)
|
|
||||||
f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*"))
|
f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*"))
|
||||||
f += createNewLine(2)
|
f += createNewLine(2)
|
||||||
createMatrixClass(f, "double", "Double", "Real")
|
createMatrixClass(f, "double", "Double", "Real")
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
package kscience.kmath.gsl.codegen
|
||||||
|
|
||||||
|
import org.intellij.lang.annotations.Language
|
||||||
|
import org.jetbrains.kotlin.com.intellij.openapi.project.Project
|
||||||
|
import org.jetbrains.kotlin.name.FqName
|
||||||
|
import org.jetbrains.kotlin.psi.KtFile
|
||||||
|
import org.jetbrains.kotlin.psi.KtPsiFactory
|
||||||
|
import org.jetbrains.kotlin.resolve.ImportPath
|
||||||
|
import java.io.File
|
||||||
|
|
||||||
|
private fun KtPsiFactory.createVectorClass(
|
||||||
|
f: KtFile,
|
||||||
|
cTypeName: String,
|
||||||
|
kotlinTypeName: String,
|
||||||
|
kotlinTypeAlias: String = kotlinTypeName
|
||||||
|
) {
|
||||||
|
val className = "Gsl${kotlinTypeAlias}Vector"
|
||||||
|
val structName = sn("gsl_vectorR", cTypeName)
|
||||||
|
|
||||||
|
@Language("kotlin") val text =
|
||||||
|
"""internal class $className(override val nativeHandle: CPointer<$structName>) : GslVector<$kotlinTypeName, $structName>() {
|
||||||
|
override val size: Int
|
||||||
|
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||||
|
|
||||||
|
override fun get(index: Int): $kotlinTypeName = ${fn("gsl_vectorRget", cTypeName)}(nativeHandle, index.toULong())
|
||||||
|
override fun close(): Unit = ${fn("gsl_vectorRfree", cTypeName)}(nativeHandle)
|
||||||
|
}"""
|
||||||
|
f += createClass(
|
||||||
|
text
|
||||||
|
)
|
||||||
|
|
||||||
|
f += createNewLine(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun vectorsCodegen(outputFile: String, project: Project = createProject()) {
|
||||||
|
val f = KtPsiFactory(project, true).run {
|
||||||
|
createFile("@file:Suppress(\"PackageDirectoryMismatch\")").also { f ->
|
||||||
|
f += createNewLine(2)
|
||||||
|
f += createPackageDirective(FqName("kscience.kmath.gsl"))
|
||||||
|
f += createNewLine(2)
|
||||||
|
f += createImportDirective(ImportPath.fromString("kotlinx.cinterop.*"))
|
||||||
|
f += createNewLine(1)
|
||||||
|
f += createImportDirective(ImportPath.fromString("org.gnu.gsl.*"))
|
||||||
|
f += createNewLine(2)
|
||||||
|
createVectorClass(f, "double", "Double", "Real")
|
||||||
|
createVectorClass(f, "float", "Float")
|
||||||
|
createVectorClass(f, "short", "Short")
|
||||||
|
createVectorClass(f, "ushort", "UShort")
|
||||||
|
createVectorClass(f, "long", "Long")
|
||||||
|
createVectorClass(f, "ulong", "ULong")
|
||||||
|
createVectorClass(f, "int", "Int")
|
||||||
|
createVectorClass(f, "uint", "UInt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PsiTestUtil.checkFileStructure(f)
|
||||||
|
|
||||||
|
File(outputFile).apply {
|
||||||
|
parentFile.mkdirs()
|
||||||
|
writeText(f.text)
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
@file:Suppress("UNUSED_VARIABLE")
|
@file:Suppress("UNUSED_VARIABLE")
|
||||||
|
|
||||||
import kscience.kmath.gsl.codegen.matricesCodegen
|
import kscience.kmath.gsl.codegen.matricesCodegen
|
||||||
|
import kscience.kmath.gsl.codegen.vectorsCodegen
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("ru.mipt.npm.mpp")
|
id("ru.mipt.npm.mpp")
|
||||||
@ -38,6 +39,7 @@ kotlin {
|
|||||||
|
|
||||||
internal val codegen: Task by tasks.creating {
|
internal val codegen: Task by tasks.creating {
|
||||||
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt")
|
matricesCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Matrices.kt")
|
||||||
|
vectorsCodegen(kotlin.sourceSets["nativeMain"].kotlin.srcDirs.first().absolutePath + "/generated/Vectors.kt")
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen))
|
kotlin.sourceSets["nativeMain"].kotlin.srcDirs(files().builtBy(codegen))
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
|
@file:Suppress("PackageDirectoryMismatch")
|
||||||
|
|
||||||
package kscience.kmath.gsl
|
package kscience.kmath.gsl
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
import kscience.kmath.linear.*
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.operations.*
|
|
||||||
import org.gnu.gsl.*
|
import org.gnu.gsl.*
|
||||||
|
|
||||||
internal class GslRealMatrix(
|
internal class GslRealMatrix(
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
package kscience.kmath.gsl
|
package kscience.kmath.gsl
|
||||||
|
|
||||||
import kotlinx.cinterop.*
|
import kotlinx.cinterop.*
|
||||||
|
import kscience.kmath.linear.MatrixFeature
|
||||||
import kscience.kmath.operations.Complex
|
import kscience.kmath.operations.Complex
|
||||||
import org.gnu.gsl.gsl_complex
|
import org.gnu.gsl.*
|
||||||
|
|
||||||
internal fun CValue<gsl_complex>.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) }
|
internal fun CValue<gsl_complex>.toKMath(): Complex = useContents { Complex(dat[0], dat[1]) }
|
||||||
|
|
||||||
@ -10,3 +11,47 @@ internal fun Complex.toGsl(): CValue<gsl_complex> = cValue {
|
|||||||
dat[0] = re
|
dat[0] = re
|
||||||
dat[1] = im
|
dat[1] = im
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal class GslComplexMatrix(
|
||||||
|
override val nativeHandle: CPointer<gsl_matrix_complex>,
|
||||||
|
features: Set<MatrixFeature> = emptySet()
|
||||||
|
) : GslMatrix<Complex, gsl_matrix_complex>() {
|
||||||
|
override val rowNum: Int
|
||||||
|
get() = memScoped { nativeHandle.getPointer(this).pointed.size1.toInt() }
|
||||||
|
|
||||||
|
override val colNum: Int
|
||||||
|
get() = memScoped { nativeHandle.getPointer(this).pointed.size2.toInt() }
|
||||||
|
|
||||||
|
override val features: Set<MatrixFeature> = features
|
||||||
|
|
||||||
|
override fun suggestFeature(vararg features: MatrixFeature): GslComplexMatrix =
|
||||||
|
GslComplexMatrix(nativeHandle, this.features + features)
|
||||||
|
|
||||||
|
override operator fun get(i: Int, j: Int): Complex =
|
||||||
|
gsl_matrix_complex_get(nativeHandle, i.toULong(), j.toULong()).toKMath()
|
||||||
|
|
||||||
|
override operator fun set(i: Int, j: Int, value: Complex): Unit =
|
||||||
|
gsl_matrix_complex_set(nativeHandle, i.toULong(), j.toULong(), value.toGsl())
|
||||||
|
|
||||||
|
override fun copy(): GslComplexMatrix = memScoped {
|
||||||
|
val new = requireNotNull(gsl_matrix_complex_alloc(rowNum.toULong(), colNum.toULong()))
|
||||||
|
gsl_matrix_complex_memcpy(new, nativeHandle)
|
||||||
|
GslComplexMatrix(new, features)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun close(): Unit = gsl_matrix_complex_free(nativeHandle)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (other is GslComplexMatrix) gsl_matrix_complex_equal(nativeHandle, other.nativeHandle)
|
||||||
|
return super.equals(other)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal class GslComplexVector(override val nativeHandle: CPointer<gsl_vector_complex>) :
|
||||||
|
GslVector<Complex, gsl_vector_complex>() {
|
||||||
|
override val size: Int
|
||||||
|
get() = memScoped { nativeHandle.getPointer(this).pointed.size.toInt() }
|
||||||
|
|
||||||
|
override fun get(index: Int): Complex = gsl_vector_complex_get(nativeHandle, index.toULong()).toKMath()
|
||||||
|
override fun close(): Unit = gsl_vector_complex_free(nativeHandle)
|
||||||
|
}
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
package kscience.kmath.gsl
|
||||||
|
|
||||||
|
import kotlinx.cinterop.CStructVar
|
||||||
|
import kscience.kmath.linear.Point
|
||||||
|
|
||||||
|
abstract class GslVector<T, H : CStructVar> internal constructor(): GslMemoryHolder<H>(), Point<T> {
|
||||||
|
public override fun iterator(): Iterator<T> = object : Iterator<T> {
|
||||||
|
private var cursor = 0
|
||||||
|
|
||||||
|
override fun hasNext(): Boolean = cursor < size
|
||||||
|
|
||||||
|
override fun next(): T {
|
||||||
|
cursor++
|
||||||
|
return this@GslVector[cursor - 1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user