Merge remote-tracking branch 'origin/dev' into nd4j

# Conflicts:
#	build.gradle.kts
This commit is contained in:
Iaroslav Postovalov 2020-07-29 23:46:11 +07:00
commit acca495720
No known key found for this signature in database
GPG Key ID: 70D5F4DCB0972F1B
13 changed files with 257 additions and 79 deletions

View File

@ -5,7 +5,7 @@
Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) Bintray: [ ![Download](https://api.bintray.com/packages/mipt-npm/scientifik/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion)
Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/scientifik/kmath-core/_latestVersion) Bintray-dev: [ ![Download](https://api.bintray.com/packages/mipt-npm/dev/kmath-core/images/download.svg) ](https://bintray.com/mipt-npm/dev/kmath-core/_latestVersion)
# KMath # KMath
Could be pronounced as `key-math`. Could be pronounced as `key-math`.

View File

@ -12,6 +12,7 @@ allprojects {
jcenter() jcenter()
maven("https://dl.bintray.com/kotlin/kotlinx") maven("https://dl.bintray.com/kotlin/kotlinx")
mavenCentral() mavenCentral()
maven("https://dl.bintray.com/hotkeytlt/maven")
} }
group = "scientifik" group = "scientifik"

View File

@ -1,10 +1,4 @@
plugins { plugins { id("scientifik.mpp") }
id("scientifik.mpp")
}
repositories {
maven("https://dl.bintray.com/hotkeytlt/maven")
}
kotlin.sourceSets { kotlin.sourceSets {
// all { // all {
@ -15,23 +9,15 @@ kotlin.sourceSets {
commonMain { commonMain {
dependencies { dependencies {
api(project(":kmath-core")) api(project(":kmath-core"))
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform:0.4.0-alpha-3") implementation("com.github.h0tk3y.betterParse:better-parse:0.4.0")
implementation("com.github.h0tk3y.betterParse:better-parse-multiplatform-metadata:0.4.0-alpha-3")
} }
} }
jvmMain { jvmMain {
dependencies { dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-jvm:0.4.0-alpha-3")
implementation("org.ow2.asm:asm:8.0.1") implementation("org.ow2.asm:asm:8.0.1")
implementation("org.ow2.asm:asm-commons:8.0.1") implementation("org.ow2.asm:asm-commons:8.0.1")
implementation(kotlin("reflect")) implementation(kotlin("reflect"))
} }
} }
jsMain {
dependencies {
implementation("com.github.h0tk3y.betterParse:better-parse-js:0.4.0-alpha-3")
}
}
} }

View File

@ -0,0 +1,59 @@
grammar ArithmeticsEvaluator;
fragment DIGIT: '0'..'9';
fragment LETTER: 'a'..'z';
fragment CAPITAL_LETTER: 'A'..'Z';
fragment UNDERSCORE: '_';
ID: (LETTER | UNDERSCORE | CAPITAL_LETTER) (LETTER | UNDERSCORE | DIGIT | CAPITAL_LETTER)*;
NUM: (DIGIT | '.')+ ([eE] [-+]? DIGIT+)?;
MUL: '*';
DIV: '/';
PLUS: '+';
MINUS: '-';
POW: '^';
COMMA: ',';
LPAR: '(';
RPAR: ')';
WS: [ \n\t\r]+ -> skip;
num
: NUM
;
singular
: ID
;
unaryFunction
: ID LPAR subSumChain RPAR
;
binaryFunction
: ID LPAR subSumChain COMMA subSumChain RPAR
;
term
: num
| singular
| unaryFunction
| binaryFunction
| MINUS term
| LPAR subSumChain RPAR
;
powChain
: term (POW term)*
;
divMulChain
: powChain ((DIV | MUL) powChain)*
;
subSumChain
: divMulChain ((PLUS | MINUS) divMulChain)*
;
rootParser
: subSumChain EOF
;

View File

@ -41,27 +41,28 @@ sealed class MST {
//TODO add a function with named arguments //TODO add a function with named arguments
fun <T> Algebra<T>.evaluate(node: MST): T { fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
return when (node) { is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value)
is MST.Numeric -> (this as? NumericAlgebra<T>)?.number(node.value) ?: error("Numeric nodes are not supported by $this")
?: error("Numeric nodes are not supported by $this") is MST.Symbolic -> symbol(node.value)
is MST.Symbolic -> symbol(node.value) is MST.Unary -> unaryOperation(node.operation, evaluate(node.value))
is MST.Unary -> unaryOperation(node.operation, evaluate(node.value)) is MST.Binary -> when {
is MST.Binary -> when { this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
this !is NumericAlgebra -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
node.left is MST.Numeric && node.right is MST.Numeric -> { node.left is MST.Numeric && node.right is MST.Numeric -> {
val number = RealField.binaryOperation( val number = RealField.binaryOperation(
node.operation, node.operation,
node.left.value.toDouble(), node.left.value.toDouble(),
node.right.value.toDouble() node.right.value.toDouble()
) )
number(number)
} number(number)
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
} }
node.left is MST.Numeric -> leftSideNumberOperation(node.operation, node.left.value, evaluate(node.right))
node.right is MST.Numeric -> rightSideNumberOperation(node.operation, evaluate(node.left), node.right.value)
else -> binaryOperation(node.operation, evaluate(node.left), evaluate(node.right))
} }
} }
fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this) fun <T> MST.compile(algebra: Algebra<T>): T = algebra.evaluate(this)

