diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt deleted file mode 100644 index 38ed00605..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressionSpaces.kt +++ /dev/null @@ -1,38 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Space -import scientifik.kmath.operations.invoke - -open class AsmExpressionSpace(private val space: Space) : - Space>, - ExpressionContext> { - override val zero: AsmExpression = AsmConstantExpression(space.zero) - override fun const(value: T): AsmExpression = AsmConstantExpression(value) - override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) - override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = AsmSumExpression(space, a, b) - override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) - operator fun AsmExpression.plus(arg: T): AsmExpression = this + const(arg) - operator fun AsmExpression.minus(arg: T): AsmExpression = this - const(arg) - operator fun T.plus(arg: AsmExpression): AsmExpression = arg + this - operator fun T.minus(arg: AsmExpression): AsmExpression = arg - this -} - -class AsmExpressionField(private val field: Field) : - ExpressionContext>, - Field>, - AsmExpressionSpace(field) { - override val one: AsmExpression - get() = const(this.field.one) - - fun number(value: Number): AsmExpression = const(field { one * value }) - - override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = - AsmProductExpression(field, a, b) - - override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = AsmDivExpression(field, a, b) - operator fun AsmExpression.times(arg: T): AsmExpression = this * const(arg) - operator fun AsmExpression.div(arg: T): AsmExpression = this / const(arg) - operator fun T.times(arg: AsmExpression): AsmExpression = arg * this - operator fun T.div(arg: AsmExpression): AsmExpression = arg / this -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt deleted file mode 100644 index fc7788589..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmExpressions.kt +++ /dev/null @@ -1,123 +0,0 @@ -package scientifik.kmath.expressions - -import scientifik.kmath.operations.* - -abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: MutableList -) : Expression { - abstract override fun invoke(arguments: Map): T -} - -interface AsmExpression { - fun tryEvaluate(): T? = null - fun invoke(gen: AsmGenerationContext) -} - -internal class AsmVariableExpression(val name: String, val default: T? = null) : - AsmExpression { - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) -} - -internal class AsmConstantExpression(val value: T) : - AsmExpression { - override fun tryEvaluate(): T = value - override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) -} - -internal class AsmSumExpression( - private val algebra: SpaceOperations, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (first.tryEvaluate() ?: return@algebra null) + (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - first.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "add", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmProductExpression( - private val algebra: RingOperations, - first: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val first: AsmExpression = first.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (first.tryEvaluate() ?: return@algebra null) * (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - first.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.RING_OPERATIONS_CLASS, - method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmConstProductExpression( - private val algebra: SpaceOperations, - expr: AsmExpression, - private val const: Number -) : AsmExpression { - private val expr: AsmExpression = expr.optimize() - - override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - gen.visitNumberConstant(const) - expr.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, - method = "multiply", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} - -internal class AsmDivExpression( - private val algebra: FieldOperations, - expr: AsmExpression, - second: AsmExpression -) : AsmExpression { - private val expr: AsmExpression = expr.optimize() - private val second: AsmExpression = second.optimize() - - override fun tryEvaluate(): T? = algebra { - (expr.tryEvaluate() ?: return@algebra null) / (second.tryEvaluate() ?: return@algebra null) - } - - override fun invoke(gen: AsmGenerationContext) { - gen.visitLoadAlgebra() - expr.invoke(gen) - second.invoke(gen) - - gen.visitAlgebraOperation( - owner = AsmGenerationContext.FIELD_OPERATIONS_CLASS, - method = "divide", - descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.OBJECT_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" - ) - } -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt deleted file mode 100644 index bb7d0476d..000000000 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Optimization.kt +++ /dev/null @@ -1,6 +0,0 @@ -package scientifik.kmath.expressions - -fun AsmExpression.optimize(): AsmExpression { - val a = tryEvaluate() - return if (a == null) this else AsmConstantExpression(a) -} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt new file mode 100644 index 000000000..17b6dc023 --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressionSpaces.kt @@ -0,0 +1,74 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.ExpressionContext +import scientifik.kmath.operations.* + +open class AsmExpressionAlgebra(val algebra: Algebra) : + Algebra>, + ExpressionContext> { + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun const(value: T): AsmExpression = AsmConstantExpression(value) + override fun variable(name: String, default: T?): AsmExpression = AsmVariableExpression(name, default) +} + +open class AsmExpressionSpace( + val space: Space +) : AsmExpressionAlgebra(space), Space> { + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override val zero: AsmExpression = AsmConstantExpression(space.zero) + + override fun add(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) + + override fun multiply(a: AsmExpression, k: Number): AsmExpression = AsmConstProductExpression(space, a, k) + operator fun AsmExpression.plus(arg: T) = this + const(arg) + operator fun AsmExpression.minus(arg: T) = this - const(arg) + operator fun T.plus(arg: AsmExpression) = arg + this + operator fun T.minus(arg: AsmExpression) = arg - this +} + +open class AsmExpressionRing(private val ring: Ring) : AsmExpressionSpace(ring), Ring> { + override val one: AsmExpression + get() = const(this.ring.one) + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + fun number(value: Number): AsmExpression = const(ring { one * value }) + + override fun multiply(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) + + operator fun AsmExpression.times(arg: T) = this * const(arg) + operator fun T.times(arg: AsmExpression) = arg * this +} + +open class AsmExpressionField(private val field: Field) : + AsmExpressionRing(field), + Field> { + + override fun unaryOperation(operation: String, arg: AsmExpression): AsmExpression = + AsmUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: AsmExpression, right: AsmExpression): AsmExpression = + AsmBinaryOperation(algebra, operation, left, right) + + override fun divide(a: AsmExpression, b: AsmExpression): AsmExpression = + AsmBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + + operator fun AsmExpression.div(arg: T) = this / const(arg) + operator fun T.div(arg: AsmExpression) = arg / this +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt new file mode 100644 index 000000000..d7d655d6e --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmExpressions.kt @@ -0,0 +1,106 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.Expression +import scientifik.kmath.operations.Algebra +import scientifik.kmath.operations.Space +import scientifik.kmath.operations.invoke + +interface AsmExpression { + fun tryEvaluate(): T? = null + fun invoke(gen: AsmGenerationContext) +} + +internal class AsmUnaryOperation(private val context: Algebra, private val name: String, expr: AsmExpression) : + AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = context { + unaryOperation( + name, + expr.tryEvaluate() ?: return@context null + ) + } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitStringConstant(name) + expr.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "unaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmBinaryOperation( + private val context: Algebra, + private val name: String, + first: AsmExpression, + second: AsmExpression +) : AsmExpression { + private val first: AsmExpression = first.optimize() + private val second: AsmExpression = second.optimize() + + override fun tryEvaluate(): T? = context { + binaryOperation( + name, + first.tryEvaluate() ?: return@context null, + second.tryEvaluate() ?: return@context null + ) + } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitStringConstant(name) + first.invoke(gen) + second.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.ALGEBRA_CLASS, + method = "binaryOperation", + descriptor = "(L${AsmGenerationContext.STRING_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};" + + "L${AsmGenerationContext.OBJECT_CLASS};)" + + "L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal class AsmVariableExpression(private val name: String, private val default: T? = null) : AsmExpression { + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromVariables(name, default) +} + +internal class AsmConstantExpression(private val value: T) : AsmExpression { + override fun tryEvaluate(): T = value + override fun invoke(gen: AsmGenerationContext): Unit = gen.visitLoadFromConstants(value) +} + +internal class AsmConstProductExpression(private val context: Space, expr: AsmExpression, private val const: Number) : + AsmExpression { + private val expr: AsmExpression = expr.optimize() + + override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const } + + override fun invoke(gen: AsmGenerationContext) { + gen.visitLoadAlgebra() + gen.visitNumberConstant(const) + expr.invoke(gen) + + gen.visitAlgebraOperation( + owner = AsmGenerationContext.SPACE_OPERATIONS_CLASS, + method = "multiply", + descriptor = "(L${AsmGenerationContext.OBJECT_CLASS};L${AsmGenerationContext.NUMBER_CLASS};)L${AsmGenerationContext.OBJECT_CLASS};" + ) + } +} + +internal abstract class FunctionalCompiledExpression internal constructor( + @JvmField protected val algebra: Algebra, + @JvmField protected val constants: MutableList +) : Expression { + abstract override fun invoke(arguments: Map): T +} diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt similarity index 92% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt index cedd5c0fd..3e47da5e3 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/AsmGenerationContext.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/AsmGenerationContext.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm import org.objectweb.asm.ClassWriter import org.objectweb.asm.Label @@ -33,15 +33,15 @@ class AsmGenerationContext( private val invokeMethodVisitor: MethodVisitor private val invokeL0: Label private lateinit var invokeL1: Label - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: FunctionalCompiledExpression? = null init { asmCompiledClassWriter.visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, slashesClassName, - "L$ASM_COMPILED_EXPRESSION_CLASS;", - ASM_COMPILED_EXPRESSION_CLASS, + "L$FUNCTIONAL_COMPILED_EXPRESSION_CLASS;", + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, arrayOf() ) @@ -58,7 +58,7 @@ class AsmGenerationContext( visitMethodInsn( Opcodes.INVOKESPECIAL, - ASM_COMPILED_EXPRESSION_CLASS, + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "", "(L$ALGEBRA_CLASS;L$LIST_CLASS;)V", false @@ -103,7 +103,7 @@ class AsmGenerationContext( @PublishedApi @Suppress("UNCHECKED_CAST") - internal fun generate(): AsmCompiledExpression { + internal fun generate(): FunctionalCompiledExpression { generatedInstance?.let { return it } invokeMethodVisitor.run { @@ -170,7 +170,7 @@ class AsmGenerationContext( .defineClass(className, asmCompiledClassWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants) as AsmCompiledExpression + .newInstance(algebra, constants) as FunctionalCompiledExpression generatedInstance = new return new @@ -245,7 +245,7 @@ class AsmGenerationContext( invokeMethodVisitor.visitFieldInsn( Opcodes.GETFIELD, - ASM_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" + FUNCTIONAL_COMPILED_EXPRESSION_CLASS, "algebra", "L$ALGEBRA_CLASS;" ) invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_ALGEBRA_CLASS) @@ -259,6 +259,10 @@ class AsmGenerationContext( private fun visitCastToT(): Unit = invokeMethodVisitor.visitTypeInsn(Opcodes.CHECKCAST, T_CLASS) + internal fun visitStringConstant(string: String) { + invokeMethodVisitor.visitLdcInsn(string) + } + internal companion object { private val SIGNATURE_LETTERS = mapOf( java.lang.Byte::class.java to "B", @@ -269,15 +273,13 @@ class AsmGenerationContext( java.lang.Double::class.java to "D" ) - internal const val ASM_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/AsmCompiledExpression" + internal const val FUNCTIONAL_COMPILED_EXPRESSION_CLASS = "scientifik/kmath/expressions/asm/FunctionalCompiledExpression" internal const val LIST_CLASS = "java/util/List" internal const val MAP_CLASS = "java/util/Map" internal const val OBJECT_CLASS = "java/lang/Object" internal const val ALGEBRA_CLASS = "scientifik/kmath/operations/Algebra" internal const val SPACE_OPERATIONS_CLASS = "scientifik/kmath/operations/SpaceOperations" internal const val STRING_CLASS = "java/lang/String" - internal const val FIELD_OPERATIONS_CLASS = "scientifik/kmath/operations/FieldOperations" - internal const val RING_OPERATIONS_CLASS = "scientifik/kmath/operations/RingOperations" internal const val NUMBER_CLASS = "java/lang/Number" } } diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt similarity index 69% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt index f8b4afd5b..d55555567 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/Builders.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Builders.kt @@ -1,7 +1,9 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Field +import scientifik.kmath.operations.Ring import scientifik.kmath.operations.Space @PublishedApi @@ -25,11 +27,21 @@ inline fun asm(i: I, algebra: Algebra, block: I.() -> AsmExpre return ctx.generate() } +inline fun asmAlgebra( + algebra: Algebra, + block: AsmExpressionAlgebra.() -> AsmExpression +): Expression = asm(AsmExpressionAlgebra(algebra), algebra, block) + inline fun asmSpace( algebra: Space, block: AsmExpressionSpace.() -> AsmExpression ): Expression = asm(AsmExpressionSpace(algebra), algebra, block) +inline fun asmRing( + algebra: Ring, + block: AsmExpressionRing.() -> AsmExpression +): Expression = asm(AsmExpressionRing(algebra), algebra, block) + inline fun asmField( algebra: Field, block: AsmExpressionField.() -> AsmExpression diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt similarity index 94% rename from kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt rename to kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt index 9cdb0672f..9f697ab6d 100644 --- a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/MethodVisitors.kt +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/MethodVisitors.kt @@ -1,4 +1,4 @@ -package scientifik.kmath.expressions +package scientifik.kmath.expressions.asm import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* diff --git a/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt new file mode 100644 index 000000000..db57a690c --- /dev/null +++ b/kmath-asm/src/main/kotlin/scientifik/kmath/expressions/asm/Optimization.kt @@ -0,0 +1,9 @@ +package scientifik.kmath.expressions.asm + +import scientifik.kmath.expressions.asm.AsmConstantExpression +import scientifik.kmath.expressions.asm.AsmExpression + +fun AsmExpression.optimize(): AsmExpression { + val a = tryEvaluate() + return if (a == null) this else AsmConstantExpression(a) +} diff --git a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt index 437fd158f..55c240cef 100644 --- a/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt +++ b/kmath-asm/src/test/kotlin/scientifik/kmath/expressions/AsmTest.kt @@ -1,103 +1,40 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Algebra +import scientifik.kmath.expressions.asm.AsmExpression +import scientifik.kmath.expressions.asm.AsmExpressionField +import scientifik.kmath.expressions.asm.asmField import scientifik.kmath.operations.RealField import kotlin.test.Test import kotlin.test.assertEquals class AsmTest { - private fun testExpressionValue( - expectedValue: T, - expr: AsmExpression, - arguments: Map, - algebra: Algebra, - clazz: Class<*> - ): Unit = assertEquals( - expectedValue, AsmGenerationContext(clazz, algebra, "TestAsmCompiled") - .also(expr::invoke) - .generate() - .invoke(arguments) - ) - - @Suppress("UNCHECKED_CAST") - private fun testDoubleExpressionValue( - expectedValue: Double, - expr: AsmExpression, - arguments: Map, - algebra: Algebra = RealField, - clazz: Class = java.lang.Double::class.java as Class - ): Unit = testExpressionValue(expectedValue, expr, arguments, algebra, clazz) + private fun testDoubleExpression( + expected: Double?, + arguments: Map = emptyMap(), + block: AsmExpressionField.() -> AsmExpression + ): Unit = assertEquals(expected = expected, actual = asmField(RealField, block)(arguments)) @Test - fun testSum() = testDoubleExpressionValue( - 25.0, - AsmSumExpression(RealField, AsmConstantExpression(1.0), AsmVariableExpression("x")), - mapOf("x" to 24.0) - ) + fun testConstantsSum() = testDoubleExpression(16.0) { const(8.0) + 8.0 } @Test - fun testConst(): Unit = testDoubleExpressionValue( - 123.0, - AsmConstantExpression(123.0), - mapOf() - ) + fun testVarsSum() = testDoubleExpression(1000.0, mapOf("x" to 500.0)) { variable("x") + 500.0 } @Test - fun testDiv(): Unit = testDoubleExpressionValue( - 0.5, - AsmDivExpression(RealField, AsmConstantExpression(1.0), AsmConstantExpression(2.0)), - mapOf() - ) + fun testProduct() = testDoubleExpression(24.0) { const(4.0) * const(6.0) } @Test - fun testProduct(): Unit = testDoubleExpressionValue( - 25.0, - AsmProductExpression(RealField,AsmVariableExpression("x"), AsmVariableExpression("x")), - mapOf("x" to 5.0) - ) + fun testConstantProduct() = testDoubleExpression(984.0) { const(8.0) * 123 } @Test - fun testCProduct(): Unit = testDoubleExpressionValue( - 25.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5.0), - mapOf("x" to 5.0) - ) + fun testSubtraction() = testDoubleExpression(2.0) { const(4.0) - 2.0 } @Test - fun testCProductWithOtherTypeNumber(): Unit = testDoubleExpressionValue( - 25.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), 5f), - mapOf("x" to 5.0) - ) - - object CustomZero : Number() { - override fun toByte(): Byte = 0 - override fun toChar(): Char = 0.toChar() - override fun toDouble(): Double = 0.0 - override fun toFloat(): Float = 0f - override fun toInt(): Int = 0 - override fun toLong(): Long = 0L - override fun toShort(): Short = 0 - } + fun testDivision() = testDoubleExpression(64.0) { const(128.0) / 2 } @Test - fun testCProductWithCustomTypeNumber(): Unit = testDoubleExpressionValue( - 0.0, - AsmConstProductExpression(RealField,AsmVariableExpression("x"), CustomZero), - mapOf("x" to 5.0) - ) + fun testDirectCall() = testDoubleExpression(4096.0) { binaryOperation("*", const(64.0), const(64.0)) } - @Test - fun testVar(): Unit = testDoubleExpressionValue( - 10000.0, - AsmVariableExpression("x"), - mapOf("x" to 10000.0) - ) - - @Test - fun testVarWithDefault(): Unit = testDoubleExpressionValue( - 10000.0, - AsmVariableExpression("x", 10000.0), - mapOf() - ) +// @Test +// fun testSine() = testDoubleExpression(0.0) { unaryOperation("sin", const(PI)) } } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt index 97002f664..b36ea8b52 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/expressions/FunctionalExpressions.kt @@ -1,80 +1,102 @@ package scientifik.kmath.expressions -import scientifik.kmath.operations.Field -import scientifik.kmath.operations.Ring -import scientifik.kmath.operations.Space +import scientifik.kmath.operations.* -internal class VariableExpression(val name: String, val default: T? = null) : Expression { +internal class FunctionalUnaryOperation(val context: Algebra, val name: String, val expr: Expression) : + Expression { + override fun invoke(arguments: Map): T = context.unaryOperation(name, expr.invoke(arguments)) +} + +internal class FunctionalBinaryOperation( + val context: Algebra, + val name: String, + val first: Expression, + val second: Expression +) : Expression { + override fun invoke(arguments: Map): T = + context.binaryOperation(name, first.invoke(arguments), second.invoke(arguments)) +} + +internal class FunctionalVariableExpression(val name: String, val default: T? = null) : Expression { override fun invoke(arguments: Map): T = arguments[name] ?: default ?: error("Parameter not found: $name") } -internal class ConstantExpression(val value: T) : Expression { +internal class FunctionalConstantExpression(val value: T) : Expression { override fun invoke(arguments: Map): T = value } -internal class SumExpression( - val context: Space, - val first: Expression, - val second: Expression -) : Expression { - override fun invoke(arguments: Map): T = context.add(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ProductExpression(val context: Ring, val first: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = - context.multiply(first.invoke(arguments), second.invoke(arguments)) -} - -internal class ConstProductExpession(val context: Space, val expr: Expression, val const: Number) : +internal class FunctionalConstProductExpression(val context: Space, val expr: Expression, val const: Number) : Expression { override fun invoke(arguments: Map): T = context.multiply(expr.invoke(arguments), const) } -internal class DivExpession(val context: Field, val expr: Expression, val second: Expression) : - Expression { - override fun invoke(arguments: Map): T = context.divide(expr.invoke(arguments), second.invoke(arguments)) +open class FunctionalExpressionAlgebra(val algebra: Algebra) : + Algebra>, + ExpressionContext> { + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + override fun const(value: T): Expression = FunctionalConstantExpression(value) + override fun variable(name: String, default: T?): Expression = FunctionalVariableExpression(name, default) } -open class FunctionalExpressionSpace( - val space: Space -) : Space>, ExpressionContext> { +open class FunctionalExpressionSpace(val space: Space) : + FunctionalExpressionAlgebra(space), + Space> { + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) - override val zero: Expression = ConstantExpression(space.zero) + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) - override fun const(value: T): Expression = ConstantExpression(value) - - override fun variable(name: String, default: T?): Expression = VariableExpression(name, default) - - override fun add(a: Expression, b: Expression): Expression = SumExpression(space, a, b) - - override fun multiply(a: Expression, k: Number): Expression = ConstProductExpession(space, a, k) + override val zero: Expression = FunctionalConstantExpression(space.zero) + override fun add(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, SpaceOperations.PLUS_OPERATION, a, b) + override fun multiply(a: Expression, k: Number): Expression = FunctionalConstProductExpression(space, a, k) operator fun Expression.plus(arg: T) = this + const(arg) operator fun Expression.minus(arg: T) = this - const(arg) - operator fun T.plus(arg: Expression) = arg + this operator fun T.minus(arg: Expression) = arg - this } -open class FunctionalExpressionField( - val field: Field -) : Field>, ExpressionContext>, FunctionalExpressionSpace(field) { - +open class FunctionalExpressionRing(val ring: Ring) : FunctionalExpressionSpace(ring), Ring> { override val one: Expression - get() = const(this.field.one) + get() = const(this.ring.one) - fun number(value: Number): Expression = const(field.run { one * value }) + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) - override fun multiply(a: Expression, b: Expression): Expression = ProductExpression(field, a, b) + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) - override fun divide(a: Expression, b: Expression): Expression = DivExpession(field, a, b) + fun number(value: Number): Expression = const(ring { one * value }) + + override fun multiply(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, RingOperations.TIMES_OPERATION, a, b) operator fun Expression.times(arg: T) = this * const(arg) - operator fun Expression.div(arg: T) = this / const(arg) - operator fun T.times(arg: Expression) = arg * this +} + +open class FunctionalExpressionField(val field: Field) : + FunctionalExpressionRing(field), + Field> { + + override fun unaryOperation(operation: String, arg: Expression): Expression = + FunctionalUnaryOperation(algebra, operation, arg) + + override fun binaryOperation(operation: String, left: Expression, right: Expression): Expression = + FunctionalBinaryOperation(algebra, operation, left, right) + + override fun divide(a: Expression, b: Expression): Expression = + FunctionalBinaryOperation(space, FieldOperations.DIV_OPERATION, a, b) + + operator fun Expression.div(arg: T) = this / const(arg) operator fun T.div(arg: Expression) = arg / this -} \ No newline at end of file +}