Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-improved-trigonometry

This commit is contained in:
Iaroslav 2020-06-16 14:40:41 +07:00
commit f8f1814def
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
14 changed files with 253 additions and 447 deletions

View File

@ -27,7 +27,7 @@ fun main() {
val complexTime = measureTimeMillis { val complexTime = measureTimeMillis {
complexField.run { complexField.run {
var res = one var res: NDBuffer<Complex> = one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
} }

View File

@ -23,14 +23,14 @@ fun main() {
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField.run { autoField.run {
var res = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += number(1.0)
} }
} }
} }
measureAndPrint("Element addition"){ measureAndPrint("Element addition") {
var res = genericField.one var res = genericField.one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
@ -63,7 +63,7 @@ fun main() {
genericField.run { genericField.run {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += one // con't avoid using `one` due to resolution ambiguity
} }
} }
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.ast package scientifik.kmath.ast
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
@ -40,12 +41,14 @@ sealed class MST {
//TODO add a function with named arguments //TODO add a function with named arguments
fun <T> NumericAlgebra<T>.evaluate(node: MST): T { fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) { return when (node) {
is MST.Numeric -> number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField.binaryOperation(
node.operation, node.operation,
@ -60,3 +63,5 @@ fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
} }
} }
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -0,0 +1,33 @@
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MSTAlgebra : NumericAlgebra<MST> {
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
override fun number(value: Number): MST = MST.Numeric(value)
}
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
override val zero: MST = number(0.0)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
}
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
override val one: MST = number(1.0)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
}
object MSTField : Field<MST>, Ring<MST> by MSTRing {
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
}

View File