View File

@ -5,6 +5,9 @@ import com.github.h0tk3y.betterParse.grammar.Grammar
import com.github.h0tk3y.betterParse.grammar.parseToEnd import com.github.h0tk3y.betterParse.grammar.parseToEnd
import com.github.h0tk3y.betterParse.grammar.parser import com.github.h0tk3y.betterParse.grammar.parser
import com.github.h0tk3y.betterParse.grammar.tryParseToEnd import com.github.h0tk3y.betterParse.grammar.tryParseToEnd
import com.github.h0tk3y.betterParse.lexer.Token
import com.github.h0tk3y.betterParse.lexer.TokenMatch
import com.github.h0tk3y.betterParse.lexer.regexToken
import com.github.h0tk3y.betterParse.parser.ParseResult import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import scientifik.kmath.operations.FieldOperations import scientifik.kmath.operations.FieldOperations
@ -13,47 +16,69 @@ import scientifik.kmath.operations.RingOperations
import scientifik.kmath.operations.SpaceOperations import scientifik.kmath.operations.SpaceOperations
/** /**
* TODO move to common * TODO move to core
*/ */
private object ArithmeticsEvaluator : Grammar<MST>() { object ArithmeticsEvaluator : Grammar<MST>() {
val num by token("-?[\\d.]+(?:[eE]-?\\d+)?".toRegex()) // TODO replace with "...".toRegex() when better-parse 0.4.1 is released
val lpar by token("\\(".toRegex()) private val num: Token by regexToken("[\\d.]+(?:[eE][-+]?\\d+)?")
val rpar by token("\\)".toRegex()) private val id: Token by regexToken("[a-z_A-Z][\\da-z_A-Z]*")
val mul by token("\\*".toRegex()) private val lpar: Token by regexToken("\\(")
val pow by token("\\^".toRegex()) private val rpar: Token by regexToken("\\)")
val div by token("/".toRegex()) private val comma: Token by regexToken(",")
val minus by token("-".toRegex()) private val mul: Token by regexToken("\\*")
val plus by token("\\+".toRegex()) private val pow: Token by regexToken("\\^")
val ws by token("\\s+".toRegex(), ignore = true) private val div: Token by regexToken("/")
private val minus: Token by regexToken("-")
private val plus: Token by regexToken("\\+")
private val ws: Token by regexToken("\\s+", ignore = true)
val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) } private val number: Parser<MST> by num use { MST.Numeric(text.toDouble()) }
private val singular: Parser<MST> by id use { MST.Symbolic(text) }
val term: Parser<MST> by number or private val unaryFunction: Parser<MST> by (id and skip(lpar) and parser(::subSumChain) and skip(rpar))
(skip(minus) and parser(this::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) }) or .map { (id, term) -> MST.Unary(id.text, term) }
(skip(lpar) and parser(this::rootParser) and skip(rpar))
val powChain by leftAssociative(term, pow) { a, _, b -> private val binaryFunction: Parser<MST> by id
.and(skip(lpar))
.and(parser(::subSumChain))
.and(skip(comma))
.and(parser(::subSumChain))
.and(skip(rpar))
.map { (id, left, right) -> MST.Binary(id.text, left, right) }
private val term: Parser<MST> by number
.or(binaryFunction)
.or(unaryFunction)
.or(singular)
.or(skip(minus) and parser(::term) map { MST.Unary(SpaceOperations.MINUS_OPERATION, it) })
.or(skip(lpar) and parser(::subSumChain) and skip(rpar))
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
MST.Binary(PowerOperations.POW_OPERATION, a, b) MST.Binary(PowerOperations.POW_OPERATION, a, b)
} }
val divMulChain: Parser<MST> by leftAssociative(powChain, div or mul use { type }) { a, op, b -> private val divMulChain: Parser<MST> by leftAssociative(
if (op == div) { term = powChain,
operator = div or mul use TokenMatch::type
) { a, op, b ->
if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOperations.DIV_OPERATION, a, b)
} else { else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOperations.TIMES_OPERATION, a, b)
}
} }
val subSumChain: Parser<MST> by leftAssociative(divMulChain, plus or minus use { type }) { a, op, b -> private val subSumChain: Parser<MST> by leftAssociative(
if (op == plus) { term = divMulChain,
operator = plus or minus use TokenMatch::type
) { a, op, b ->
if (op == plus)
MST.Binary(SpaceOperations.PLUS_OPERATION, a, b) MST.Binary(SpaceOperations.PLUS_OPERATION, a, b)
} else { else
MST.Binary(SpaceOperations.MINUS_OPERATION, a, b) MST.Binary(SpaceOperations.MINUS_OPERATION, a, b)
}
} }
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain
} }
fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this) fun String.tryParseMath(): ParseResult<MST> = ArithmeticsEvaluator.tryParseToEnd(this)
fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this) fun String.parseMath(): MST = ArithmeticsEvaluator.parseToEnd(this)

