From 2e5c13aea99b8ecc82c867ca859353df80ea194c Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Wed, 10 Jun 2020 02:05:13 +0700 Subject: [PATCH 1/6] Improve support of string-identified operations API, rework trigonometric operations algebra part: introduce inverse trigonometric operations, rename tg to tan --- .../commons/expressions/DiffExpression.kt | 7 +- .../scientifik/kmath/operations/Complex.kt | 8 +- .../kmath/operations/NumberAlgebra.kt | 12 +- .../kmath/operations/OptionalOperations.kt | 34 +++-- .../kmath/structures/ComplexNDField.kt | 7 ++ .../kmath/structures/ExtendedNDField.kt | 2 +- .../kmath/structures/RealBufferField.kt | 117 ++++++++++++------ .../kmath/structures/RealNDField.kt | 7 ++ 8 files changed, 140 insertions(+), 54 deletions(-) diff --git a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt index 8c19395d3..88d32378e 100644 --- a/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt +++ b/kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions/DiffExpression.kt @@ -59,8 +59,10 @@ class DerivativeStructureField( override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() - override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() + override fun asin(arg: DerivativeStructure): DerivativeStructure = arg.asin() + override fun acos(arg: DerivativeStructure): DerivativeStructure = arg.acos() + override fun atan(arg: DerivativeStructure): DerivativeStructure = arg.atan() override fun power(arg: DerivativeStructure, pow: Number): DerivativeStructure = when (pow) { is Double -> arg.pow(pow) @@ -136,6 +138,3 @@ object DiffExpressionContext : ExpressionContext, Field< override fun divide(a: DiffExpression, b: DiffExpression) = DiffExpression { a.function(this) / b.function(this) } } - - - diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt index 6c529f55e..ea9425bef 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/Complex.kt @@ -8,6 +8,8 @@ import scientifik.memory.MemorySpec import scientifik.memory.MemoryWriter import kotlin.math.* +private val PI_DIV_2 = Complex(PI / 2, 0) + /** * A field for complex numbers */ @@ -30,9 +32,11 @@ object ComplexField : ExtendedFieldOperations, Field { return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm) } - override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg)) - + override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2 override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2 + override fun asin(arg: Complex): Complex = -i * ln(sqrt(one - arg pow 2) + i * arg) + override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(one - arg pow 2) + i * arg) + override fun atan(arg: Complex): Complex = i * (ln(one - i * arg) - ln(one + i * arg)) / 2 override fun power(arg: Complex, pow: Number): Complex = arg.r.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta)) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 9639e4c28..3d942beac 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -8,7 +8,7 @@ import kotlin.math.pow as kpow */ interface ExtendedFieldOperations : FieldOperations, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, ExponentialOperations @@ -44,6 +44,10 @@ object RealField : ExtendedField, Norm { override inline fun sin(arg: Double) = kotlin.math.sin(arg) override inline fun cos(arg: Double) = kotlin.math.cos(arg) + override inline fun tan(arg: Double): Double = kotlin.math.tan(arg) + override inline fun acos(arg: Double): Double = kotlin.math.acos(arg) + override inline fun asin(arg: Double): Double = kotlin.math.asin(arg) + override inline fun atan(arg: Double): Double = kotlin.math.atan(arg) override inline fun power(arg: Double, pow: Number) = arg.kpow(pow.toDouble()) @@ -75,6 +79,10 @@ object FloatField : ExtendedField, Norm { override inline fun sin(arg: Float) = kotlin.math.sin(arg) override inline fun cos(arg: Float) = kotlin.math.cos(arg) + override inline fun tan(arg: Float) = kotlin.math.tan(arg) + override inline fun acos(arg: Float) = kotlin.math.acos(arg) + override inline fun asin(arg: Float) = kotlin.math.asin(arg) + override inline fun atan(arg: Float) = kotlin.math.atan(arg) override inline fun power(arg: Float, pow: Number) = arg.pow(pow.toFloat()) @@ -180,4 +188,4 @@ object LongRing : Ring, Norm { override inline fun Long.minus(b: Long) = (this - b) override inline fun Long.times(b: Long) = (this * b) -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index bd83932e7..bbd7a110e 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -13,16 +13,33 @@ package scientifik.kmath.operations interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T + fun tan(arg: T): T = sin(arg) / cos(arg) - fun tg(arg: T): T = sin(arg) / cos(arg) + companion object { + const val SIN_OPERATION = "sin" + const val COS_OPERATION = "cos" + const val TAN_OPERATION = "tan" + } +} - fun ctg(arg: T): T = cos(arg) / sin(arg) +interface InverseTrigonometricOperations : TrigonometricOperations { + fun asin(arg: T): T + fun acos(arg: T): T + fun atan(arg: T): T + + companion object { + const val ASIN_OPERATION = "asin" + const val ACOS_OPERATION = "acos" + const val ATAN_OPERATION = "atan" + } } fun >> sin(arg: T): T = arg.context.sin(arg) fun >> cos(arg: T): T = arg.context.cos(arg) -fun >> tg(arg: T): T = arg.context.tg(arg) -fun >> ctg(arg: T): T = arg.context.ctg(arg) +fun >> tan(arg: T): T = arg.context.tan(arg) +fun >> asin(arg: T): T = arg.context.asin(arg) +fun >> acos(arg: T): T = arg.context.acos(arg) +fun >> atan(arg: T): T = arg.context.atan(arg) /* Power and roots */ @@ -32,8 +49,11 @@ fun >> ctg(arg: T): T = arg.conte interface PowerOperations : Algebra { fun power(arg: T, pow: Number): T fun sqrt(arg: T) = power(arg, 0.5) - infix fun T.pow(pow: Number) = power(this, pow) + + companion object { + const val SQRT_OPERATION = "sqrt" + } } infix fun >> T.pow(power: Double): T = context.power(this, power) @@ -42,7 +62,7 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T } @@ -54,4 +74,4 @@ interface Norm { fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt index a79366a99..c7e672c28 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ComplexNDField.kt @@ -79,6 +79,13 @@ class ComplexNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) {acos(it)} + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) {atan(it)} } diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt index 3437644ff..c986ff011 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/ExtendedNDField.kt @@ -4,7 +4,7 @@ import scientifik.kmath.operations.* interface ExtendedNDField> : NDField, - TrigonometricOperations, + InverseTrigonometricOperations, PowerOperations, ExponentialOperations where F : ExtendedFieldOperations, F : Field diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 88c8c29db..2fb6d15d4 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,8 +9,10 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { + override fun add(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -22,6 +24,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun multiply(a: Buffer, k: Number): DoubleBuffer { val kValue = k.toDouble() + return if (a is DoubleBuffer) { val aArray = a.array DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) @@ -32,6 +35,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -43,6 +47,7 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { override fun divide(a: Buffer, b: Buffer): DoubleBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } + return if (a is DoubleBuffer && b is DoubleBuffer) { val aArray = a.array val bArray = b.array @@ -52,49 +57,67 @@ object RealBufferFieldOperations : ExtendedFieldOperations> { } } - override fun sin(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) - } + override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) - } + override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } + override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) } - override fun exp(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } + override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun ln(arg: Buffer): DoubleBuffer { - return if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } + override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) + } + + override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) + } + + override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) + } + + override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) + } + + override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + val array = arg.array + DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else { + DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } } @@ -119,7 +142,6 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) @@ -135,6 +157,26 @@ class RealBufferField(val size: Int) : ExtendedField> { return RealBufferFieldOperations.cos(arg) } + override fun tan(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.tan(arg) + } + + override fun asin(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.asin(arg) + } + + override fun acos(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.acos(arg) + } + + override fun atan(arg: Buffer): DoubleBuffer { + require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } + return RealBufferFieldOperations.atan(arg) + } + override fun power(arg: Buffer, pow: Number): DoubleBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) @@ -149,5 +191,4 @@ class RealBufferField(val size: Int) : ExtendedField> { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } - -} \ No newline at end of file +} diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt index 8c1bd4239..22b33aa4d 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealNDField.kt @@ -74,6 +74,13 @@ class RealNDField(override val shape: IntArray) : override fun cos(arg: NDBuffer) = map(arg) { cos(it) } + override fun tan(arg: NDBuffer): NDBuffer = map(arg) { tan(it) } + + override fun asin(arg: NDBuffer): NDBuffer = map(arg) { asin(it) } + + override fun acos(arg: NDBuffer): NDBuffer = map(arg) { acos(it) } + + override fun atan(arg: NDBuffer): NDBuffer = map(arg) { atan(it) } } -- 2.34.1 From 48b688b6b1a6888d1812a36ab5c2bca12042f8d2 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sun, 14 Jun 2020 00:06:12 +0700 Subject: [PATCH 2/6] Fix minor problems occured after merge --- .../scientifik/kmath/operations/NumberAlgebra.kt | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 830b1496d..a1b845ccc 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -20,20 +20,18 @@ interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + else -> super.unaryOperation(operation, arg) } } interface ExtendedField : ExtendedFieldOperations, Field { - override fun rightSideNumberOperation(operation: String, left: T, right: Number): T { - return when (operation) { - PowerOperations.POW_OPERATION -> power(left, right) - else -> super.rightSideNumberOperation(operation, left, right) - } - + override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) { + PowerOperations.POW_OPERATION -> power(left, right) + else -> super.rightSideNumberOperation(operation, left, right) } } + /** * Real field element wrapping double. * -- 2.34.1 From 1e2460c5b3e4f03b8351fd5ad7b5feb63311923f Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Mon, 15 Jun 2020 21:02:38 +0700 Subject: [PATCH 3/6] Rename --- .../kmath/asm/internal/{MethodVisitors.kt => methodVisitors.kt} | 0 .../kmath/asm/internal/{Optimization.kt => optimization.kt} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{MethodVisitors.kt => methodVisitors.kt} (100%) rename kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/{Optimization.kt => optimization.kt} (100%) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/MethodVisitors.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt similarity index 100% rename from kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/Optimization.kt rename to kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt -- 2.34.1 From e47ec1aeb9587b0f04f54f99d90c4da1df933152 Mon Sep 17 00:00:00 2001 From: Commander Tvis Date: Thu, 25 Jun 2020 10:07:36 +0700 Subject: [PATCH 4/6] Delete AsmCompiledExpression abstract class, implement dynamic field generation to reduce quantity of cast instructions, minor refactor and renaming of internal APIs --- kmath-ast/README.md | 18 +- .../kmath/asm/internal/AsmBuilder.kt | 220 +++++++++--------- .../asm/internal/AsmCompiledExpression.kt | 18 -- .../kmath/asm/internal/buildName.kt | 3 +- .../kmath/asm/internal/classWriters.kt | 12 +- .../kmath/asm/internal/instructionAdapters.kt | 10 + .../kmath/asm/internal/methodVisitors.kt | 4 +- .../kmath/asm/internal/specialization.kt | 2 +- 8 files changed, 140 insertions(+), 147 deletions(-) delete mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt create mode 100644 kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt diff --git a/kmath-ast/README.md b/kmath-ast/README.md index b5ca5886f..4563e17cf 100644 --- a/kmath-ast/README.md +++ b/kmath-ast/README.md @@ -24,20 +24,20 @@ For example, the following builder: package scientifik.kmath.asm.generated; import java.util.Map; -import scientifik.kmath.asm.internal.AsmCompiledExpression; -import scientifik.kmath.operations.Algebra; +import scientifik.kmath.expressions.Expression; import scientifik.kmath.operations.RealField; -// The class's name is build with MST's hash-code and collision fixing number. -public final class AsmCompiledExpression_45045_0 extends AsmCompiledExpression { - // Plain constructor - public AsmCompiledExpression_45045_0(Algebra algebra, Object[] constants) { - super(algebra, constants); +public final class AsmCompiledExpression_1073786867_0 implements Expression { + private final RealField algebra; + private final Object[] constants; + + public AsmCompiledExpression_1073786867_0(RealField algebra, Object[] constants) { + this.algebra = algebra; + this.constants = constants; } - // The actual dynamic code: public final Double invoke(Map arguments) { - return (Double)((RealField)super.algebra).add((Double)arguments.get("x"), (Double)2.0D); + return (Double)this.algebra.add(((Double)arguments.get("x")).doubleValue(), 2.0D); } } ``` diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt index 8f45c4044..536d6136d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt @@ -6,6 +6,7 @@ import org.objectweb.asm.Opcodes.RETURN import org.objectweb.asm.commons.InstructionAdapter import scientifik.kmath.asm.internal.AsmBuilder.ClassLoader import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression import scientifik.kmath.operations.Algebra import java.util.* import kotlin.reflect.KClass @@ -36,32 +37,27 @@ internal class AsmBuilder internal constructor( */ private val classLoader: ClassLoader = ClassLoader(javaClass.classLoader) - @Suppress("PrivatePropertyName") - private val T_ALGEBRA_TYPE: Type = algebra::class.asm - - @Suppress("PrivatePropertyName") - internal val T_TYPE: Type = classOfT.asm - - @Suppress("PrivatePropertyName") - private val CLASS_TYPE: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! + private val tAlgebraType: Type = algebra::class.asm + internal val tType: Type = classOfT.asm + private val classType: Type = Type.getObjectType(className.replace(oldChar = '.', newChar = '/'))!! /** - * Index of `this` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `this` variable in invoke method of the built subclass. */ private val invokeThisVar: Int = 0 /** - * Index of `arguments` variable in invoke method of [AsmCompiledExpression] built subclass. + * Index of `arguments` variable in invoke method of the built subclass. */ private val invokeArgumentsVar: Int = 1 /** - * List of constants to provide to [AsmCompiledExpression] subclass. + * List of constants to provide to the subclass. */ private val constants: MutableList = mutableListOf() /** - * Method visitor of `invoke` method of [AsmCompiledExpression] subclass. + * Method visitor of `invoke` method of the subclass. */ private lateinit var invokeMethodVisitor: InstructionAdapter internal var primitiveMode = false @@ -72,78 +68,92 @@ internal class AsmBuilder internal constructor( @Suppress("PropertyName") internal var PRIMITIVE_MASK_BOXED: Type = OBJECT_TYPE private val typeStack = Stack() - internal val expectationStack = Stack().apply { push(T_TYPE) } + internal val expectationStack: Stack = Stack().apply { push(tType) } /** - * The cache of [AsmCompiledExpression] subclass built by this builder. + * The cache for instance built by this builder. */ - private var generatedInstance: AsmCompiledExpression? = null + private var generatedInstance: Expression? = null /** - * Subclasses, loads and instantiates the [AsmCompiledExpression] for given parameters. + * Subclasses, loads and instantiates [Expression] for given parameters. * * The built instance is cached. */ @Suppress("UNCHECKED_CAST") - fun getInstance(): AsmCompiledExpression { + fun getInstance(): Expression { generatedInstance?.let { return it } - if (SIGNATURE_LETTERS.containsKey(classOfT.java)) { + if (SIGNATURE_LETTERS.containsKey(classOfT)) { primitiveMode = true - PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT.java) - PRIMITIVE_MASK_BOXED = T_TYPE + PRIMITIVE_MASK = SIGNATURE_LETTERS.getValue(classOfT) + PRIMITIVE_MASK_BOXED = tType } val classWriter = ClassWriter(ClassWriter.COMPUTE_FRAMES) { visit( Opcodes.V1_8, Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL or Opcodes.ACC_SUPER, - CLASS_TYPE.internalName, - "L${ASM_COMPILED_EXPRESSION_TYPE.internalName}<${T_TYPE.descriptor}>;", - ASM_COMPILED_EXPRESSION_TYPE.internalName, - arrayOf() + classType.internalName, + "${OBJECT_TYPE.descriptor}L${EXPRESSION_TYPE.internalName}<${tType.descriptor}>;", + OBJECT_TYPE.internalName, + arrayOf(EXPRESSION_TYPE.internalName) + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "algebra", + descriptor = tAlgebraType.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd + ) + + visitField( + access = Opcodes.ACC_PRIVATE or Opcodes.ACC_FINAL, + name = "constants", + descriptor = OBJECT_ARRAY_TYPE.descriptor, + signature = null, + value = null, + block = FieldVisitor::visitEnd ) visitMethod( Opcodes.ACC_PUBLIC, "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), + Type.getMethodDescriptor(Type.VOID_TYPE, tAlgebraType, OBJECT_ARRAY_TYPE), null, null ).instructionAdapter { val thisVar = 0 val algebraVar = 1 val constantsVar = 2 - val l0 = Label() - visitLabel(l0) - load(thisVar, CLASS_TYPE) - load(algebraVar, ALGEBRA_TYPE) + val l0 = label() + load(thisVar, classType) + invokespecial(OBJECT_TYPE.internalName, "", Type.getMethodDescriptor(Type.VOID_TYPE), false) + label() + load(thisVar, classType) + load(algebraVar, tAlgebraType) + putfield(classType.internalName, "algebra", tAlgebraType.descriptor) + label() + load(thisVar, classType) load(constantsVar, OBJECT_ARRAY_TYPE) - - invokespecial( - ASM_COMPILED_EXPRESSION_TYPE.internalName, - "", - Type.getMethodDescriptor(Type.VOID_TYPE, ALGEBRA_TYPE, OBJECT_ARRAY_TYPE), - false - ) - - val l1 = Label() - visitLabel(l1) + putfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + label() visitInsn(RETURN) - val l2 = Label() - visitLabel(l2) - visitLocalVariable("this", CLASS_TYPE.descriptor, null, l0, l2, thisVar) + val l4 = label() + visitLocalVariable("this", classType.descriptor, null, l0, l4, thisVar) visitLocalVariable( "algebra", - ALGEBRA_TYPE.descriptor, - "L${ALGEBRA_TYPE.internalName}<${T_TYPE.descriptor}>;", + tAlgebraType.descriptor, + null, l0, - l2, + l4, algebraVar ) - visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l2, constantsVar) + visitLocalVariable("constants", OBJECT_ARRAY_TYPE.descriptor, null, l0, l4, constantsVar) visitMaxs(0, 3) visitEnd() } @@ -151,22 +161,20 @@ internal class AsmBuilder internal constructor( visitMethod( Opcodes.ACC_PUBLIC or Opcodes.ACC_FINAL, "invoke", - Type.getMethodDescriptor(T_TYPE, MAP_TYPE), - "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;)${T_TYPE.descriptor}", + Type.getMethodDescriptor(tType, MAP_TYPE), + "(L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;)${tType.descriptor}", null ).instructionAdapter { invokeMethodVisitor = this visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() invokeLabel0Visitor() - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -176,7 +184,7 @@ internal class AsmBuilder internal constructor( visitLocalVariable( "arguments", MAP_TYPE.descriptor, - "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${T_TYPE.descriptor}>;", + "L${MAP_TYPE.internalName}<${STRING_TYPE.descriptor}+${tType.descriptor}>;", l0, l1, invokeArgumentsVar @@ -196,18 +204,16 @@ internal class AsmBuilder internal constructor( val thisVar = 0 val argumentsVar = 1 visitCode() - val l0 = Label() - visitLabel(l0) + val l0 = label() load(thisVar, OBJECT_TYPE) load(argumentsVar, MAP_TYPE) - invokevirtual(CLASS_TYPE.internalName, "invoke", Type.getMethodDescriptor(T_TYPE, MAP_TYPE), false) - areturn(T_TYPE) - val l1 = Label() - visitLabel(l1) + invokevirtual(classType.internalName, "invoke", Type.getMethodDescriptor(tType, MAP_TYPE), false) + areturn(tType) + val l1 = label() visitLocalVariable( "this", - CLASS_TYPE.descriptor, + classType.descriptor, null, l0, l1, @@ -225,7 +231,7 @@ internal class AsmBuilder internal constructor( .defineClass(className, classWriter.toByteArray()) .constructors .first() - .newInstance(algebra, constants.toTypedArray()) as AsmCompiledExpression + .newInstance(algebra, constants.toTypedArray()) as Expression generatedInstance = new return new @@ -235,21 +241,21 @@ internal class AsmBuilder internal constructor( * Loads a constant from */ internal fun loadTConstant(value: T) { - if (classOfT.java in INLINABLE_NUMBERS) { + if (classOfT in INLINABLE_NUMBERS) { val expectedType = expectationStack.pop()!! val mustBeBoxed = expectedType.sort == Type.OBJECT loadNumberConstant(value as Number, mustBeBoxed) - if (mustBeBoxed) typeStack.push(T_TYPE) else typeStack.push(PRIMITIVE_MASK) + if (mustBeBoxed) typeStack.push(tType) else typeStack.push(PRIMITIVE_MASK) return } - loadConstant(value as Any, T_TYPE) + loadConstant(value as Any, tType) } private fun box(): Unit = invokeMethodVisitor.invokestatic( - T_TYPE.internalName, + tType.internalName, "valueOf", - Type.getMethodDescriptor(T_TYPE, PRIMITIVE_MASK), + Type.getMethodDescriptor(tType, PRIMITIVE_MASK), false ) @@ -263,16 +269,16 @@ internal class AsmBuilder internal constructor( private fun loadConstant(value: Any, type: Type): Unit = invokeMethodVisitor.run { val idx = if (value in constants) constants.indexOf(value) else constants.apply { add(value) }.lastIndex loadThis() - getfield(CLASS_TYPE.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) + getfield(classType.internalName, "constants", OBJECT_ARRAY_TYPE.descriptor) iconst(idx) visitInsn(AALOAD) checkcast(type) } - private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, CLASS_TYPE) + private fun loadThis(): Unit = invokeMethodVisitor.load(invokeThisVar, classType) /** - * Either loads a numeric constant [value] from [AsmCompiledExpression] constants field or boxes a primitive + * Either loads a numeric constant [value] from the class's constants field or boxes a primitive * constant from the constant pool (some numbers with special opcodes like [Opcodes.ICONST_0] aren't even loaded * from it). */ @@ -292,7 +298,7 @@ internal class AsmBuilder internal constructor( if (mustBeBoxed) { box() - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) } return @@ -300,11 +306,11 @@ internal class AsmBuilder internal constructor( loadConstant(value, boxed) if (!mustBeBoxed) unbox() - else invokeMethodVisitor.checkcast(T_TYPE) + else invokeMethodVisitor.checkcast(tType) } /** - * Loads a variable [name] from [AsmCompiledExpression.invoke] [Map] parameter. The [defaultValue] may be provided. + * Loads a variable [name] arguments [Map] parameter of [Expression.invoke]. The [defaultValue] may be provided. */ internal fun loadVariable(name: String, defaultValue: T? = null): Unit = invokeMethodVisitor.run { load(invokeArgumentsVar, OBJECT_ARRAY_TYPE) @@ -319,7 +325,7 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) return } @@ -331,11 +337,11 @@ internal class AsmBuilder internal constructor( Type.getMethodDescriptor(OBJECT_TYPE, OBJECT_TYPE) ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -343,15 +349,11 @@ internal class AsmBuilder internal constructor( } /** - * Loads algebra from according field of [AsmCompiledExpression] and casts it to class of [algebra] provided. + * Loads algebra from according field of the class and casts it to class of [algebra] provided. */ internal fun loadAlgebra() { loadThis() - - invokeMethodVisitor.run { - getfield(ASM_COMPILED_EXPRESSION_TYPE.internalName, "algebra", ALGEBRA_TYPE.descriptor) - checkcast(T_ALGEBRA_TYPE) - } + invokeMethodVisitor.getfield(classType.internalName, "algebra", tAlgebraType.descriptor) } /** @@ -368,7 +370,12 @@ internal class AsmBuilder internal constructor( tArity: Int, opcode: Int = Opcodes.INVOKEINTERFACE ) { - repeat(tArity) { if (!typeStack.empty()) typeStack.pop() } + run loop@{ + repeat(tArity) { + if (typeStack.empty()) return@loop + typeStack.pop() + } + } invokeMethodVisitor.visitMethodInsn( opcode, @@ -378,12 +385,12 @@ internal class AsmBuilder internal constructor( opcode == Opcodes.INVOKEINTERFACE ) - invokeMethodVisitor.checkcast(T_TYPE) + invokeMethodVisitor.checkcast(tType) val isLastExpr = expectationStack.size == 1 val expectedType = expectationStack.pop()!! if (expectedType.sort == Type.OBJECT || isLastExpr) - typeStack.push(T_TYPE) + typeStack.push(tType) else { unbox() typeStack.push(PRIMITIVE_MASK) @@ -399,27 +406,18 @@ internal class AsmBuilder internal constructor( /** * Maps JVM primitive numbers boxed types to their letters of JVM signature convention. */ - private val SIGNATURE_LETTERS: Map, Type> by lazy { + private val SIGNATURE_LETTERS: Map, Type> by lazy { hashMapOf( - java.lang.Byte::class.java to Type.BYTE_TYPE, - java.lang.Short::class.java to Type.SHORT_TYPE, - java.lang.Integer::class.java to Type.INT_TYPE, - java.lang.Long::class.java to Type.LONG_TYPE, - java.lang.Float::class.java to Type.FLOAT_TYPE, - java.lang.Double::class.java to Type.DOUBLE_TYPE + java.lang.Byte::class to Type.BYTE_TYPE, + java.lang.Short::class to Type.SHORT_TYPE, + java.lang.Integer::class to Type.INT_TYPE, + java.lang.Long::class to Type.LONG_TYPE, + java.lang.Float::class to Type.FLOAT_TYPE, + java.lang.Double::class to Type.DOUBLE_TYPE ) } - private val BOXED_TO_PRIMITIVES: Map by lazy { - hashMapOf( - java.lang.Byte::class.asm to Type.BYTE_TYPE, - java.lang.Short::class.asm to Type.SHORT_TYPE, - java.lang.Integer::class.asm to Type.INT_TYPE, - java.lang.Long::class.asm to Type.LONG_TYPE, - java.lang.Float::class.asm to Type.FLOAT_TYPE, - java.lang.Double::class.asm to Type.DOUBLE_TYPE - ) - } + private val BOXED_TO_PRIMITIVES: Map by lazy { SIGNATURE_LETTERS.mapKeys { (k, _) -> k.asm } } private val NUMBER_CONVERTER_METHODS: Map by lazy { hashMapOf( @@ -435,15 +433,15 @@ internal class AsmBuilder internal constructor( /** * Provides boxed number types values of which can be stored in JVM bytecode constant pool. */ - private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } - internal val ASM_COMPILED_EXPRESSION_TYPE: Type = AsmCompiledExpression::class.asm - internal val NUMBER_TYPE: Type = java.lang.Number::class.asm - internal val MAP_TYPE: Type = java.util.Map::class.asm - internal val OBJECT_TYPE: Type = java.lang.Object::class.asm + private val INLINABLE_NUMBERS: Set> by lazy { SIGNATURE_LETTERS.keys } + internal val EXPRESSION_TYPE: Type by lazy { Expression::class.asm } + internal val NUMBER_TYPE: Type by lazy { java.lang.Number::class.asm } + internal val MAP_TYPE: Type by lazy { java.util.Map::class.asm } + internal val OBJECT_TYPE: Type by lazy { java.lang.Object::class.asm } @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "RemoveRedundantQualifierName") - internal val OBJECT_ARRAY_TYPE: Type = Array::class.asm - internal val ALGEBRA_TYPE: Type = Algebra::class.asm - internal val STRING_TYPE: Type = java.lang.String::class.asm + internal val OBJECT_ARRAY_TYPE: Type by lazy { Array::class.asm } + internal val ALGEBRA_TYPE: Type by lazy { Algebra::class.asm } + internal val STRING_TYPE: Type by lazy { java.lang.String::class.asm } } } diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt deleted file mode 100644 index 7c4a9fc99..000000000 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmCompiledExpression.kt +++ /dev/null @@ -1,18 +0,0 @@ -package scientifik.kmath.asm.internal - -import scientifik.kmath.expressions.Expression -import scientifik.kmath.operations.Algebra - -/** - * [Expression] partial implementation to have it subclassed by actual implementations. Provides unified storage for - * objects needed to implement the expression. - * - * @property algebra the algebra to delegate calls. - * @property constants the constants array to have persistent objects to reference in [invoke]. - */ -internal abstract class AsmCompiledExpression internal constructor( - @JvmField protected val algebra: Algebra, - @JvmField protected val constants: Array -) : Expression { - abstract override fun invoke(arguments: Map): T -} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt index 66bd039c3..41dbf5807 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/buildName.kt @@ -1,9 +1,10 @@ package scientifik.kmath.asm.internal import scientifik.kmath.ast.MST +import scientifik.kmath.expressions.Expression /** - * Creates a class name for [AsmCompiledExpression] subclassed to implement [mst] provided. + * Creates a class name for [Expression] subclassed to implement [mst] provided. * * This methods helps to avoid collisions of class name to prevent loading several classes with the same name. If there * is a colliding class, change [collision] parameter or leave it `0` to check existing classes recursively. diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt index 95d713b18..af5c1049d 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/classWriters.kt @@ -1,15 +1,17 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.ClassWriter +import org.objectweb.asm.FieldVisitor import org.objectweb.asm.MethodVisitor -internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = ClassWriter(flags).apply(block) +internal inline fun ClassWriter(flags: Int, block: ClassWriter.() -> Unit): ClassWriter = + ClassWriter(flags).apply(block) -internal inline fun ClassWriter.visitMethod( +internal inline fun ClassWriter.visitField( access: Int, name: String, descriptor: String, signature: String?, - exceptions: Array?, - block: MethodVisitor.() -> Unit -): MethodVisitor = visitMethod(access, name, descriptor, signature, exceptions).apply(block) + value: Any?, + block: FieldVisitor.() -> Unit +): FieldVisitor = visitField(access, name, descriptor, signature, value).apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt new file mode 100644 index 000000000..f47293687 --- /dev/null +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/instructionAdapters.kt @@ -0,0 +1,10 @@ +package scientifik.kmath.asm.internal + +import org.objectweb.asm.Label +import org.objectweb.asm.commons.InstructionAdapter + +internal fun InstructionAdapter.label(): Label { + val l = Label() + visitLabel(l) + return l +} diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt index 7b0d346b7..aaae02ebb 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/methodVisitors.kt @@ -3,7 +3,7 @@ package scientifik.kmath.asm.internal import org.objectweb.asm.MethodVisitor import org.objectweb.asm.commons.InstructionAdapter -fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) +internal fun MethodVisitor.instructionAdapter(): InstructionAdapter = InstructionAdapter(this) -fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = +internal fun MethodVisitor.instructionAdapter(block: InstructionAdapter.() -> Unit): InstructionAdapter = instructionAdapter().apply(block) diff --git a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt index 2e15a1a93..4c7a0d57e 100644 --- a/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt +++ b/kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/specialization.kt @@ -22,7 +22,7 @@ internal fun AsmBuilder.buildExpectationStack(context: Algebra, name: val aName = methodNameAdapters[name] ?: name val hasSpecific = context.javaClass.methods.find { it.name == aName && it.parameters.size == arity } != null - val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else T_TYPE + val t = if (primitiveMode && hasSpecific) PRIMITIVE_MASK else tType repeat(arity) { expectationStack.push(t) } return hasSpecific -- 2.34.1 From a275e74cf287bb3cf6b929b78d0db2b329f3cb43 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Fri, 26 Jun 2020 14:57:07 +0700 Subject: [PATCH 5/6] Add mapping for other dynamic operations --- .../kotlin/scientifik/kmath/operations/NumberAlgebra.kt | 4 ++++ .../kotlin/scientifik/kmath/operations/OptionalOperations.kt | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index a1b845ccc..3fb57656c 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -17,6 +17,10 @@ interface ExtendedFieldOperations : override fun unaryOperation(operation: String, arg: T): T = when (operation) { TrigonometricOperations.COS_OPERATION -> cos(arg) TrigonometricOperations.SIN_OPERATION -> sin(arg) + TrigonometricOperations.TAN_OPERATION -> tan(arg) + InverseTrigonometricOperations.ACOS_OPERATION -> acos(arg) + InverseTrigonometricOperations.ASIN_OPERATION -> asin(arg) + InverseTrigonometricOperations.ATAN_OPERATION -> atan(arg) PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index a0266c78b..542d6376b 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -13,7 +13,7 @@ package scientifik.kmath.operations interface TrigonometricOperations : FieldOperations { fun sin(arg: T): T fun cos(arg: T): T - fun tan(arg: T): T = sin(arg) / cos(arg) + fun tan(arg: T): T companion object { const val SIN_OPERATION = "sin" -- 2.34.1 From e91c5a57c493dbee9f8f1545b294b69848572cf0 Mon Sep 17 00:00:00 2001 From: Iaroslav Date: Sat, 27 Jun 2020 20:31:42 +0700 Subject: [PATCH 6/6] Minor refactor for changed ExtendedFieldOperations, replace DoubleBuffer with RealBuffer --- .../kmath/operations/NumberAlgebra.kt | 4 +- .../kmath/operations/OptionalOperations.kt | 4 +- .../kmath/structures/RealBufferField.kt | 160 ++++++++---------- 3 files changed, 76 insertions(+), 92 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt index 3fb57656c..953c5a112 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/NumberAlgebra.kt @@ -7,7 +7,6 @@ import kotlin.math.pow as kpow * Advanced Number-like field that implements basic operations */ interface ExtendedFieldOperations : - FieldOperations, InverseTrigonometricOperations, PowerOperations, ExponentialOperations { @@ -24,9 +23,8 @@ interface ExtendedFieldOperations : PowerOperations.SQRT_OPERATION -> sqrt(arg) ExponentialOperations.EXP_OPERATION -> exp(arg) ExponentialOperations.LN_OPERATION -> ln(arg) - else -> super.unaryOperation(operation, arg) + else -> super.unaryOperation(operation, arg) } - } interface ExtendedField : ExtendedFieldOperations, Field { diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt index 542d6376b..709f0260f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/operations/OptionalOperations.kt @@ -64,7 +64,7 @@ fun >> sqr(arg: T): T = arg pow 2.0 /* Exponential */ -interface ExponentialOperations: Algebra { +interface ExponentialOperations : Algebra { fun exp(arg: T): T fun ln(arg: T): T @@ -81,4 +81,4 @@ interface Norm { fun norm(arg: T): R } -fun >, R> norm(arg: T): R = arg.context.norm(arg) \ No newline at end of file +fun >, R> norm(arg: T): R = arg.context.norm(arg) diff --git a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt index 2fb6d15d4..826203d1f 100644 --- a/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt +++ b/kmath-core/src/commonMain/kotlin/scientifik/kmath/structures/RealBufferField.kt @@ -9,185 +9,171 @@ import kotlin.math.* * A simple field over linear buffers of [Double] */ object RealBufferFieldOperations : ExtendedFieldOperations> { - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] + b[it] }) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { val kValue = k.toDouble() - return if (a is DoubleBuffer) { + return if (a is RealBuffer) { val aArray = a.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * kValue }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * kValue }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * kValue }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * kValue }) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] * b[it] }) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(b.size == a.size) { "The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " } - return if (a is DoubleBuffer && b is DoubleBuffer) { + return if (a is RealBuffer && b is RealBuffer) { val aArray = a.array val bArray = b.array - DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) - } else { - DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) - } + RealBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) + } else + RealBuffer(DoubleArray(a.size) { a[it] / b[it] }) } - override fun sin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun sin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { sin(array[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { sin(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { sin(arg[it]) }) } - override fun cos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun cos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { cos(array[it]) }) + RealBuffer(DoubleArray(arg.size) { cos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + + override fun tan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { tan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { tan(arg[it]) }) + + override fun asin(arg: Buffer): RealBuffer = if (arg is RealBuffer) { + val array = arg.array + RealBuffer(DoubleArray(arg.size) { asin(array[it]) }) } else { - DoubleBuffer(DoubleArray(arg.size) { cos(arg[it]) }) + RealBuffer(DoubleArray(arg.size) { asin(arg[it]) }) } - override fun tan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun acos(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { tan(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { tan(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { acos(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - override fun asin(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun atan(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { asin(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { asin(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { atan(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - override fun acos(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun power(arg: Buffer, pow: Number): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { acos(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { acos(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) + } else + RealBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - override fun atan(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { + override fun exp(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { atan(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { atan(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { exp(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - override fun power(arg: Buffer, pow: Number): DoubleBuffer = if (arg is DoubleBuffer) { + override fun ln(arg: Buffer): RealBuffer = if (arg is RealBuffer) { val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { array[it].pow(pow.toDouble()) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { arg[it].pow(pow.toDouble()) }) - } - - override fun exp(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { exp(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { exp(arg[it]) }) - } - - override fun ln(arg: Buffer): DoubleBuffer = if (arg is DoubleBuffer) { - val array = arg.array - DoubleBuffer(DoubleArray(arg.size) { ln(array[it]) }) - } else { - DoubleBuffer(DoubleArray(arg.size) { ln(arg[it]) }) - } + RealBuffer(DoubleArray(arg.size) { ln(array[it]) }) + } else + RealBuffer(DoubleArray(arg.size) { ln(arg[it]) }) } class RealBufferField(val size: Int) : ExtendedField> { + override val zero: Buffer by lazy { RealBuffer(size) { 0.0 } } + override val one: Buffer by lazy { RealBuffer(size) { 1.0 } } - override val zero: Buffer by lazy { DoubleBuffer(size) { 0.0 } } - - override val one: Buffer by lazy { DoubleBuffer(size) { 1.0 } } - - override fun add(a: Buffer, b: Buffer): DoubleBuffer { + override fun add(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.add(a, b) } - override fun multiply(a: Buffer, k: Number): DoubleBuffer { + override fun multiply(a: Buffer, k: Number): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, k) } - override fun multiply(a: Buffer, b: Buffer): DoubleBuffer { + override fun multiply(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.multiply(a, b) } - override fun divide(a: Buffer, b: Buffer): DoubleBuffer { + override fun divide(a: Buffer, b: Buffer): RealBuffer { require(a.size == size) { "The buffer size ${a.size} does not match context size $size" } return RealBufferFieldOperations.divide(a, b) } - override fun sin(arg: Buffer): DoubleBuffer { + override fun sin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.sin(arg) } - override fun cos(arg: Buffer): DoubleBuffer { + override fun cos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.cos(arg) } - override fun tan(arg: Buffer): DoubleBuffer { + override fun tan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.tan(arg) } - override fun asin(arg: Buffer): DoubleBuffer { + override fun asin(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.asin(arg) } - override fun acos(arg: Buffer): DoubleBuffer { + override fun acos(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.acos(arg) } - override fun atan(arg: Buffer): DoubleBuffer { + override fun atan(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.atan(arg) } - override fun power(arg: Buffer, pow: Number): DoubleBuffer { + override fun power(arg: Buffer, pow: Number): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.power(arg, pow) } - override fun exp(arg: Buffer): DoubleBuffer { + override fun exp(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.exp(arg) } - override fun ln(arg: Buffer): DoubleBuffer { + override fun ln(arg: Buffer): RealBuffer { require(arg.size == size) { "The buffer size ${arg.size} does not match context size $size" } return RealBufferFieldOperations.ln(arg) } -- 2.34.1