Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-refactor-agc

# Conflicts:
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/AsmExpressions.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/asm.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/AsmBuilder.kt
#	kmath-ast/src/jvmMain/kotlin/scientifik/kmath/asm/internal/optimization.kt
This commit is contained in:
Iaroslav 2020-06-16 13:28:25 +07:00
commit fe64537cbc
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
13 changed files with 239 additions and 421 deletions
examples/src/main/kotlin/scientifik/kmath/structures
kmath-ast/src
commonMain/kotlin/scientifik/kmath/ast
jvmMain/kotlin/scientifik/kmath/asm
jvmTest/kotlin/scietifik/kmath
kmath-commons/src/main/kotlin/scientifik/kmath/commons/expressions
kmath-core/src/commonMain/kotlin/scientifik/kmath
misc
operations

@ -27,7 +27,7 @@ fun main() {
val complexTime = measureTimeMillis {
complexField.run {
var res = one
var res: NDBuffer<Complex> = one
repeat(n) {
res += 1.0
}

@ -23,14 +23,14 @@ fun main() {
measureAndPrint("Automatic field addition") {
autoField.run {
var res = one
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
res += number(1.0)
}
}
}
measureAndPrint("Element addition"){
measureAndPrint("Element addition") {
var res = genericField.one
repeat(n) {
res += 1.0
@ -63,7 +63,7 @@ fun main() {
genericField.run {
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
res += one // con't avoid using `one` due to resolution ambiguity
}
}
}

@ -1,5 +1,6 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField
@ -40,12 +41,14 @@ sealed class MST {
//TODO add a function with named arguments
fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> number(node.value)
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
@ -59,4 +62,6 @@ fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}
}
}
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

@ -0,0 +1,33 @@
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MSTAlgebra : NumericAlgebra<MST> {
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
override fun number(value: Number): MST = MST.Numeric(value)
}
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
override val zero: MST = number(0.0)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
}
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
override val one: MST = number(1.0)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
}
object MSTField : Field<MST>, Ring<MST> by MSTRing {
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
}

@ -1,265 +0,0 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.*
/**
* A function declaration that could be compiled to [AsmBuilder].
*
* @param T the type the stored function returns.
*/
abstract class AsmExpression<T> internal constructor() {
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
internal open fun tryEvaluate(): T? = null
/**
* Compiles this declaration.
*
* @param gen the target [AsmBuilder].
*/
@PublishedApi
internal abstract fun compile(gen: AsmBuilder<T>)
}
internal class AsmUnaryOperation<T>(private val context: Algebra<T>, private val name: String, expr: AsmExpression<T>) :
AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = context { unaryOperation(name, expr.tryEvaluate() ?: return@context null) }
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
if (!hasSpecific(context, name, 1))
gen.loadStringConstant(name)
expr.compile(gen)
if (gen.tryInvokeSpecific(context, name, 1))
return
gen.invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmBinaryOperation<T>(
private val context: Algebra<T>,
private val name: String,
first: AsmExpression<T>,
second: AsmExpression<T>
) : AsmExpression<T>() {
private val first: AsmExpression<T> = first.optimize()
private val second: AsmExpression<T> = second.optimize()
override fun tryEvaluate(): T? = context {
binaryOperation(
name,
first.tryEvaluate() ?: return@context null,
second.tryEvaluate() ?: return@context null
)
}
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
if (!hasSpecific(context, name, 2))
gen.loadStringConstant(name)
first.compile(gen)
second.compile(gen)
if (gen.tryInvokeSpecific(context, name, 2))
return
gen.invokeAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmVariableExpression<T>(private val name: String, private val default: T? = null) :
AsmExpression<T>() {
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadVariable(name, default)
}
internal class AsmConstantExpression<T>(private val value: T) :
AsmExpression<T>() {
override fun tryEvaluate(): T = value
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadTConstant(value)
}
internal class AsmConstProductExpression<T>(
private val context: Space<T>,
expr: AsmExpression<T>,
private val const: Number
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = context { (expr.tryEvaluate() ?: return@context null) * const }
override fun compile(gen: AsmBuilder<T>) {
gen.loadAlgebra()
gen.loadNumberConstant(const)
expr.compile(gen)
gen.invokeAlgebraOperation(
owner = AsmBuilder.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.NUMBER_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmNumberExpression<T>(private val context: NumericAlgebra<T>, private val value: Number) :
AsmExpression<T>() {
override fun tryEvaluate(): T? = context.number(value)
override fun compile(gen: AsmBuilder<T>): Unit = gen.loadNumberConstant(value)
}
internal abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* A context class for [AsmExpression] construction.
*/
interface AsmExpressionAlgebra<T, A : NumericAlgebra<T>> : NumericAlgebra<AsmExpression<T>>,
ExpressionAlgebra<T, AsmExpression<T>> {
/**
* The algebra to provide for AsmExpressions built.
*/
val algebra: A
/**
* Builds an AsmExpression to wrap a number.
*/
override fun number(value: Number): AsmExpression<T> = AsmNumberExpression(algebra, value)
/**
* Builds an AsmExpression of constant expression which does not depend on arguments.
*/
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(value)
/**
* Builds an AsmExpression to access a variable.
*/
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(name, default)
/**
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, operation, left, right)
/**
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
AsmUnaryOperation(algebra, operation, arg)
}
/**
* A context class for [AsmExpression] construction for [Space] algebras.
*/
open class AsmExpressionSpace<T, A>(override val algebra: A) : AsmExpressionAlgebra<T, A>,
Space<AsmExpression<T>> where A : Space<T>, A : NumericAlgebra<T> {
override val zero: AsmExpression<T>
get() = const(algebra.zero)
/**
* Builds an AsmExpression of addition of two another expressions.
*/
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an AsmExpression of multiplication of expression by number.
*/
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(algebra, a, k)
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
}
/**
* A context class for [AsmExpression] construction for [Ring] algebras.
*/
open class AsmExpressionRing<T, A>(override val algebra: A) : AsmExpressionSpace<T, A>(algebra),
Ring<AsmExpression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: AsmExpression<T>
get() = const(algebra.one)
/**
* Builds an AsmExpression of multiplication of two expressions.
*/
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionSpace>.number(value)
}
/**
* A context class for [AsmExpression] construction for [Field] algebras.
*/
open class AsmExpressionField<T, A>(override val algebra: A) :
AsmExpressionRing<T, A>(algebra),
Field<AsmExpression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an AsmExpression of division an expression by another one.
*/
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionRing>.number(value)
}

@ -1,53 +1,103 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.MSTField
import scientifik.kmath.ast.MSTRing
import scientifik.kmath.ast.MSTSpace
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.*
import kotlin.reflect.KClass
@PublishedApi
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
/**
* Compile given MST to an Expression using AST compiler
*/
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}
return buildName(expression, collision + 1)
fun AsmBuilder<T>.visit(node: MST): Unit {
when (node) {
is MST.Symbolic -> visitLoadFromVariables(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value)
} else {
error("Number literals are not supported in $algebra")
}
visitLoadFromConstants(constant)
}
is MST.Unary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
is MST.Binary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 2))
visitStringConstant(node.operation)
visit(node.left)
visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
}
}
val builder = AsmBuilder(type.java, algebra, buildName(this))
builder.visit(this)
return builder.generate()
}
@PublishedApi
internal inline fun <reified T> AsmExpression<T>.compile(algebra: Algebra<T>): Expression<T> =
AsmBuilder(T::class.java, algebra, buildName(this), this).getInstance()
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
expressionAlgebra: E,
block: E.() -> AsmExpression<T>
): Expression<T> = expressionAlgebra.block().compile(expressionAlgebra.algebra)
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
mstAlgebra: E,
block: E.() -> MST
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
inline fun <reified T, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
expressionAlgebra: E,
ast: MST
): Expression<T> = asm(expressionAlgebra) { evaluate(ast) }
inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
MSTSpace.block().compileWith(T::class, this)
inline fun <reified T, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
AsmExpressionSpace(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
asmSpace { evaluate(ast) }
inline fun <reified T, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
AsmExpressionRing(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
asmRing { evaluate(ast) }
inline fun <reified T, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
AsmExpressionField(this).let { it.block().compile(it.algebra) }
inline fun <reified T, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
asmRing { evaluate(ast) }
inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
MSTRing.block().compileWith(T::class, this)
inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
MSTField.block().compileWith(T::class, this)

@ -1,8 +1,6 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmExpression
import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
@ -31,7 +29,7 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
append("L${AsmBuilder.OBJECT_CLASS};")
}
invokeAlgebraOperation(
visitAlgebraOperation(
owner = owner,
method = aName,
descriptor = sig,
@ -41,9 +39,8 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
return true
}
@PublishedApi
internal fun <T> AsmExpression<T>.optimize(): AsmExpression<T> {
val a = tryEvaluate()
return if (a == null) this else AsmConstantExpression(a)
}
//
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
// val a = tryEvaluate()
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
//}

@ -1,75 +1,66 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.asmRing
import scientifik.kmath.asm.asmSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmAlgebras {
@Test
fun space() {
val res = ByteRing.asmSpace {
binaryOperation(
"+",
unaryOperation(
"+",
3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()) * 3.toByte() - 1.toByte())
),
number(1)
) + variable("x") + zero
}("x" to 2.toByte())
assertEquals(16, res)
}
@Test
fun ring() {
val res = ByteRing.asmRing {
binaryOperation(
"+",
unaryOperation(
"+",
(3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * const(2)
}()
assertEquals(24, res)
}
@Test
fun field() {
val res = RealField.asmField {
divide(binaryOperation(
"+",
unaryOperation(
"+",
(3.0 - (2.0 + (multiply(
add(const(1.0), const(1.0)),
2
) + 1.0))) * 3 - 1.0
),
number(1)
) / 2, const(2.0)) * one
}()
assertEquals(3.0, res)
}
}
//
//class TestAsmAlgebras {
// @Test
// fun space() {
// val res = ByteRing.asmInRing {
// binaryOperation(
// "+",
//
// unaryOperation(
// "+",
// 3.toByte() - (2.toByte() + (multiply(
// add(number(1), number(1)),
// 2
// ) + 1.toByte()) * 3.toByte() - 1.toByte())
// ),
//
// number(1)
// ) + symbol("x") + zero
// }("x" to 2.toByte())
//
// assertEquals(16, res)
// }
//
// @Test
// fun ring() {
// val res = ByteRing.asmInRing {
// binaryOperation(
// "+",
//
// unaryOperation(
// "+",
// (3.toByte() - (2.toByte() + (multiply(
// add(const(1), const(1)),
// 2
// ) + 1.toByte()))) * 3.0 - 1.toByte()
// ),
//
// number(1)
// ) * const(2)
// }()
//
// assertEquals(24, res)
// }
//
// @Test
// fun field() {
// val res = RealField.asmInField {
// +(3 - 2 + 2*(number(1)+1.0)
//
// unaryOperation(
// "+",
// (3.0 - (2.0 + (multiply(
// add((1.0), const(1.0)),
// 2
// ) + 1.0))) * 3 - 1.0
// )+
//
// number(1)
// ) / 2, const(2.0)) * one
// }()
//
// assertEquals(3.0, res)
// }
//}

@ -1,6 +1,6 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.asmInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
@ -9,13 +9,13 @@ import kotlin.test.assertEquals
class TestAsmExpressions {
@Test
fun testUnaryOperationInvocation() {
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0)
assertEquals(2.0, res)
val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testConstProductInvocation() {
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0)
val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res)
}
}

@ -1,18 +1,18 @@
package scietifik.kmath.ast
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest {
@Test
fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.asmField(mst)()
val res = ComplexField.compile(mst)()
assertEquals(Complex(10.0, 0.0), res)
}
}

