Simplify and revise test cases for MST compilation engines #285

This commit is contained in:
Iaroslav Postovalov 2021-05-03 00:24:32 +07:00
parent 72d91b04da
commit f0627b2ced
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
18 changed files with 171 additions and 641 deletions

View File

@ -3,12 +3,12 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
package space.kscience.kmath.wasm package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstField import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.interpret import space.kscience.kmath.expressions.interpret
import space.kscience.kmath.misc.symbol import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
@ -16,45 +16,41 @@ import space.kscience.kmath.operations.invoke
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class TestWasmConsistencyWithInterpreter { internal class TestCompilerConsistencyWithInterpreter {
@Test @Test
fun intRing() { fun intRing() = runCompilerTest {
val mst = MstRing { val mst = MstRing {
binaryOperationFunction("+")( binaryOperationFunction("+")(
unaryOperationFunction("+")( unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale( (bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)), add(number(1), number(1)),
2.0 2.0,
) + 1.toByte()))) * 3.0 - 1.toByte() ) + 1.toByte()))) * 3.0 - 1.toByte()
), ),
number(1) number(1),
) * number(2) ) * number(2)
} }
assertEquals( assertEquals(
mst.interpret(IntRing, x to 3), mst.interpret(IntRing, x to 3),
mst.compile(IntRing, x to 3) mst.compile(IntRing, x to 3),
) )
} }
@Test @Test
fun doubleField() { fun doubleField() = runCompilerTest {
val mst = MstField { val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")( +(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0 (3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1), + number(1),
number(1) / 2 + number(2.0) * one number(1) / 2 + number(2.0) * one,
) + zero ) + zero
} }
assertEquals( assertEquals(
mst.interpret(DoubleField, x to 2.0), mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0) mst.compile(DoubleField, x to 2.0),
) )
} }
private companion object {
private val x by symbol
}
} }

View File

@ -0,0 +1,65 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestCompilerOperations {
@Test
fun testUnaryPlus() = runCompilerTest {
val expr = MstExtendedField { +bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() = runCompilerTest {
val expr = MstExtendedField { -bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) + bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() = runCompilerTest {
val expr = MstExtendedField { sin(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testCosine() = runCompilerTest {
val expr = MstExtendedField { cos(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 0.0))
}
@Test
fun testSubtract() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) - bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) / bindSymbol(x) }.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() = runCompilerTest {
val expr = MstExtendedField { bindSymbol(x) pow 2 }.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
}

View File

@ -3,11 +3,11 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
package space.kscience.kmath.wasm package space.kscience.kmath.ast
import space.kscience.kmath.expressions.MstRing import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol import space.kscience.kmath.misc.Symbol.Companion.x
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -15,20 +15,16 @@ import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
internal class TestWasmVariables { internal class TestCompilerVariables {
@Test @Test
fun testVariable() { fun testVariable() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing) val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
assertEquals(1, expr(x to 1)) assertEquals(1, expr(x to 1))
} }
@Test @Test
fun testUndefinedVariableFails() { fun testUndefinedVariableFails() = runCompilerTest {
val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing) val expr = MstRing { bindSymbol(x) }.compileToExpression(IntRing)
assertFailsWith<NoSuchElementException> { expr() } assertFailsWith<NoSuchElementException> { expr() }
} }
private companion object {
private val x by symbol
}
} }

View File

@ -13,7 +13,7 @@ import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class ParserTest { internal class TestParser {
@Test @Test
fun evaluateParsedMst() { fun evaluateParsedMst() {
val mst = "2+2*(2+2)".parseMath() val mst = "2+2*(2+2)".parseMath()

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.operations.DoubleField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
internal class ParserPrecedenceTest { internal class TestParserPrecedence {
@Test @Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath())) fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))

View File

@ -0,0 +1,25 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
internal interface CompilerTestContext {
fun MST.compileToExpression(algebra: IntRing): Expression<Int>
fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int
fun MST.compile(algebra: IntRing, vararg arguments: Pair<Symbol, Int>): Int = compile(algebra, mapOf(*arguments))
fun MST.compileToExpression(algebra: DoubleField): Expression<Double>
fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double
fun MST.compile(algebra: DoubleField, vararg arguments: Pair<Symbol, Double>): Double =
compile(algebra, mapOf(*arguments))
}
internal expect inline fun runCompilerTest(action: CompilerTestContext.() -> Unit)

View File

@ -0,0 +1,39 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.estree.compile as estreeCompile
import space.kscience.kmath.estree.compileToExpression as estreeCompileToExpression
import space.kscience.kmath.wasm.compile as wasmCompile
import space.kscience.kmath.wasm.compileToExpression as wasmCompileToExpression
private object WasmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = wasmCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = wasmCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = wasmCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
wasmCompile(algebra, arguments)
}
private object ESTreeCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = estreeCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = estreeCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = estreeCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
estreeCompile(algebra, arguments)
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) {
action(WasmCompilerTestContext)
action(ESTreeCompilerTestContext)
}

View File

@ -1,97 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeConsistencyWithInterpreter {
@Test
fun mstSpace() {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol(x) + zero
}
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun doubleField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex()),
)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestESTreeSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.estree
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestESTreeVariables {
@Test
fun testVariable() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr(x to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestWasmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.wasm
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestWasmSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -1,97 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.complex.ComplexField
import space.kscience.kmath.complex.toComplex
import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmConsistencyWithInterpreter {
@Test
fun mstSpace() {
val mst = MstGroup {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
number(3.toByte()) - (number(2.toByte()) + (scale(
add(number(1), number(1)),
2.0
) + number(1.toByte()) * 3.toByte() - number(1.toByte())))
),
number(1)
) + bindSymbol(x) + zero
}
assertEquals(
mst.interpret(MstGroup, x to MST.Numeric(2)),
mst.compile(MstGroup, x to MST.Numeric(2))
)
}
@Test
fun byteRing() {
val mst = MstRing {
binaryOperationFunction("+")(
unaryOperationFunction("+")(
(bindSymbol(x) - (2.toByte() + (scale(
add(number(1), number(1)),
2.0
) + 1.toByte()))) * 3.0 - 1.toByte()
),
number(1)
) * number(2)
}
assertEquals(
mst.interpret(ByteRing, x to 3.toByte()),
mst.compile(ByteRing, x to 3.toByte())
)
}
@Test
fun doubleField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(DoubleField, x to 2.0),
mst.compile(DoubleField, x to 2.0)
)
}
@Test
fun complexField() {
val mst = MstField {
+(3 - 2 + 2 * number(1) + 1.0) + binaryOperationFunction("+")(
(3.0 - (bindSymbol(x) + (scale(add(number(1.0), number(1.0)), 2.0) + 1.0))) * 3 - 1.0
+ number(1),
number(1) / 2 + number(2.0) * one
) + zero
}
assertEquals(
mst.interpret(ComplexField, x to 2.0.toComplex()),
mst.compile(ComplexField, x to 2.0.toComplex())
)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,42 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstField
import space.kscience.kmath.expressions.MstGroup
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmOperationsSupport {
@Test
fun testUnaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-2.0, res)
}
@Test
fun testBinaryOperationInvocation() {
val expression = MstGroup { -bindSymbol(x) + number(1.0) }.compileToExpression(DoubleField)
val res = expression(x to 2.0)
assertEquals(-1.0, res)
}
@Test
fun testConstProductInvocation() {
val res = MstField { bindSymbol(x) * 2 }.compileToExpression(DoubleField)(x to 2.0)
assertEquals(4.0, res)
}
private companion object {
private val x by symbol
}
}