View File

@ -8,7 +8,6 @@ import scientifik.kmath.ast.MST
import scientifik.kmath.ast.MstExpression import scientifik.kmath.ast.MstExpression
import scientifik.kmath.expressions.Expression import scientifik.kmath.expressions.Expression
import scientifik.kmath.operations.Algebra import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.NumericAlgebra
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -17,7 +16,19 @@ import kotlin.reflect.KClass
fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> { fun <T : Any> MST.compileWith(type: KClass<T>, algebra: Algebra<T>): Expression<T> {
fun AsmBuilder<T>.visit(node: MST) { fun AsmBuilder<T>.visit(node: MST) {
when (node) { when (node) {
is MST.Symbolic -> loadVariable(node.value) is MST.Symbolic -> {
val symbol = try {
algebra.symbol(node.value)
} catch (ignored: Throwable) {
null
}
if (symbol != null)
loadTConstant(symbol)
else
loadVariable(node.value)
}
is MST.Numeric -> loadNumeric(node.value) is MST.Numeric -> loadNumeric(node.value)
is MST.Unary -> buildAlgebraOperationCall( is MST.Unary -> buildAlgebraOperationCall(

View File

@ -288,7 +288,7 @@ internal class AsmBuilder<T> internal constructor(
/** /**
* Loads a [T] constant from [constants]. * Loads a [T] constant from [constants].
*/ */
private fun loadTConstant(value: T) { internal fun loadTConstant(value: T) {
if (classOfT in INLINABLE_NUMBERS) { if (classOfT in INLINABLE_NUMBERS) {
val expectedType = expectationStack.pop() val expectedType = expectationStack.pop()
val mustBeBoxed = expectedType.sort == Type.OBJECT val mustBeBoxed = expectedType.sort == Type.OBJECT
@ -340,7 +340,7 @@ internal class AsmBuilder<T> internal constructor(
checkcast(type) checkcast(type)
} }
fun loadNumeric(value: Number) { internal fun loadNumeric(value: Number) {
if (expectationStack.peek() == NUMBER_TYPE) { if (expectationStack.peek() == NUMBER_TYPE) {
loadNumberConstant(value, true) loadNumberConstant(value, true)
expectationStack.pop() expectationStack.pop()

View File

@ -0,0 +1,36 @@
package scietifik.kmath.ast
import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.parseMath
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.RealField
import kotlin.test.Test
import kotlin.test.assertEquals
internal class ParserPrecedenceTest {
private val f: Field<Double> = RealField
@Test
fun test1(): Unit = assertEquals(6.0, f.evaluate("2*2+2".parseMath()))
@Test
fun test2(): Unit = assertEquals(6.0, f.evaluate("2+2*2".parseMath()))
@Test
fun test3(): Unit = assertEquals(10.0, f.evaluate("2^3+2".parseMath()))
@Test
fun test4(): Unit = assertEquals(10.0, f.evaluate("2+2^3".parseMath()))
@Test
fun test5(): Unit = assertEquals(16.0, f.evaluate("2^3*2".parseMath()))
@Test
fun test6(): Unit = assertEquals(16.0, f.evaluate("2*2^3".parseMath()))
@Test
fun test7(): Unit = assertEquals(18.0, f.evaluate("2+2^3*2".parseMath()))
@Test
fun test8(): Unit = assertEquals(18.0, f.evaluate("2*2^3+2".parseMath()))
}

View File

@ -4,8 +4,10 @@ import scientifik.kmath.ast.evaluate
import scientifik.kmath.ast.mstInField import scientifik.kmath.ast.mstInField
import scientifik.kmath.ast.parseMath import scientifik.kmath.ast.parseMath
import scientifik.kmath.expressions.invoke import scientifik.kmath.expressions.invoke
import scientifik.kmath.operations.Algebra
import scientifik.kmath.operations.Complex import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.RealField
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -22,4 +24,37 @@ internal class ParserTest {
val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }() val res = ComplexField.mstInField { number(2) + number(2) * (number(2) + number(2)) }()
assertEquals(Complex(10.0, 0.0), res) assertEquals(Complex(10.0, 0.0), res)
} }
@Test
fun `evaluate MST with singular`() {
val mst = "i".parseMath()
val res = ComplexField.evaluate(mst)
assertEquals(ComplexField.i, res)
}
@Test
fun `evaluate MST with unary function`() {
val mst = "sin(0)".parseMath()
val res = RealField.evaluate(mst)
assertEquals(0.0, res)
}
@Test
fun `evaluate MST with binary function`() {
val magicalAlgebra = object : Algebra<String> {
override fun symbol(value: String): String = value
override fun unaryOperation(operation: String, arg: String): String = throw NotImplementedError()
override fun binaryOperation(operation: String, left: String, right: String): String = when (operation) {
"magic" -> "$left$right"
else -> throw NotImplementedError()
}
}
val mst = "magic(a, b)".parseMath()
val res = magicalAlgebra.evaluate(mst)
assertEquals("a ★ b", res)
}
} }

View File

@ -1,5 +1,6 @@
package scientifik.kmath.operations package scientifik.kmath.operations
import scientifik.kmath.operations.RealField.pow
import kotlin.math.abs import kotlin.math.abs
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
@ -85,6 +86,11 @@ object RealField : ExtendedField<Double>, Norm<Double, Double> {
override inline fun Double.times(b: Double) = this * b override inline fun Double.times(b: Double) = this * b
override inline fun Double.div(b: Double) = this / b override inline fun Double.div(b: Double) = this / b
override fun binaryOperation(operation: String, left: Double, right: Double): Double = when (operation) {
PowerOperations.POW_OPERATION -> left pow right
else -> super.binaryOperation(operation, left, right)
}
} }
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")

View File

@ -8,30 +8,48 @@ import java.math.MathContext
* A field wrapper for Java [BigInteger] * A field wrapper for Java [BigInteger]
*/ */
object JBigIntegerField : Field<BigInteger> { object JBigIntegerField : Field<BigInteger> {
override val zero: BigInteger = BigInteger.ZERO override val zero: BigInteger
override val one: BigInteger = BigInteger.ONE get() = BigInteger.ZERO
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) override val one: BigInteger
get() = BigInteger.ONE
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b) override fun divide(a: BigInteger, b: BigInteger): BigInteger = a.div(b)
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
override fun BigInteger.minus(b: BigInteger): BigInteger = this.subtract(b)
override fun multiply(a: BigInteger, k: Number): BigInteger = a.multiply(k.toInt().toBigInteger())
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
override fun BigInteger.unaryMinus(): BigInteger = negate()
} }
/** /**
* A Field wrapper for Java [BigDecimal] * A Field wrapper for Java [BigDecimal]
*/ */
class JBigDecimalField(val mathContext: MathContext = MathContext.DECIMAL64) : Field<BigDecimal> { abstract class JBigDecimalFieldBase internal constructor(val mathContext: MathContext = MathContext.DECIMAL64) :
override val zero: BigDecimal = BigDecimal.ZERO Field<BigDecimal>,
override val one: BigDecimal = BigDecimal.ONE PowerOperations<BigDecimal> {
override val zero: BigDecimal
get() = BigDecimal.ZERO
override val one: BigDecimal
get() = BigDecimal.ONE
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
override fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
override fun multiply(a: BigDecimal, k: Number): BigDecimal = override fun multiply(a: BigDecimal, k: Number): BigDecimal =
a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext) a.multiply(k.toDouble().toBigDecimal(mathContext), mathContext)
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
} override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
}
class JBigDecimalField(mathContext: MathContext = MathContext.DECIMAL64) : JBigDecimalFieldBase(mathContext) {
companion object : JBigDecimalFieldBase()
}

View File

@ -31,5 +31,5 @@ object D2 : Dimension {
} }
object D3 : Dimension { object D3 : Dimension {
override val dim: UInt get() = 31U override val dim: UInt get() = 3U
} }