@ -74,10 +74,10 @@ class DerivativeStructureField(
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this
operator fun Number.minus(s: DerivativeStructure) = s - this
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
override operator fun Number.plus(b: DerivativeStructure) = b + this
override operator fun Number.minus(b: DerivativeStructure) = b - this
}
/**

@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
// Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
that.d += z.d
override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
that.d -= z.d
override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
b.d -= z.d
}
operator fun Variable<T>.minus(that: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
this@minus.d += z.d
}
}

@ -12,9 +12,6 @@ interface Algebra<T> {
*/
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
@Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol"))
fun raw(value: String): T = symbol(value)
/**
* Dynamic call of unary operation with name [operation] on [arg]
*/
@ -64,6 +61,8 @@ interface SpaceOperations<T> : Algebra<T> {
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.unaryPlus(): T = this
operator fun T.plus(b: T): T = add(this, b)
operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble())
@ -138,17 +137,25 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
//TODO those operators are blocked by type conflict in RealField
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this
//
// operator fun T.minus(b: Number) = this.minus(b * one)
// operator fun Number.minus(b: T) = -b + this
operator fun T.plus(b: Number) = this.plus(number(b))
operator fun Number.plus(b: T) = b + this
operator fun T.minus(b: Number) = this.minus(number(b))
operator fun Number.minus(b: T) = -b + this
}
/**