Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-refactor-agc

# Conflicts:
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt
This commit is contained in:
Iaroslav 2020-06-16 13:28:25 +07:00
commit fe64537cbc
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
13 changed files with 239 additions and 421 deletions

View File

@ -27,7 +27,7 @@ fun main() {
val complexTime = measureTimeMillis { val complexTime = measureTimeMillis {
complexField.run { complexField.run {
var res = one var res: NDBuffer<Complex> = one
repeat(n) { repeat(n) {
res += 1.0 res += 1.0
} }

View File

@ -23,9 +23,9 @@ fun main() {
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField.run { autoField.run {
var res = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += number(1.0)
} }
} }
} }
@ -63,7 +63,7 @@ fun main() {
genericField.run { genericField.run {
var res: NDBuffer<Double> = one var res: NDBuffer<Double> = one
repeat(n) { repeat(n) {
res += 1.0 res += one // con't avoid using `one` due to resolution ambiguity
} }
} }
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.ast package scientifik.kmath.ast
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
@ -40,12 +41,14 @@ sealed class MST {
//TODO add a function with named arguments //TODO add a function with named arguments
fun <T> NumericAlgebra<T>.evaluate(node: MST): T { fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) { return when (node) {
is MST.Numeric -> number(node.value) is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value) is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when { is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField.binaryOperation(
node.operation, node.operation,
@ -60,3 +63,5 @@ fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
} }
} }
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -0,0 +1,33 @@
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MSTAlgebra : NumericAlgebra<MST> {
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
override fun number(value: Number): MST = MST.Numeric(value)
}
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
override val zero: MST = number(0.0)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
}
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
override val one: MST = number(1.0)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
}
object MSTField : Field<MST>, Ring<MST> by MSTRing {
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
}

View File

@ -1,265 +0,0 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.*
/**
* A function declaration that could be compiled to [AsmBuilder].
*
* @param T the type the stored function returns.
*/
abstract class AsmExpression<T> internal constructor() {
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
internal open fun tryEvaluate(): T? = null
/**
* Compiles this declaration.
*
* @param gen the target [AsmBuilder].
*/
@PublishedApi
internal abstract fun compile(gen: AsmBuilder<T>)
}
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
if (!hasSpecific(context, name, 1))
gen.loadStringConstant(name)
expr.compile(gen)
if (gen.tryInvokeSpecific(context, name, 1))
return
gen.invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmBinaryOperation<T>(
private val context: Algebra<T>,
private val name: String,
first: AsmExpression<T>,
second: AsmExpression<T>
) : AsmExpression<T>() {
private val first: AsmExpression<T> = first.optimize()
private val second: AsmExpression<T> = second.optimize()
override fun tryEvaluate(): T? = context {
binaryOperation(
name,
first.tryEvaluate() ?: return@context null,
second.tryEvaluate() ?: return@context null
)
}
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
if (!hasSpecific(context, name, 2))
gen.loadStringConstant(name)
first.compile(gen)
second.compile(gen)
if (gen.tryInvokeSpecific(context, name, 2))
return
gen.invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
AsmExpression<T>() {
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadVariable(name, default)
}
internal class AsmConstantExpression<T>(private val value: T) :
AsmExpression<T>() {
override fun tryEvaluate(): T = value
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadTConstant(value)
}
internal class AsmConstProductExpression<T>(
private val context: Space<T>,
expr: AsmExpression<T>,
private val const: Number
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
gen.loadNumberConstant(const)
expr.compile(gen)
gen.invokeAlgebraOperation(
owner = AsmBuilder.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.NUMBER_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, private val value: Number) :
AsmExpression<T>() {
override fun tryEvaluate(): T? = context.number(value)
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadNumberConstant(value)
}
internal abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* A context class for [AsmExpression] construction.
*/
interface AsmExpressionAlgebra<T, A : NumericAlgebra<T>> : NumericAlgebra<AsmExpression<T>>,
ExpressionAlgebra<T, AsmExpression<T>> {
/**
* The algebra to provide for AsmExpressions built.
*/
val algebra: A
/**
* Builds an AsmExpression to wrap a number.
*/
override fun number(value: Number): AsmExpression<T> = AsmNumberExpression(algebra, value)
/**
* Builds an AsmExpression of constant expression which does not depend on arguments.
*/
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
/**
* Builds an AsmExpression to access a variable.
*/
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
/**
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, operation, left, right)
/**
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
AsmUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [AsmExpression] construction for [Space] algebras.
*/
open class AsmExpressionSpace<T, A>(override val algebra: A) : AsmExpressionAlgebra<T, A>,
Space<AsmExpression<T>> where A : Space<T>, A : NumericAlgebra<T> {
override val zero: AsmExpression<T>
get() = const(algebra.zero)
/**
* Builds an AsmExpression of addition of two another expressions.
*/
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an AsmExpression of multiplication of expression by number.
*/
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(algebra, a, k)
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
}
/**
* A context class for [AsmExpression] construction for [Ring] algebras.
*/
open class AsmExpressionRing<T, A>(override val algebra: A) : AsmExpressionSpace<T, A>(algebra),
Ring<AsmExpression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: AsmExpression<T>
get() = const(algebra.one)
/**
* Builds an AsmExpression of multiplication of two expressions.
*/
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionSpace>.number(value)
}
/**
* A context class for [AsmExpression] construction for [Field] algebras.
*/
open class AsmExpressionField<T, A>(override val algebra: A) :
AsmExpressionRing<T, A>(algebra),
Field<AsmExpression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an AsmExpression of division an expression by another one.
*/
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionRing>.number(value)
}

View File

@ -1,14 +1,24 @@
package scientifik.kmath.asm package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST import scientifik.kmath.ast.MST
import scientifik.kmath.ast.evaluate import scientifik.kmath.ast.MSTField
import scientifik.kmath.ast.MSTRing
import scientifik.kmath.ast.MSTSpace
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.* import scientifik.kmath.operations.*
import kotlin.reflect.KClass
@PublishedApi
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String { /**
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision" * Compile given MST to an Expression using AST compiler
*/
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try { try {
Class.forName(name) Class.forName(name)
@ -16,38 +26,78 @@ internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String
return name return name
} }
return buildName(expression, collision + 1) return buildName(mst, collision + 1)
} }
@PublishedApi fun AsmBuilder<T>.visit(node: MST): Unit {
internal inline fun <reified T> AsmExpression<T>.compile(algebra: Algebra<T>): Expression<T> = when (node) {
AsmBuilder(T::class.java, algebra, buildName(this), this).getInstance() is MST.Symbolic -> visitLoadFromVariables(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value)
} else {
error("Number literals are not supported in $algebra")
}
visitLoadFromConstants(constant)
}
is MST.Unary -> {
visitLoadAlgebra()
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm( if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
expressionAlgebra: E,
block: E.() -> AsmExpression<T>
): Expression<T> = expressionAlgebra.block().compile(expressionAlgebra.algebra)
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm( visit(node.value)
expressionAlgebra: E,
ast: MST
): Expression<T> = asm(expressionAlgebra) { evaluate(ast) }
inline fun <reified T, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> = if (!tryInvokeSpecific(algebra, node.operation, 1)) {
AsmExpressionSpace(this).let { it.block().compile(it.algebra) } visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
is MST.Binary -> {
visitLoadAlgebra()
inline fun <reified T, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> = if (!hasSpecific(algebra, node.operation, 2))
asmSpace { evaluate(ast) } visitStringConstant(node.operation)
inline fun <reified T, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> = visit(node.left)
AsmExpressionRing(this).let { it.block().compile(it.algebra) } visit(node.right)
inline fun <reified T, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> = if (!tryInvokeSpecific(algebra, node.operation, 2)) {
asmRing { evaluate(ast) }
inline fun <reified T, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> = visitAlgebraOperation(
AsmExpressionField(this).let { it.block().compile(it.algebra) } owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
}
}
inline fun <reified T, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> = val builder = AsmBuilder(type.java, algebra, buildName(this))
asmRing { evaluate(ast) } builder.visit(this)
return builder.generate()
}
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
mstAlgebra: E,
block: E.() -> MST
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
MSTSpace.block().compileWith(T::class, this)
inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
MSTRing.block().compileWith(T::class, this)
inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
MSTField.block().compileWith(T::class, this)

View File

@ -1,8 +1,6 @@
package scientifik.kmath.asm.internal package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmExpression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide") private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
@ -31,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
append("L${AsmBuilder.OBJECT_CLASS};") append("L${AsmBuilder.OBJECT_CLASS};")
} }
invokeAlgebraOperation( visitAlgebraOperation(
owner = owner, owner = owner,
method = aName, method = aName,
descriptor = sig, descriptor = sig,
@ -41,9 +39,8 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
return true return true
} }
//
@PublishedApi //internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
internal fun <T> AsmExpression<T>.optimize(): AsmExpression<T> { // val a = tryEvaluate()
val a = tryEvaluate() // return if (a == null) this else AsmConstantExpression(type, algebra, a)
return if (a == null) this else AsmConstantExpression(a) //}
}

View File

@ -1,75 +1,66 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.asmField //
import scientifik.kmath.asm.asmRing //class TestAsmAlgebras {
import scientifik.kmath.asm.asmSpace // @Test
import scientifik.kmath.expressions.invoke // fun space() {
import scientifik.kmath.operations.ByteRing // val res = ByteRing.asmInRing {
import scientifik.kmath.operations.RealField // binaryOperation(
import kotlin.test.Test // "+",
import kotlin.test.assertEquals //
// unaryOperation(
class TestAsmAlgebras { // "+",
@Test // 3.toByte() - (2.toByte() + (multiply(
fun space() { // add(number(1), number(1)),
val res = ByteRing.asmSpace { // 2
binaryOperation( // ) + 1.toByte()) * 3.toByte() - 1.toByte())
"+", // ),
//
unaryOperation( // number(1)
"+", // ) + symbol("x") + zero
3.toByte() - (2.toByte() + (multiply( // }("x" to 2.toByte())
add(const(1), const(1)), //
2 // assertEquals(16, res)
) + 1.toByte()) * 3.toByte() - 1.toByte()) // }
), //
// @Test
number(1) // fun ring() {
) + variable("x") + zero // val res = ByteRing.asmInRing {
}("x" to 2.toByte()) // binaryOperation(
// "+",
assertEquals(16, res) //
} // unaryOperation(
// "+",
@Test // (3.toByte() - (2.toByte() + (multiply(
fun ring() { // add(const(1), const(1)),
val res = ByteRing.asmRing { // 2
binaryOperation( // ) + 1.toByte()))) * 3.0 - 1.toByte()
"+", // ),
//
unaryOperation( // number(1)
"+", // ) * const(2)
(3.toByte() - (2.toByte() + (multiply( // }()
add(const(1), const(1)), //
2 // assertEquals(24, res)
) + 1.toByte()))) * 3.0 - 1.toByte() // }
), //
// @Test
number(1) // fun field() {
) * const(2) // val res = RealField.asmInField {
}() // +(3 - 2 + 2*(number(1)+1.0)
//
assertEquals(24, res) // unaryOperation(
} // "+",
// (3.0 - (2.0 + (multiply(
@Test // add((1.0), const(1.0)),
fun field() { // 2
val res = RealField.asmField { // ) + 1.0))) * 3 - 1.0
divide(binaryOperation( // )+
"+", //
// number(1)
unaryOperation( // ) / 2, const(2.0)) * one
"+", // }()
(3.0 - (2.0 + (multiply( //
add(const(1.0), const(1.0)), // assertEquals(3.0, res)
2 // }
) + 1.0))) * 3 - 1.0 //}
),
number(1)
) / 2, const(2.0)) * one
}()
assertEquals(3.0, res)
}
}

View File

@ -1,6 +1,6 @@
package scietifik.kmath.asm package scietifik.kmath.asm
import scientifik.kmath.asm.asmField import scientifik.kmath.asm.asmInField
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
@ -9,13 +9,13 @@ import kotlin.test.assertEquals
class TestAsmExpressions { class TestAsmExpressions {
@Test @Test
fun testUnaryOperationInvocation() { fun testUnaryOperationInvocation() {
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0) val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
assertEquals(2.0, res) assertEquals(-2.0, res)
} }
@Test @Test
fun testConstProductInvocation() { fun testConstProductInvocation() {
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0) val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res) assertEquals(4.0, res)
} }
} }

