Merge remote-tracking branch 'mipt-npm/adv-expr' into adv-expr-improved-trigonometry

This commit is contained in:
Iaroslav 2020-06-16 14:40:41 +07:00
commit f8f1814def
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
14 changed files with 253 additions and 447 deletions

View File

@ -27,7 +27,7 @@ fun main() {
val complexTime = measureTimeMillis {
complexField.run {
var res = one
var res: NDBuffer<Complex> = one
repeat(n) {
res += 1.0
}

View File

@ -23,14 +23,14 @@ fun main() {
measureAndPrint("Automatic field addition") {
autoField.run {
var res = one
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
res += number(1.0)
}
}
}
measureAndPrint("Element addition"){
measureAndPrint("Element addition") {
var res = genericField.one
repeat(n) {
res += 1.0
@ -63,7 +63,7 @@ fun main() {
genericField.run {
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
res += one // con't avoid using `one` due to resolution ambiguity
}
}
}

View File

@ -1,5 +1,6 @@
package scientifik.kmath.ast
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.RealField
@ -40,12 +41,14 @@ sealed class MST {
//TODO add a function with named arguments
fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
fun <T> Algebra<T>.evaluate(node: MST): T {
return when (node) {
is MST.Numeric -> number(node.value)
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
?: error("Numeric nodes are not supported by $this")
is MST.Symbolic -> symbol(node.value)
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Binary -> when {
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation(
node.operation,
@ -59,4 +62,6 @@ fun <T> NumericAlgebra<T>.evaluate(node: MST): T {
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
}
}
}
}
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -0,0 +1,33 @@
@file:Suppress("DELEGATED_MEMBER_HIDES_SUPERTYPE_OVERRIDE")
package scientifik.kmath.ast
import scientifik.kmath.operations.*
object MSTAlgebra : NumericAlgebra<MST> {
override fun symbol(value: String): MST = MST.Symbolic(value)
override fun unaryOperation(operation: String, arg: MST): MST = MST.Unary(operation, arg)
override fun binaryOperation(operation: String, left: MST, right: MST): MST = MST.Binary(operation, left, right)
override fun number(value: Number): MST = MST.Numeric(value)
}
object MSTSpace : Space<MST>, NumericAlgebra<MST> by MSTAlgebra {
override val zero: MST = number(0.0)
override fun add(a: MST, b: MST): MST = binaryOperation(SpaceOperations.PLUS_OPERATION, a, b)
override fun multiply(a: MST, k: Number): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, number(k))
}
object MSTRing : Ring<MST>, Space<MST> by MSTSpace {
override val one: MST = number(1.0)
override fun multiply(a: MST, b: MST): MST = binaryOperation(RingOperations.TIMES_OPERATION, a, b)
}
object MSTField : Field<MST>, Ring<MST> by MSTRing {
override fun divide(a: MST, b: MST): MST = binaryOperation(FieldOperations.DIV_OPERATION, a, b)
}

View File

@ -1,292 +0,0 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.optimize
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.expressions.Expression
import scientifik.kmath.expressions.ExpressionAlgebra
import scientifik.kmath.operations.*
import kotlin.reflect.KClass
/**
* A function declaration that could be compiled to [AsmBuilder].
*
* @param T the type the stored function returns.
*/
sealed class AsmExpression<T : Any>: Expression<T> {
abstract val type: KClass<out T>
abstract val algebra: Algebra<T>
/**
* Tries to evaluate this function without its variables. This method is intended for optimization.
*
* @return `null` if the function depends on its variables, the value if the function is a constant.
*/
internal open fun tryEvaluate(): T? = null
/**
* Compiles this declaration.
*
* @param gen the target [AsmBuilder].
*/
internal abstract fun appendTo(gen: AsmBuilder<T>)
/**
* Compile and cache the expression
*/
private val compiledExpression by lazy{
val builder = AsmBuilder(type.java, algebra, buildName(this))
this.appendTo(builder)
builder.generate()
}
override fun invoke(arguments: Map<String, T>): T = compiledExpression.invoke(arguments)
}
internal class AsmUnaryOperation<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
expr: AsmExpression<T>
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = algebra { unaryOperation(name, expr.tryEvaluate() ?: return@algebra null) }
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(algebra, name, 1))
gen.visitStringConstant(name)
expr.appendTo(gen)
if (gen.tryInvokeSpecific(algebra, name, 1))
return
gen.visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmBinaryOperation<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
first: AsmExpression<T>,
second: AsmExpression<T>
) : AsmExpression<T>() {
private val first: AsmExpression<T> = first.optimize()
private val second: AsmExpression<T> = second.optimize()
override fun tryEvaluate(): T? = algebra {
binaryOperation(
name,
first.tryEvaluate() ?: return@algebra null,
second.tryEvaluate() ?: return@algebra null
)
}
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
if (!hasSpecific(algebra, name, 2))
gen.visitStringConstant(name)
first.appendTo(gen)
second.appendTo(gen)
if (gen.tryInvokeSpecific(algebra, name, 2))
return
gen.visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmVariableExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val name: String,
private val default: T? = null
) : AsmExpression<T>() {
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitLoadFromVariables(name, default)
}
internal class AsmConstantExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Algebra<T>,
private val value: T
) : AsmExpression<T>() {
override fun tryEvaluate(): T = value
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitLoadFromConstants(value)
}
internal class AsmConstProductExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: Space<T>,
expr: AsmExpression<T>,
private val const: Number
) : AsmExpression<T>() {
private val expr: AsmExpression<T> = expr.optimize()
override fun tryEvaluate(): T? = algebra { (expr.tryEvaluate() ?: return@algebra null) * const }
override fun appendTo(gen: AsmBuilder<T>) {
gen.visitLoadAlgebra()
gen.visitNumberConstant(const)
expr.appendTo(gen)
gen.visitAlgebraOperation(
owner = AsmBuilder.SPACE_OPERATIONS_CLASS,
method = "multiply",
descriptor = "(L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.NUMBER_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
internal class AsmNumberExpression<T : Any>(
override val type: KClass<out T>,
override val algebra: NumericAlgebra<T>,
private val value: Number
) : AsmExpression<T>() {
override fun tryEvaluate(): T? = algebra.number(value)
override fun appendTo(gen: AsmBuilder<T>): Unit = gen.visitNumberConstant(value)
}
internal abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* A context class for [AsmExpression] construction.
*
* @param algebra The algebra to provide for AsmExpressions built.
*/
open class AsmExpressionAlgebra<T : Any, A : NumericAlgebra<T>>(val type: KClass<out T>, val algebra: A) :
NumericAlgebra<AsmExpression<T>>, ExpressionAlgebra<T, AsmExpression<T>> {
/**
* Builds an AsmExpression to wrap a number.
*/
override fun number(value: Number): AsmExpression<T> = AsmNumberExpression(type, algebra, value)
/**
* Builds an AsmExpression of constant expression which does not depend on arguments.
*/
override fun const(value: T): AsmExpression<T> = AsmConstantExpression(type, algebra, value)
/**
* Builds an AsmExpression to access a variable.
*/
override fun variable(name: String, default: T?): AsmExpression<T> = AsmVariableExpression(type, algebra, name, default)
/**
* Builds an AsmExpression of dynamic call of binary operation [operation] on [left] and [right].
*/
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, operation, left, right)
/**
* Builds an AsmExpression of dynamic call of unary operation with name [operation] on [arg].
*/
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
AsmUnaryOperation(type, algebra, operation, arg)
}
/**
* A context class for [AsmExpression] construction for [Space] algebras.
*/
open class AsmExpressionSpace<T : Any, A>(type: KClass<out T>, algebra: A) : AsmExpressionAlgebra<T, A>(type, algebra),
Space<AsmExpression<T>> where A : Space<T>, A : NumericAlgebra<T> {
override val zero: AsmExpression<T> get() = const(algebra.zero)
/**
* Builds an AsmExpression of addition of two another expressions.
*/
override fun add(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, SpaceOperations.PLUS_OPERATION, a, b)
/**
* Builds an AsmExpression of multiplication of expression by number.
*/
override fun multiply(a: AsmExpression<T>, k: Number): AsmExpression<T> = AsmConstProductExpression(type, algebra, a, k)
operator fun AsmExpression<T>.plus(arg: T): AsmExpression<T> = this + const(arg)
operator fun AsmExpression<T>.minus(arg: T): AsmExpression<T> = this - const(arg)
operator fun T.plus(arg: AsmExpression<T>): AsmExpression<T> = arg + this
operator fun T.minus(arg: AsmExpression<T>): AsmExpression<T> = arg - this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionAlgebra>.binaryOperation(operation, left, right)
}
/**
* A context class for [AsmExpression] construction for [Ring] algebras.
*/
open class AsmExpressionRing<T : Any, A>(type: KClass<out T>, algebra: A) : AsmExpressionSpace<T, A>(type, algebra),
Ring<AsmExpression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
override val one: AsmExpression<T> get() = const(algebra.one)
/**
* Builds an AsmExpression of multiplication of two expressions.
*/
override fun multiply(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, RingOperations.TIMES_OPERATION, a, b)
operator fun AsmExpression<T>.times(arg: T): AsmExpression<T> = this * const(arg)
operator fun T.times(arg: AsmExpression<T>): AsmExpression<T> = arg * this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionSpace>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionSpace>.number(value)
}
/**
* A context class for [AsmExpression] construction for [Field] algebras.
*/
open class AsmExpressionField<T : Any, A>(type: KClass<out T>, algebra: A) :
AsmExpressionRing<T, A>(type, algebra),
Field<AsmExpression<T>> where A : Field<T>, A : NumericAlgebra<T> {
/**
* Builds an AsmExpression of division an expression by another one.
*/
override fun divide(a: AsmExpression<T>, b: AsmExpression<T>): AsmExpression<T> =
AsmBinaryOperation(type, algebra, FieldOperations.DIV_OPERATION, a, b)
operator fun AsmExpression<T>.div(arg: T): AsmExpression<T> = this / const(arg)
operator fun T.div(arg: AsmExpression<T>): AsmExpression<T> = arg / this
override fun unaryOperation(operation: String, arg: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.unaryOperation(operation, arg)
override fun binaryOperation(operation: String, left: AsmExpression<T>, right: AsmExpression<T>): AsmExpression<T> =
super<AsmExpressionRing>.binaryOperation(operation, left, right)
override fun number(value: Number): AsmExpression<T> = super<AsmExpressionRing>.number(value)
}

View File

@ -1,47 +1,103 @@
package scientifik.kmath.asm
import scientifik.kmath.asm.internal.AsmBuilder
import scientifik.kmath.asm.internal.hasSpecific
import scientifik.kmath.asm.internal.tryInvokeSpecific
import scientifik.kmath.ast.MST
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.MSTField
import scientifik.kmath.ast.MSTRing
import scientifik.kmath.ast.MSTSpace
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.NumericAlgebra
import scientifik.kmath.operations.Ring
import scientifik.kmath.operations.Space
import scientifik.kmath.operations.*
import kotlin.reflect.KClass
internal fun buildName(expression: AsmExpression<*>, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${expression.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
/**
* Compile given MST to an Expression using AST compiler
*/
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun buildName(mst: MST, collision: Int = 0): String {
val name = "scientifik.kmath.expressions.generated.AsmCompiledExpression_${mst.hashCode()}_$collision"
try {
Class.forName(name)
} catch (ignored: ClassNotFoundException) {
return name
}
return buildName(mst, collision + 1)
}
return buildName(expression, collision + 1)
fun AsmBuilder<T>.visit(node: MST): Unit {
when (node) {
is MST.Symbolic -> visitLoadFromVariables(node.value)
is MST.Numeric -> {
val constant = if (algebra is NumericAlgebra<T>) {
algebra.number(node.value)
} else {
error("Number literals are not supported in $algebra")
}
visitLoadFromConstants(constant)
}
is MST.Unary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 1)) visitStringConstant(node.operation)
visit(node.value)
if (!tryInvokeSpecific(algebra, node.operation, 1)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "unaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
is MST.Binary -> {
visitLoadAlgebra()
if (!hasSpecific(algebra, node.operation, 2))
visitStringConstant(node.operation)
visit(node.left)
visit(node.right)
if (!tryInvokeSpecific(algebra, node.operation, 2)) {
visitAlgebraOperation(
owner = AsmBuilder.ALGEBRA_CLASS,
method = "binaryOperation",
descriptor = "(L${AsmBuilder.STRING_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};" +
"L${AsmBuilder.OBJECT_CLASS};)" +
"L${AsmBuilder.OBJECT_CLASS};"
)
}
}
}
}
val builder = AsmBuilder(type.java, algebra, buildName(this))
builder.visit(this)
return builder.generate()
}
inline fun <reified T : Any, A : NumericAlgebra<T>, E : AsmExpressionAlgebra<T, A>> A.asm(
expressionAlgebra: E,
block: E.() -> AsmExpression<T>
): Expression<T> = expressionAlgebra.block()
inline fun <reified T : Any> Algebra<T>.compile(mst: MST): Expression<T> = mst.compileWith(T::class, this)
inline fun <reified T : Any> NumericAlgebra<T>.asm(ast: MST): Expression<T> =
AsmExpressionAlgebra(T::class, this).evaluate(ast)
inline fun <reified T : Any, A : Algebra<T>, E : Algebra<MST>> A.asm(
mstAlgebra: E,
block: E.() -> MST
): Expression<T> = mstAlgebra.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmSpace(block: AsmExpressionSpace<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
AsmExpressionSpace<T, A>(T::class, this).block()
inline fun <reified T : Any, A : Space<T>> A.asmInSpace(block: MSTSpace.() -> MST): Expression<T> =
MSTSpace.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmSpace(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Space<T> =
asmSpace { evaluate(ast) }
inline fun <reified T : Any, A : Ring<T>> A.asmInRing(block: MSTRing.() -> MST): Expression<T> =
MSTRing.block().compileWith(T::class, this)
inline fun <reified T : Any, A> A.asmRing(block: AsmExpressionRing<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
AsmExpressionRing(T::class, this).block()
inline fun <reified T : Any, A> A.asmRing(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Ring<T> =
asmRing { evaluate(ast) }
inline fun <reified T : Any, A> A.asmField(block: AsmExpressionField<T, A>.() -> AsmExpression<T>): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
AsmExpressionField(T::class, this).block()
inline fun <reified T : Any, A> A.asmField(ast: MST): Expression<T> where A : NumericAlgebra<T>, A : Field<T> =
asmRing { evaluate(ast) }
inline fun <reified T : Any, A : Field<T>> A.asmInField(block: MSTField.() -> MST): Expression<T> =
MSTField.block().compileWith(T::class, this)

View File

@ -4,10 +4,19 @@ import org.objectweb.asm.ClassWriter
import org.objectweb.asm.Label
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.FunctionalCompiledExpression
import scientifik.kmath.asm.internal.AsmBuilder.AsmClassLoader
import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra
private abstract class FunctionalCompiledExpression<T> internal constructor(
@JvmField protected val algebra: Algebra<T>,
@JvmField protected val constants: Array<Any>
) : Expression<T> {
abstract override fun invoke(arguments: Map<String, T>): T
}
/**
* AsmGenerationContext is a structure that abstracts building a class that unwraps [AsmNode] to plain Java
* expression. This class uses [AsmClassLoader] for loading the generated class, then it is able to instantiate the new
@ -16,6 +25,8 @@ import scientifik.kmath.operations.Algebra
* @param T the type of AsmExpression to unwrap.
* @param algebra the algebra the applied AsmExpressions use.
* @param className the unique class name of new loaded class.
*
* @author [Iaroslav Postovalov](https://github.com/CommanderTvis)
*/
internal class AsmBuilder<T>(
private val classOfT: Class<*>,
@ -44,7 +55,6 @@ internal class AsmBuilder<T>(
private val invokeMethodVisitor: MethodVisitor
private val invokeL0: Label
private lateinit var invokeL1: Label
private var generatedInstance: FunctionalCompiledExpression<T>? = null
init {
asmCompiledClassWriter.visit(
@ -113,8 +123,7 @@ internal class AsmBuilder<T>(
}
@Suppress("UNCHECKED_CAST")
fun generate(): FunctionalCompiledExpression<T> {
generatedInstance?.let { return it }
fun generate(): Expression<T> {
invokeMethodVisitor.run {
visitInsn(Opcodes.ARETURN)
@ -182,7 +191,6 @@ internal class AsmBuilder<T>(
.first()
.newInstance(algebra, constants.toTypedArray()) as FunctionalCompiledExpression<T>
generatedInstance = new
return new
}

View File

@ -1,8 +1,6 @@
package scientifik.kmath.asm.internal
import org.objectweb.asm.Opcodes
import scientifik.kmath.asm.AsmConstantExpression
import scientifik.kmath.asm.AsmExpression
import scientifik.kmath.operations.Algebra
private val methodNameAdapters: Map<String, String> = mapOf("+" to "add", "*" to "multiply", "/" to "divide")
@ -41,8 +39,8 @@ internal fun <T> AsmBuilder<T>.tryInvokeSpecific(context: Algebra<T>, name: Stri
return true
}
internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
val a = tryEvaluate()
return if (a == null) this else AsmConstantExpression(type, algebra, a)
}
//
//internal fun <T : Any> AsmExpression<T>.optimize(): AsmExpression<T> {
// val a = tryEvaluate()
// return if (a == null) this else AsmConstantExpression(type, algebra, a)
//}

View File

@ -1,75 +1,66 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.asmRing
import scientifik.kmath.asm.asmSpace
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.ByteRing
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
class TestAsmAlgebras {
@Test
fun space() {
val res = ByteRing.asmSpace {
binaryOperation(
"+",
unaryOperation(
"+",
3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()) * 3.toByte() - 1.toByte())
),
number(1)
) + variable("x") + zero
}("x" to 2.toByte())
assertEquals(16, res)
}
@Test
fun ring() {
val res = ByteRing.asmRing {
binaryOperation(
"+",
unaryOperation(
"+",
(3.toByte() - (2.toByte() + (multiply(
add(const(1), const(1)),
2
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * const(2)
}()
assertEquals(24, res)
}
@Test
fun field() {
val res = RealField.asmField {
divide(binaryOperation(
"+",
unaryOperation(
"+",
(3.0 - (2.0 + (multiply(
add(const(1.0), const(1.0)),
2
) + 1.0))) * 3 - 1.0
),
number(1)
) / 2, const(2.0)) * one
}()
assertEquals(3.0, res)
}
}
//
//class TestAsmAlgebras {
// @Test
// fun space() {
// val res = ByteRing.asmInRing {
// binaryOperation(
// "+",
//
// unaryOperation(
// "+",
// 3.toByte() - (2.toByte() + (multiply(
// add(number(1), number(1)),
// 2
// ) + 1.toByte()) * 3.toByte() - 1.toByte())
// ),
//
// number(1)
// ) + symbol("x") + zero
// }("x" to 2.toByte())
//
// assertEquals(16, res)
// }
//
// @Test
// fun ring() {
// val res = ByteRing.asmInRing {
// binaryOperation(
// "+",
//
// unaryOperation(
// "+",
// (3.toByte() - (2.toByte() + (multiply(
// add(const(1), const(1)),
// 2
// ) + 1.toByte()))) * 3.0 - 1.toByte()
// ),
//
// number(1)
// ) * const(2)
// }()
//
// assertEquals(24, res)
// }
//
// @Test
// fun field() {
// val res = RealField.asmInField {
// +(3 - 2 + 2*(number(1)+1.0)
//
// unaryOperation(
// "+",
// (3.0 - (2.0 + (multiply(
// add((1.0), const(1.0)),
// 2
// ) + 1.0))) * 3 - 1.0
// )+
//
// number(1)
// ) / 2, const(2.0)) * one
// }()
//
// assertEquals(3.0, res)
// }
//}

View File

@ -1,6 +1,6 @@
package scietifik.kmath.asm
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.asmInField
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.RealField
import kotlin.test.Test
@ -9,13 +9,13 @@ import kotlin.test.assertEquals
class TestAsmExpressions {
@Test
fun testUnaryOperationInvocation() {
val res = RealField.asmField { unaryOperation("+", variable("x")) }("x" to 2.0)
assertEquals(2.0, res)
val res = RealField.asmInField { -symbol("x") }("x" to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testConstProductInvocation() {
val res = RealField.asmField { variable("x") * 2 }("x" to 2.0)
val res = RealField.asmInField { symbol("x") * 2 }("x" to 2.0)
assertEquals(4.0, res)
}
}

View File

@ -1,18 +1,18 @@
package scietifik.kmath.ast
import scientifik.kmath.asm.asmField
import scientifik.kmath.asm.compile
import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import kotlin.test.assertEquals
import kotlin.test.Test
import kotlin.test.assertEquals
class AsmTest {
@Test
fun parsedExpression() {
val mst = "2+2*(2+2)".parseMath()
val res = ComplexField.asmField(mst)()
val res = ComplexField.compile(mst)()
assertEquals(Complex(10.0, 0.0), res)
}
}

View File

@ -76,10 +76,10 @@ class DerivativeStructureField(
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
operator fun DerivativeStructure.plus(n: Number): DerivativeStructure = add(n.toDouble())
operator fun DerivativeStructure.minus(n: Number): DerivativeStructure = subtract(n.toDouble())
operator fun Number.plus(s: DerivativeStructure) = s + this
operator fun Number.minus(s: DerivativeStructure) = s - this
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
override operator fun Number.plus(b: DerivativeStructure) = b + this
override operator fun Number.minus(b: DerivativeStructure) = b - this
}
/**

View File

@ -90,20 +90,20 @@ abstract class AutoDiffField<T : Any, F : Field<T>> : Field<Variable<T>> {
// Overloads for Double constants
operator fun Number.plus(that: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + that.value }) { z ->
that.d += z.d
override operator fun Number.plus(b: Variable<T>): Variable<T> =
derive(variable { this@plus.toDouble() * one + b.value }) { z ->
b.d += z.d
}
operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
override operator fun Variable<T>.plus(b: Number): Variable<T> = b.plus(this)
operator fun Number.minus(that: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - that.value }) { z ->
that.d -= z.d
override operator fun Number.minus(b: Variable<T>): Variable<T> =
derive(variable { this@minus.toDouble() * one - b.value }) { z ->
b.d -= z.d
}
operator fun Variable<T>.minus(that: Number): Variable<T> =
derive(variable { this@minus.value - one * that.toDouble() }) { z ->
override operator fun Variable<T>.minus(b: Number): Variable<T> =
derive(variable { this@minus.value - one * b.toDouble() }) { z ->
this@minus.d += z.d
}
}

View File

@ -12,9 +12,6 @@ interface Algebra<T> {
*/
fun symbol(value: String): T = error("Wrapping of '$value' is not supported in $this")
@Deprecated("Symbol is more concise",replaceWith = ReplaceWith("symbol"))
fun raw(value: String): T = symbol(value)
/**
* Dynamic call of unary operation with name [operation] on [arg]
*/
@ -64,6 +61,8 @@ interface SpaceOperations<T> : Algebra<T> {
//Operation to be performed in this context. Could be moved to extensions in case of KEEP-176
operator fun T.unaryMinus(): T = multiply(this, -1.0)
operator fun T.unaryPlus(): T = this
operator fun T.plus(b: T): T = add(this, b)
operator fun T.minus(b: T): T = add(this, -b)
operator fun T.times(k: Number) = multiply(this, k.toDouble())
@ -138,17 +137,25 @@ interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = one * value.toDouble()
override fun leftSideNumberOperation(operation: String, left: Number, right: T): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.leftSideNumberOperation(operation, left, right)
}
//TODO those operators are blocked by type conflict in RealField
override fun rightSideNumberOperation(operation: String, left: T, right: Number): T = when (operation) {
SpaceOperations.PLUS_OPERATION -> left + right
SpaceOperations.MINUS_OPERATION -> left - right
RingOperations.TIMES_OPERATION -> left * right
else -> super.rightSideNumberOperation(operation, left, right)
}
// operator fun T.plus(b: Number) = this.plus(b * one)
// operator fun Number.plus(b: T) = b + this
//
// operator fun T.minus(b: Number) = this.minus(b * one)
// operator fun Number.minus(b: T) = -b + this
operator fun T.plus(b: Number) = this.plus(number(b))
operator fun Number.plus(b: T) = b + this
operator fun T.minus(b: Number) = this.minus(number(b))
operator fun Number.minus(b: T) = -b + this
}
/**