View File

@ -1,76 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstExtendedField
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
internal class TestAsmSpecialization {
@Test
fun testUnaryPlus() {
val expr = MstExtendedField { unaryOperationFunction("+")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(2.0, expr(x to 2.0))
}
@Test
fun testUnaryMinus() {
val expr = MstExtendedField { unaryOperationFunction("-")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(-2.0, expr(x to 2.0))
}
@Test
fun testAdd() {
val expr = MstExtendedField {
binaryOperationFunction("+")(
bindSymbol(x),
bindSymbol(x),
)
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
@Test
fun testSine() {
val expr = MstExtendedField { unaryOperationFunction("sin")(bindSymbol(x)) }.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 0.0))
}
@Test
fun testSubtract() {
val expr = MstExtendedField {
binaryOperationFunction("-")(bindSymbol(x),
bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(0.0, expr(x to 2.0))
}
@Test
fun testDivide() {
val expr = MstExtendedField {
binaryOperationFunction("/")(bindSymbol(x), bindSymbol(x))
}.compileToExpression(DoubleField)
assertEquals(1.0, expr(x to 2.0))
}
@Test
fun testPower() {
val expr = MstExtendedField {
binaryOperationFunction("pow")(bindSymbol(x), number(2))
}.compileToExpression(DoubleField)
assertEquals(4.0, expr(x to 2.0))
}
private companion object {
private val x by symbol
}
}

View File

@ -1,34 +0,0 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.asm
import space.kscience.kmath.expressions.MstRing
import space.kscience.kmath.expressions.invoke
import space.kscience.kmath.misc.symbol
import space.kscience.kmath.operations.ByteRing
import space.kscience.kmath.operations.bindSymbol
import space.kscience.kmath.operations.invoke
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
internal class TestAsmVariables {
@Test
fun testVariable() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertEquals(1.toByte(), expr(x to 1.toByte()))
}
@Test
fun testUndefinedVariableFails() {
val expr = MstRing { bindSymbol(x) }.compileToExpression(ByteRing)
assertFailsWith<NoSuchElementException> { expr() }
}
private companion object {
private val x by symbol
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.ast
import space.kscience.kmath.expressions.Expression
import space.kscience.kmath.expressions.MST
import space.kscience.kmath.misc.Symbol
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.asm.compile as asmCompile
import space.kscience.kmath.asm.compileToExpression as asmCompileToExpression
private object AsmCompilerTestContext : CompilerTestContext {
override fun MST.compileToExpression(algebra: IntRing): Expression<Int> = asmCompileToExpression(algebra)
override fun MST.compile(algebra: IntRing, arguments: Map<Symbol, Int>): Int = asmCompile(algebra, arguments)
override fun MST.compileToExpression(algebra: DoubleField): Expression<Double> = asmCompileToExpression(algebra)
override fun MST.compile(algebra: DoubleField, arguments: Map<Symbol, Double>): Double =
asmCompile(algebra, arguments)
}
internal actual inline fun runCompilerTest(action: CompilerTestContext.() -> Unit) = action(AsmCompilerTestContext)