View File

@ -1,18 +1,18 @@
package scietifik.kmath.ast package scietifik.kmath.ast
import scientifik.kmath.asm.asmField import scientifik.kmath.asm.compile
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest { class AsmTest {
@Test @Test
fun parsedExpression() { fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.asmField(mst)() val res = ComplexField.compile(mst)()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
} }

View File

@ -74,10 +74,10 @@ class DerivativeStructureField(
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble()) override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble()) override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this override operator fun Number.plus(b: DerivativeStructure) = b + this
operator fun Number.minus(s: DerivativeStructure) = s - this override operator fun Number.minus(b: DerivativeStructure) = b - this
} }
/** /**

View File

@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
// Overloads for Double constants // Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> = override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z -> derive(variable { this@plus.toDouble() * one + b.value }) { z ->
that.d += z.d b.d += z.d
} }
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this) override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> = override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z -> derive(variable { this@minus.toDouble() * one - b.value }) { z ->
that.d -= z.d b.d -= z.d
} }
operator fun Variable<T>.minus(that: Number): Variable<T> = override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z -> derive(variable { this@minus.value - one * b.toDouble() }) { z ->
this@minus.d += z.d this@minus.d += z.d
} }
} }

View File

@ -12,9 +12,6 @@ interface Algebra<T> {
*/ */
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this") fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
@Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol"))
fun raw(value: String): T = symbol(value)
/** /**
* Dynamic call of unary operation with name [operation] on [arg] * Dynamic call of unary operation with name [operation] on [arg]
*/ */
@ -64,6 +61,8 @@ interface SpaceOperations<T> : Algebra<T> {
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176 //Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0) operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.unaryPlus(): T = this
operator fun T.plus(b: T): T = add(this, b) operator fun T.plus(b: T): T = add(this, b)
operator fun T.minus(b: T): T = add(this, -b) operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble()) operator fun T.times(k: Number) = multiply(this, k.toDouble())
@ -138,17 +137,25 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble() override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) { override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right) else -> super.leftSideNumberOperation(operation, left, right)
} }
//TODO those operators are blocked by type conflict in RealField override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this operator fun T.plus(b: Number) = this.plus(number(b))
// operator fun Number.plus(b: T) = b + this
// operator fun T.minus(b: Number) = this.minus(b * one)
// operator fun Number.minus(b: T) = -b + this operator fun T.minus(b: Number) = this.minus(number(b))
operator fun Number.minus(b: T) = -b + this
} }
/** /**