@ -1,292 +0,0 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.*
import kotlin.reflect.KClass
/**
* A function declaration that could be compiled to [AsmBuilder].
*
* @param T the type the stored function returns.
*/
sealed class AsmExpression<T : Any>: Expression<T> {
abstract val type: KClass<out T>
abstract val algebra: Algebra<T>
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
internal open fun tryEvaluate(): T? = null
/**
* Compiles this declaration.
*
* @param gen the target [AsmBuilder].
*/
internal abstract fun appendTo(gen: AsmBuilder<T>)
/**
* Compile and cache the expression
*/
private val compiledExpression by lazy{
val builder = AsmBuilder(type.java, algebra, buildName(this))
this.appendTo(builder)
builder.generate()
}
override fun invoke(arguments: Map<String, T>): T = compiledExpression.invoke(arguments)
}
internal class AsmUnaryOperation<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
expr: AsmExpression<T>
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = algebra { unaryOperation(name, expr.tryEvaluate() ?: return@algebra null) }
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(algebra, name, 1))
gen.visitStringConstant(name)
expr.appendTo(gen)
if (gen.tryInvokeSpecific(algebra, name, 1))
return
gen.visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmBinaryOperation<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
first: AsmExpression<T>,
second: AsmExpression<T>
) : AsmExpression<T>() {
private val first: AsmExpression<T> = first.optimize()
private val second: AsmExpression<T> = second.optimize()
override fun tryEvaluate(): T? = algebra {
binaryOperation(
name,
first.tryEvaluate() ?: return@algebra null,
second.tryEvaluate() ?: return@algebra null
)
}
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(algebra, name, 2))
gen.visitStringConstant(name)
first.appendTo(gen)
second.appendTo(gen)
if (gen.tryInvokeSpecific(algebra, name, 2))
return
gen.visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmVariableExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
private val default: T? = null
) : AsmExpression<T>() {
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitLoadFromVariables(name, default)
}
internal class AsmConstantExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val value: T
) : AsmExpression<T>() {
override fun tryEvaluate(): T = value
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitLoadFromConstants(value)
}
internal class AsmConstProductExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Space<T>,
expr: AsmExpression<T>,
private val const: Number
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const }
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
gen.visitNumberConstant(const)
expr.appendTo(gen)
gen.visitAlgebraOperation(
owner = AsmBuilder.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.NUMBER_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmNumberExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: NumericAlgebra<T>,
private val value: Number
) : AsmExpression<T>() {
override fun tryEvaluate(): T? = algebra.number(value)
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitNumberConstant(value)
}
internal abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* A context class for [AsmExpression] construction.
*
* @param algebra The algebra to provide for AsmExpressions built.
*/
open class AsmExpressionAlgebra<T : Any, A : NumericAlgebra<T>>(val type: KClass<out T>, val algebra: A) :
NumericAlgebra<AsmExpression<T>>, ExpressionAlgebra<T, AsmExpression<T>> {
/**
* Builds an AsmExpression to wrap a number.
*/
override fun number(value: Number): AsmExpression<T> = AsmNumberExpression(type, algebra, value)
/**
* Builds an AsmExpression of constant expression which does not depend on arguments.
*/
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(type, algebra, value)
/**
* Builds an AsmExpression to access a variable.
*/
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(type, algebra, name, default)
/**
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, operation, left, right)
/**
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
AsmUnaryOperation(type, algebra, operation, arg)
}
/**
* A context class for [AsmExpression] construction for [Space] algebras.
*/
open class AsmExpressionSpace<T : Any, A>(type: KClass<out T>, algebra: A) : AsmExpressionAlgebra<T, A>(type, algebra),
Space<AsmExpression<T>> where A : Space<T>, A : NumericAlgebra<T> {
override val zero: AsmExpression<T> get() = const(algebra.zero)
/**
* Builds an AsmExpression of addition of two another expressions.
*/
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an AsmExpression of multiplication of expression by number.
*/
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(type, algebra, a, k)
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
}
/**
* A context class for [AsmExpression] construction for [Ring] algebras.
*/
open class AsmExpressionRing<T : Any, A>(type: KClass<out T>, algebra: A) : AsmExpressionSpace<T, A>(type, algebra),
Ring<AsmExpression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: AsmExpression<T> get() = const(algebra.one)
/**
* Builds an AsmExpression of multiplication of two expressions.
*/
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionSpace>.number(value)
}
/**
* A context class for [AsmExpression] construction for [Field] algebras.
*/
open class AsmExpressionField<T : Any, A>(type: KClass<out T>, algebra: A) :
AsmExpressionRing<T, A>(type, algebra),
Field<AsmExpression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an AsmExpression of division an expression by another one.
*/
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionRing>.number(value)
}

View File

@ -1,15 +1,24 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.MSTField
import scientifik.kmath.ast.MSTRing
import scientifik.kmath.ast.MSTSpace
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Field import scientifik.kmath.operations.*
import scientifik.kmath.operations.NumericAlgebra import kotlin.reflect.KClass
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" /**
* Compile given MST to an Expression using AST compiler
*/
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try { try {
Class.forName(name) Class.forName(name)
@ -17,31 +26,78 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String
return name return name
} }
return buildName(expression, collision + 1) return buildName(mst, collision + 1)
}
fun AsmBuilder<T>.visit(node: MST): Unit {
when (node) {
is MST.Symbolic -> visitLoadFromVariables(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value)
} else {
error("Number literals are not supported in $algebra")
}
visitLoadFromConstants(constant)
}
is MST.Unary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
is MST.Binary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 2))
visitStringConstant(node.operation)
visit(node.left)
visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
}
}
val builder = AsmBuilder(type.java, algebra, buildName(this))
builder.visit(this)
return builder.generate()
} }
inline fun <reified T : Any, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm( inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
expressionAlgebra: E,
block: E.() -> AsmExpression<T>
): Expression<T> = expressionAlgebra.block()
inline fun <reified T : Any> NumericAlgebra<T>.asm(ast: MST): Expression<T> = inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
AsmExpressionAlgebra(T::class, this).evaluate(ast) mstAlgebra: E,
block: E.() -> MST
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> = inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
AsmExpressionSpace<T, A>(T::class, this).block() MSTSpace.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> = inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
asmSpace { evaluate(ast) } MSTRing.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> = inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
AsmExpressionRing(T::class, this).block() MSTField.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
asmRing { evaluate(ast) }
inline fun <reified T : Any, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
AsmExpressionField(T::class, this).block()
inline fun <reified T : Any, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
asmRing { evaluate(ast) }

View File

@ -4,10 +4,19 @@ import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Label import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.FunctionalCompiledExpression
import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/** /**
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java * AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java
* expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new * expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new
@ -16,6 +25,8 @@ import scientifik.kmath.operations.Algebra
* @param T the type of AsmExpression to unwrap. * @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use. * @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class. * @param className the unique class name of new loaded class.
*
* @author [Iaroslav Postovalov](https://github.com/CommanderTvis)
*/ */
internal class AsmBuilder<T>( internal class AsmBuilder<T>(
private val classOfT: Class<*>, private val classOfT: Class<*>,
@ -44,7 +55,6 @@ internal class AsmBuilder<T>(
private val invokeMethodVisitor: MethodVisitor private val invokeMethodVisitor: MethodVisitor
private val invokeL0: Label private val invokeL0: Label
private lateinit var invokeL1: Label private lateinit var invokeL1: Label
private var generatedInstance: FunctionalCompiledExpression<T>? = null
init { init {
asmCompiledClassWriter.visit( asmCompiledClassWriter.visit(
@ -113,8 +123,7 @@ internal class AsmBuilder<T>(
} }
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
fun generate(): FunctionalCompiledExpression<T> { fun generate(): Expression<T> {
generatedInstance?.let { return it }
invokeMethodVisitor.run { invokeMethodVisitor.run {
visitInsn(Opcodes.ARETURN) visitInsn(Opcodes.ARETURN)
@ -182,7 +191,6 @@ internal class AsmBuilder<T>(
.first() .first()
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T> .newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
generatedInstance = new
return new return new
} }

View File

@ -1,8 +1,6 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmExpression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide") private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
@ -41,8 +39,8 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
return true return true
} }
//
internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> { //internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
val a = tryEvaluate() // val a = tryEvaluate()
return if (a == null) this else AsmConstantExpression(type, algebra, a) // return if (a == null) this else AsmConstantExpression(type, algebra, a)
} //}

View File

@ -1,75 +1,66 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.asmField //
import scientifik.kmath.asm.asmRing //class TestAsmAlgebras {
import scientifik.kmath.asm.asmSpace // @Test
import scientifik.kmath.expressions.invoke // fun space() {
import scientifik.kmath.operations.ByteRing // val res = ByteRing.asmInRing {
import scientifik.kmath.operations.RealField // binaryOperation(
import kotlin.test.Test // "+",
import kotlin.test.assertEquals //
// unaryOperation(
class TestAsmAlgebras { // "+",
@Test // 3.toByte() - (2.toByte() + (multiply(
fun space() { // add(number(1), number(1)),
val res = ByteRing.asmSpace { // 2
binaryOperation( // ) + 1.toByte()) * 3.toByte() - 1.toByte())
"+", // ),
//
unaryOperation( // number(1)
"+", // ) + symbol("x") + zero
3.toByte() - (2.toByte() + (multiply( // }("x" to 2.toByte())
add(const(1), const(1)), //
2 // assertEquals(16, res)
) + 1.toByte()) * 3.toByte() - 1.toByte()) // }
), //
// @Test
number(1) // fun ring() {
) + variable("x") + zero // val res = ByteRing.asmInRing {
}("x" to 2.toByte()) // binaryOperation(
// "+",
assertEquals(16, res) //
} // unaryOperation(
// "+",
@Test // (3.toByte() - (2.toByte() + (multiply(
fun ring() { // add(const(1), const(1)),
val res = ByteRing.asmRing { // 2
binaryOperation( // ) + 1.toByte()))) * 3.0 - 1.toByte()
"+", // ),
//
unaryOperation( // number(1)
"+", // ) * const(2)
(3.toByte() - (2.toByte() + (multiply( // }()
add(const(1), const(1)), //
2 // assertEquals(24, res)
) + 1.toByte()))) * 3.0 - 1.toByte() // }
), //
// @Test
number(1) // fun field() {
) * const(2) // val res = RealField.asmInField {
}() // +(3 - 2 + 2*(number(1)+1.0)
//
assertEquals(24, res) // unaryOperation(
} // "+",
// (3.0 - (2.0 + (multiply(
@Test // add((1.0), const(1.0)),
fun field() { // 2
val res = RealField.asmField { // ) + 1.0))) * 3 - 1.0
divide(binaryOperation( // )+
"+", //
// number(1)
unaryOperation( // ) / 2, const(2.0)) * one
"+", // }()
(3.0 - (2.0 + (multiply( //
add(const(1.0), const(1.0)), // assertEquals(3.0, res)
2 // }
) + 1.0))) * 3 - 1.0 //}
),
number(1)
) / 2, const(2.0)) * one
}()
assertEquals(3.0, res)
}
}

View File

@ -1,6 +1,6 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.asmField import scientifik.kmath.asm.asmInField
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
@ -9,13 +9,13 @@ import kotlin.test.assertEquals
class TestAsmExpressions { class TestAsmExpressions {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0) val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
assertEquals(2.0, res) assertEquals(-2.0, res)
} }
@Test @Test
fun testConstProductInvocation() { fun testConstProductInvocation() {
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0) val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res) assertEquals(4.0, res)
} }
} }

View File

@ -1,18 +1,18 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.asm.asmField import scientifik.kmath.asm.compile
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest { class AsmTest {
@Test @Test
fun parsedExpression() { fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.asmField(mst)() val res = ComplexField.compile(mst)()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
} }

View File

@ -76,10 +76,10 @@ class DerivativeStructureField(
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this override operator fun Number.plus(b: DerivativeStructure) = b + this
operator fun Number.minus(s: DerivativeStructure) = s - this override operator fun Number.minus(b: DerivativeStructure) = b - this
} }
/** /**

View File

@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
// Overloads for Double constants // Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> = override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z -> derive(variable { this@plus.toDouble() * one + b.value }) { z ->
that.d += z.d b.d += z.d
} }
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this) override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> = override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z -> derive(variable { this@minus.toDouble() * one - b.value }) { z ->
that.d -= z.d b.d -= z.d
} }
operator fun Variable<T>.minus(that: Number): Variable<T> = override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z -> derive(variable { this@minus.value - one * b.toDouble() }) { z ->
this@minus.d += z.d this@minus.d += z.d
} }
} }

View File

@ -12,9 +12,6 @@ interface Algebra<T> {
*/ */
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
@Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol"))
fun raw(value: String): T = symbol(value)
/** /**
* Dynamic call of unary operation with name [operation] on [arg] * Dynamic call of unary operation with name [operation] on [arg]
*/ */
@ -64,6 +61,8 @@ interface SpaceOperations<T> : Algebra<T> {
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.unaryPlus(): T = this
operator fun T.plus(b: T): T = add(this, b) operator fun T.plus(b: T): T = add(this, b)
operator fun T.minus(b: T): T = add(this, -b) operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.times(k: Number) = multiply(this, k.toDouble())
@ -138,17 +137,25 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble() override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right) else -> super.leftSideNumberOperation(operation, left, right)
} }
//TODO those operators are blocked by type conflict in RealField override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this operator fun T.plus(b: Number) = this.plus(number(b))
// operator fun Number.plus(b: T) = b + this
// operator fun T.minus(b: Number) = this.minus(b * one)
// operator fun Number.minus(b: T) = -b + this operator fun T.minus(b: Number) = this.minus(number(b))
operator fun Number.minus(b: T) = -b + this
} }
/** /**