forked from kscience/kmath
Merge branch 'dev' into feature/polynomials
This commit is contained in:
commit
d416f8cf34
@ -4,6 +4,8 @@
|
|||||||
### Added
|
### Added
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
- Kotlin 1.7
|
||||||
|
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.0"
|
version = "0.3.1-dev-1"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
@ -51,6 +51,18 @@ subprojects {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
plugins.withId("org.jetbrains.kotlin.multiplatform") {
|
||||||
|
configure<org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension> {
|
||||||
|
sourceSets {
|
||||||
|
val commonTest by getting {
|
||||||
|
dependencies {
|
||||||
|
implementation(projects.testUtils)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
readme.readmeTemplate = file("docs/templates/README-TEMPLATE.md")
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
plugins {
|
plugins {
|
||||||
|
kotlin("jvm") version "1.7.0"
|
||||||
`kotlin-dsl`
|
`kotlin-dsl`
|
||||||
`version-catalog`
|
`version-catalog`
|
||||||
alias(miptNpmLibs.plugins.kotlin.plugin.serialization)
|
alias(npmlibs.plugins.kotlin.plugin.serialization)
|
||||||
}
|
}
|
||||||
|
|
||||||
java.targetCompatibility = JavaVersion.VERSION_11
|
java.targetCompatibility = JavaVersion.VERSION_11
|
||||||
@ -13,17 +14,18 @@ repositories {
|
|||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
}
|
}
|
||||||
|
|
||||||
val toolsVersion: String by extra
|
val toolsVersion = npmlibs.versions.tools.get()
|
||||||
val kotlinVersion = miptNpmLibs.versions.kotlin.asProvider().get()
|
val kotlinVersion = npmlibs.versions.kotlin.asProvider().get()
|
||||||
val benchmarksVersion = miptNpmLibs.versions.kotlinx.benchmark.get()
|
val benchmarksVersion = npmlibs.versions.kotlinx.benchmark.get()
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api("ru.mipt.npm:gradle-tools:$toolsVersion")
|
api("ru.mipt.npm:gradle-tools:$toolsVersion")
|
||||||
|
api(npmlibs.atomicfu.gradle)
|
||||||
//plugins form benchmarks
|
//plugins form benchmarks
|
||||||
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
api("org.jetbrains.kotlinx:kotlinx-benchmark-plugin:$benchmarksVersion")
|
||||||
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
api("org.jetbrains.kotlin:kotlin-allopen:$kotlinVersion")
|
||||||
//to be used inside build-script only
|
//to be used inside build-script only
|
||||||
implementation(miptNpmLibs.kotlinx.serialization.json)
|
implementation(npmlibs.kotlinx.serialization.json)
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets.all {
|
kotlin.sourceSets.all {
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright 2018-2021 KMath contributors.
|
|
||||||
# Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
|
||||||
#
|
|
||||||
|
|
||||||
kotlin.code.style=official
|
|
||||||
toolsVersion=0.11.2-kotlin-1.6.10
|
|
@ -6,7 +6,17 @@
|
|||||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
||||||
|
|
||||||
dependencyResolutionManagement {
|
dependencyResolutionManagement {
|
||||||
val toolsVersion: String by extra
|
val projectProperties = java.util.Properties()
|
||||||
|
file("../gradle.properties").inputStream().use {
|
||||||
|
projectProperties.load(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
projectProperties.forEach { key, value ->
|
||||||
|
extra.set(key.toString(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
val toolsVersion: String = projectProperties["toolsVersion"].toString()
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenLocal()
|
mavenLocal()
|
||||||
@ -16,7 +26,7 @@ dependencyResolutionManagement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
versionCatalogs {
|
versionCatalogs {
|
||||||
create("miptNpmLibs") {
|
create("npmlibs") {
|
||||||
from("ru.mipt.npm:version-catalog:$toolsVersion")
|
from("ru.mipt.npm:version-catalog:$toolsVersion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -319,7 +319,9 @@ public object EjmlLinearSpace${ops} : EjmlLinearSpace<${type}, ${kmathAlgebra},
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,8 +6,10 @@ kotlin.code.style=official
|
|||||||
kotlin.jupyter.add.scanner=false
|
kotlin.jupyter.add.scanner=false
|
||||||
kotlin.mpp.stability.nowarn=true
|
kotlin.mpp.stability.nowarn=true
|
||||||
kotlin.native.ignoreDisabledTargets=true
|
kotlin.native.ignoreDisabledTargets=true
|
||||||
#kotlin.incremental.js.ir=true
|
//kotlin.incremental.js.ir=true
|
||||||
|
|
||||||
org.gradle.configureondemand=true
|
org.gradle.configureondemand=true
|
||||||
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1G
|
|
||||||
org.gradle.parallel=true
|
org.gradle.parallel=true
|
||||||
|
org.gradle.jvmargs=-Xmx4096m
|
||||||
|
|
||||||
|
toolsVersion=0.11.7-kotlin-1.7.0
|
||||||
|
@ -199,10 +199,7 @@ public fun main() {
|
|||||||
|
|
||||||
Result LaTeX:
|
Result LaTeX:
|
||||||
|
|
||||||
<div style="background-color:white;">
|
$$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
|
||||||
|
|
||||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
|
|
||||||
</div>
|
|
||||||
|
|
||||||
Result MathML (can be used with MathJax or other renderers):
|
Result MathML (can be used with MathJax or other renderers):
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ kotlin.sourceSets {
|
|||||||
|
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api("com.github.h0tk3y.betterParse:better-parse:0.4.2")
|
api("com.github.h0tk3y.betterParse:better-parse:0.4.4")
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,7 +57,7 @@ tasks.dokkaHtml {
|
|||||||
|
|
||||||
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
if (System.getProperty("space.kscience.kmath.ast.dump.generated.classes") == "1")
|
||||||
tasks.jvmTest {
|
tasks.jvmTest {
|
||||||
jvmArgs = (jvmArgs ?: emptyList()) + listOf("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
|
jvmArgs("-Dspace.kscience.kmath.ast.dump.generated.classes=1")
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -170,10 +170,7 @@ public fun main() {
|
|||||||
|
|
||||||
Result LaTeX:
|
Result LaTeX:
|
||||||
|
|
||||||
<div style="background-color:white;">
|
$$\operatorname{exp}\\,\left(\sqrt{x}\right)-\frac{\frac{\operatorname{arcsin}\\,\left(2\\,x\right)}{2\times10^{10}+x^{3}}}{12}+x^{2/3}$$
|
||||||
|
|
||||||
![](https://latex.codecogs.com/gif.latex?%5Coperatorname{exp}%5C,%5Cleft(%5Csqrt{x}%5Cright)-%5Cfrac{%5Cfrac{%5Coperatorname{arcsin}%5C,%5Cleft(2%5C,x%5Cright)}{2%5Ctimes10^{10}%2Bx^{3}}}{12}+x^{2/3})
|
|
||||||
</div>
|
|
||||||
|
|
||||||
Result MathML (can be used with MathJax or other renderers):
|
Result MathML (can be used with MathJax or other renderers):
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ public fun <T : Any> MST.compileToExpression(algebra: Algebra<T>): Expression<T>
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ESTreeBuilder<T> { visit(typed) }.instance
|
return ESTreeBuilder { visit(typed) }.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -22,28 +22,20 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suppress("UNUSED_VARIABLE")
|
||||||
val instance: Expression<T> by lazy {
|
val instance: Expression<T> by lazy {
|
||||||
val node = Program(
|
val node = Program(
|
||||||
sourceType = "script",
|
sourceType = "script",
|
||||||
VariableDeclaration(
|
ReturnStatement(bodyCallback())
|
||||||
kind = "var",
|
|
||||||
VariableDeclarator(
|
|
||||||
id = Identifier("executable"),
|
|
||||||
init = FunctionExpression(
|
|
||||||
params = arrayOf(Identifier("constants"), Identifier("arguments")),
|
|
||||||
body = BlockStatement(ReturnStatement(bodyCallback())),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
eval(generate(node))
|
val code = generate(node)
|
||||||
GeneratedExpression(js("executable"), constants.toTypedArray())
|
GeneratedExpression(js("new Function('constants', 'arguments_0', code)"), constants.toTypedArray())
|
||||||
}
|
}
|
||||||
|
|
||||||
private val constants = mutableListOf<Any>()
|
private val constants = mutableListOf<Any>()
|
||||||
|
|
||||||
fun constant(value: Any?) = when {
|
fun constant(value: Any?): BaseExpression = when {
|
||||||
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
|
value == null || jsTypeOf(value) == "number" || jsTypeOf(value) == "string" || jsTypeOf(value) == "boolean" ->
|
||||||
SimpleLiteral(value)
|
SimpleLiteral(value)
|
||||||
|
|
||||||
@ -61,7 +53,8 @@ internal class ESTreeBuilder<T>(val bodyCallback: ESTreeBuilder<T>.() -> BaseExp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fun variable(name: Symbol): BaseExpression = call(getOrFail, Identifier("arguments"), SimpleLiteral(name.identity))
|
fun variable(name: Symbol): BaseExpression =
|
||||||
|
call(getOrFail, Identifier("arguments_0"), SimpleLiteral(name.identity))
|
||||||
|
|
||||||
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
fun call(function: Function<T>, vararg args: BaseExpression): BaseExpression = SimpleCallExpression(
|
||||||
optional = false,
|
optional = false,
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:Suppress("unused")
|
||||||
|
|
||||||
package space.kscience.kmath.internal.estree
|
package space.kscience.kmath.internal.estree
|
||||||
|
|
||||||
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
|
internal fun Program(sourceType: String, vararg body: dynamic) = object : Program {
|
||||||
@ -28,9 +30,10 @@ internal fun Identifier(name: String) = object : Identifier {
|
|||||||
override var name = name
|
override var name = name
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun FunctionExpression(params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
|
internal fun FunctionExpression(id: Identifier?, params: Array<dynamic>, body: BlockStatement) = object : FunctionExpression {
|
||||||
override var params = params
|
override var params = params
|
||||||
override var type = "FunctionExpression"
|
override var type = "FunctionExpression"
|
||||||
|
override var id: Identifier? = id
|
||||||
override var body = body
|
override var body = body
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,7 +71,7 @@ internal sealed class WasmBuilder<T : Number, out E : Expression<T>>(
|
|||||||
protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
|
protected open fun visitBinary(mst: TypedMst.Binary<T>): ExpressionRef =
|
||||||
error("Binary operation ${mst.operation} not defined in $this")
|
error("Binary operation ${mst.operation} not defined in $this")
|
||||||
|
|
||||||
protected open fun createModule(): BinaryenModule = js("new \$module\$binaryen.Module()")
|
protected open fun createModule(): BinaryenModule = space.kscience.kmath.internal.binaryen.Module()
|
||||||
|
|
||||||
protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
|
protected fun visit(node: TypedMst<T>): ExpressionRef = when (node) {
|
||||||
is TypedMst.Constant -> visitNumber(
|
is TypedMst.Constant -> visitNumber(
|
||||||
|
@ -49,5 +49,7 @@ internal abstract class AsmBuilder {
|
|||||||
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
* ASM Type for [space.kscience.kmath.expressions.Symbol].
|
||||||
*/
|
*/
|
||||||
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
|
val SYMBOL_TYPE: Type = getObjectType("space/kscience/kmath/expressions/Symbol")
|
||||||
|
|
||||||
|
const val ARGUMENTS_NAME = "args"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,13 +19,15 @@ readme {
|
|||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "complex",
|
id = "complex",
|
||||||
description = "Complex Numbers",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Complex.kt"
|
||||||
)
|
){
|
||||||
|
"Complex numbers operations"
|
||||||
|
}
|
||||||
|
|
||||||
feature(
|
feature(
|
||||||
id = "quaternion",
|
id = "quaternion",
|
||||||
description = "Quaternions",
|
|
||||||
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt"
|
ref = "src/commonMain/kotlin/space/kscience/kmath/complex/Quaternion.kt"
|
||||||
)
|
){
|
||||||
|
"Quaternions and their composition"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,52 +16,130 @@ import space.kscience.kmath.structures.MutableBuffer
|
|||||||
import space.kscience.kmath.structures.MutableMemoryBuffer
|
import space.kscience.kmath.structures.MutableMemoryBuffer
|
||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents `double`-based quaternion.
|
||||||
|
*
|
||||||
|
* @property w The first component.
|
||||||
|
* @property x The second component.
|
||||||
|
* @property y The third component.
|
||||||
|
* @property z The fourth component.
|
||||||
|
*/
|
||||||
|
public class Quaternion(
|
||||||
|
public val w: Double,
|
||||||
|
public val x: Double,
|
||||||
|
public val y: Double,
|
||||||
|
public val z: Double,
|
||||||
|
) : Buffer<Double> {
|
||||||
|
init {
|
||||||
|
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
|
||||||
|
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
|
||||||
|
require(!y.isNaN()) { "y-component of quaternion is not-a-number" }
|
||||||
|
require(!z.isNaN()) { "z-component of quaternion is not-a-number" }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a string representation of this quaternion.
|
||||||
|
*/
|
||||||
|
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
|
||||||
|
|
||||||
|
override val size: Int get() = 4
|
||||||
|
|
||||||
|
override fun get(index: Int): Double = when (index) {
|
||||||
|
0 -> w
|
||||||
|
1 -> x
|
||||||
|
2 -> y
|
||||||
|
3 -> z
|
||||||
|
else -> error("Index $index out of bounds [0,3]")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun iterator(): Iterator<Double> = listOf(w, x, y, z).iterator()
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as Quaternion
|
||||||
|
|
||||||
|
if (w != other.w) return false
|
||||||
|
if (x != other.x) return false
|
||||||
|
if (y != other.y) return false
|
||||||
|
if (z != other.z) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var result = w.hashCode()
|
||||||
|
result = 31 * result + x.hashCode()
|
||||||
|
result = 31 * result + y.hashCode()
|
||||||
|
result = 31 * result + z.hashCode()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public companion object : MemorySpec<Quaternion> {
|
||||||
|
override val objectSize: Int get() = 32
|
||||||
|
|
||||||
|
override fun MemoryReader.read(offset: Int): Quaternion = Quaternion(
|
||||||
|
readDouble(offset),
|
||||||
|
readDouble(offset + 8),
|
||||||
|
readDouble(offset + 16),
|
||||||
|
readDouble(offset + 24)
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
|
||||||
|
writeDouble(offset, value.w)
|
||||||
|
writeDouble(offset + 8, value.x)
|
||||||
|
writeDouble(offset + 16, value.y)
|
||||||
|
writeDouble(offset + 24, value.z)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun Quaternion(w: Number, x: Number = 0.0, y: Number = 0.0, z: Number = 0.0): Quaternion = Quaternion(
|
||||||
|
w.toDouble(),
|
||||||
|
x.toDouble(),
|
||||||
|
y.toDouble(),
|
||||||
|
z.toDouble(),
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This quaternion's conjugate.
|
* This quaternion's conjugate.
|
||||||
*/
|
*/
|
||||||
public val Quaternion.conjugate: Quaternion
|
public val Quaternion.conjugate: Quaternion
|
||||||
get() = QuaternionField { z - x * i - y * j - z * k }
|
get() = Quaternion(w, -x, -y, -z)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This quaternion's reciprocal.
|
* This quaternion's reciprocal.
|
||||||
*/
|
*/
|
||||||
public val Quaternion.reciprocal: Quaternion
|
public val Quaternion.reciprocal: Quaternion
|
||||||
get() {
|
get() {
|
||||||
QuaternionField {
|
val norm2 = (w * w + x * x + y * y + z * z)
|
||||||
val n = norm(this@reciprocal)
|
return Quaternion(w / norm2, -x / norm2, -y / norm2, -z / norm2)
|
||||||
return conjugate / (n * n)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Absolute value of the quaternion.
|
|
||||||
*/
|
|
||||||
public val Quaternion.r: Double
|
|
||||||
get() = sqrt(w * w + x * x + y * y + z * z)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field of [Quaternion].
|
* A field of [Quaternion].
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
|
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Double>, PowerOperations<Quaternion>,
|
||||||
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
||||||
override val zero: Quaternion = 0.toQuaternion()
|
override val zero: Quaternion = Quaternion(0.0)
|
||||||
override val one: Quaternion = 1.toQuaternion()
|
override val one: Quaternion = Quaternion(1.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `i` quaternion unit.
|
* The `i` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val i: Quaternion = Quaternion(0, 1)
|
public val i: Quaternion = Quaternion(0.0, 1.0, 0.0, 0.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `j` quaternion unit.
|
* The `j` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val j: Quaternion = Quaternion(0, 0, 1)
|
public val j: Quaternion = Quaternion(0.0, 0.0, 1.0, 0.0)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The `k` quaternion unit.
|
* The `k` quaternion unit.
|
||||||
*/
|
*/
|
||||||
public val k: Quaternion = Quaternion(0, 0, 0, 1)
|
public val k: Quaternion = Quaternion(0.0, 0.0, 0.0, 1.0)
|
||||||
|
|
||||||
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
||||||
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
||||||
@ -133,7 +211,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
|
|
||||||
override fun exp(arg: Quaternion): Quaternion {
|
override fun exp(arg: Quaternion): Quaternion {
|
||||||
val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z
|
val un = arg.x * arg.x + arg.y * arg.y + arg.z * arg.z
|
||||||
if (un == 0.0) return exp(arg.w).toQuaternion()
|
if (un == 0.0) return Quaternion(exp(arg.w))
|
||||||
val n1 = sqrt(un)
|
val n1 = sqrt(un)
|
||||||
val ea = exp(arg.w)
|
val ea = exp(arg.w)
|
||||||
val n2 = ea * sin(n1) / n1
|
val n2 = ea * sin(n1) / n1
|
||||||
@ -158,7 +236,8 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
|
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
|
||||||
}
|
}
|
||||||
|
|
||||||
override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z)
|
override operator fun Number.plus(other: Quaternion): Quaternion =
|
||||||
|
Quaternion(toDouble() + other.w, other.x, other.y, other.z)
|
||||||
|
|
||||||
override operator fun Number.minus(other: Quaternion): Quaternion =
|
override operator fun Number.minus(other: Quaternion): Quaternion =
|
||||||
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
|
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
|
||||||
@ -170,7 +249,12 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
|
Quaternion(toDouble() * arg.w, toDouble() * arg.x, toDouble() * arg.y, toDouble() * arg.z)
|
||||||
|
|
||||||
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
|
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
|
||||||
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)
|
override fun norm(arg: Quaternion): Double = sqrt(
|
||||||
|
arg.w.pow(2) +
|
||||||
|
arg.x.pow(2) +
|
||||||
|
arg.y.pow(2) +
|
||||||
|
arg.z.pow(2)
|
||||||
|
)
|
||||||
|
|
||||||
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
|
override fun bindSymbolOrNull(value: String): Quaternion? = when (value) {
|
||||||
"i" -> i
|
"i" -> i
|
||||||
@ -179,7 +263,7 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun number(value: Number): Quaternion = value.toQuaternion()
|
override fun number(value: Number): Quaternion = Quaternion(value)
|
||||||
|
|
||||||
override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0
|
override fun sinh(arg: Quaternion): Quaternion = (exp(arg) - exp(-arg)) / 2.0
|
||||||
override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0
|
override fun cosh(arg: Quaternion): Quaternion = (exp(arg) + exp(-arg)) / 2.0
|
||||||
@ -189,76 +273,6 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
|||||||
override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0
|
override fun atanh(arg: Quaternion): Quaternion = (ln(arg + one) - ln(one - arg)) / 2.0
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents `double`-based quaternion.
|
|
||||||
*
|
|
||||||
* @property w The first component.
|
|
||||||
* @property x The second component.
|
|
||||||
* @property y The third component.
|
|
||||||
* @property z The fourth component.
|
|
||||||
*/
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public data class Quaternion(
|
|
||||||
val w: Double, val x: Double, val y: Double, val z: Double,
|
|
||||||
) {
|
|
||||||
public constructor(w: Number, x: Number, y: Number, z: Number) : this(
|
|
||||||
w.toDouble(),
|
|
||||||
x.toDouble(),
|
|
||||||
y.toDouble(),
|
|
||||||
z.toDouble(),
|
|
||||||
)
|
|
||||||
|
|
||||||
public constructor(w: Number, x: Number, y: Number) : this(w.toDouble(), x.toDouble(), y.toDouble(), 0.0)
|
|
||||||
public constructor(w: Number, x: Number) : this(w.toDouble(), x.toDouble(), 0.0, 0.0)
|
|
||||||
public constructor(w: Number) : this(w.toDouble(), 0.0, 0.0, 0.0)
|
|
||||||
public constructor(wx: Complex, yz: Complex) : this(wx.re, wx.im, yz.re, yz.im)
|
|
||||||
public constructor(wx: Complex) : this(wx.re, wx.im, 0, 0)
|
|
||||||
|
|
||||||
init {
|
|
||||||
require(!w.isNaN()) { "w-component of quaternion is not-a-number" }
|
|
||||||
require(!x.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
require(!y.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
require(!z.isNaN()) { "x-component of quaternion is not-a-number" }
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a string representation of this quaternion.
|
|
||||||
*/
|
|
||||||
override fun toString(): String = "($w + $x * i + $y * j + $z * k)"
|
|
||||||
|
|
||||||
public companion object : MemorySpec<Quaternion> {
|
|
||||||
override val objectSize: Int
|
|
||||||
get() = 32
|
|
||||||
|
|
||||||
override fun MemoryReader.read(offset: Int): Quaternion =
|
|
||||||
Quaternion(readDouble(offset), readDouble(offset + 8), readDouble(offset + 16), readDouble(offset + 24))
|
|
||||||
|
|
||||||
override fun MemoryWriter.write(offset: Int, value: Quaternion) {
|
|
||||||
writeDouble(offset, value.w)
|
|
||||||
writeDouble(offset + 8, value.x)
|
|
||||||
writeDouble(offset + 16, value.y)
|
|
||||||
writeDouble(offset + 24, value.z)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a quaternion with real part equal to this real.
|
|
||||||
*
|
|
||||||
* @receiver the real part.
|
|
||||||
* @return a new quaternion.
|
|
||||||
*/
|
|
||||||
public fun Number.toQuaternion(): Quaternion = Quaternion(this)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a quaternion with `w`-component equal to `re`-component of given complex and `x`-component equal to
|
|
||||||
* `im`-component of given complex.
|
|
||||||
*
|
|
||||||
* @receiver the complex number.
|
|
||||||
* @return a new quaternion.
|
|
||||||
*/
|
|
||||||
public fun Complex.toQuaternion(): Quaternion = Quaternion(this)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the
|
* Creates a new buffer of quaternions with the specified [size], where each element is calculated by calling the
|
||||||
* specified [init] function.
|
* specified [init] function.
|
||||||
|
@ -6,10 +6,23 @@
|
|||||||
package space.kscience.kmath.complex
|
package space.kscience.kmath.complex
|
||||||
|
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal class QuaternionFieldTest {
|
internal class QuaternionTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testNorm() {
|
||||||
|
assertEquals(2.0, QuaternionField.norm(Quaternion(1.0, 1.0, 1.0, 1.0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testInverse() = QuaternionField {
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0)
|
||||||
|
assertBufferEquals(one, q * q.reciprocal, 1e-4)
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAddition() {
|
fun testAddition() {
|
||||||
assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) })
|
assertEquals(Quaternion(42, 42), QuaternionField { Quaternion(16, 16) + Quaternion(26, 26) })
|
@ -1,8 +1,6 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
id("ru.mipt.npm.gradle.common")
|
|
||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
// id("com.xcporter.metaview") version "0.0.5"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
|
@ -10,25 +10,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
|
||||||
* An exception is thrown when the expected and actual shape of NDArray differ.
|
|
||||||
*
|
|
||||||
* @property expected the expected shape.
|
|
||||||
* @property actual the actual shape.
|
|
||||||
*/
|
|
||||||
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
|
||||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
|
||||||
|
|
||||||
public typealias Shape = IntArray
|
|
||||||
|
|
||||||
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
|
||||||
|
|
||||||
public interface WithShape {
|
|
||||||
public val shape: Shape
|
|
||||||
|
|
||||||
public val indices: ShapeIndexer get() = DefaultStrides(shape)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base interface for all ND-algebra implementations.
|
* The base interface for all ND-algebra implementations.
|
||||||
*
|
*
|
||||||
|
@ -47,7 +47,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
|||||||
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
|
public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = ::Strides
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An exception is thrown when the expected and actual shape of NDArray differ.
|
||||||
|
*
|
||||||
|
* @property expected the expected shape.
|
||||||
|
* @property actual the actual shape.
|
||||||
|
*/
|
||||||
|
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||||
|
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||||
|
|
||||||
|
public class IndexOutOfShapeException(public val shape: Shape, public val index: IntArray) :
|
||||||
|
RuntimeException("Index ${index.contentToString()} is out of shape ${shape.contentToString()}")
|
||||||
|
|
||||||
|
public typealias Shape = IntArray
|
||||||
|
|
||||||
|
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
||||||
|
|
||||||
|
public interface WithShape {
|
||||||
|
public val shape: Shape
|
||||||
|
|
||||||
|
public val indices: ShapeIndexer get() = DefaultStrides(shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun requireIndexInShape(index: IntArray, shape: Shape) {
|
||||||
|
if (index.size != shape.size) throw IndexOutOfShapeException(index, shape)
|
||||||
|
shape.forEachIndexed { axis, axisShape ->
|
||||||
|
if (index[axis] !in 0 until axisShape) throw IndexOutOfShapeException(index, shape)
|
||||||
|
}
|
||||||
|
}
|
@ -66,7 +66,7 @@ public abstract class Strides: ShapeIndexer {
|
|||||||
/**
|
/**
|
||||||
* Simple implementation of [Strides].
|
* Simple implementation of [Strides].
|
||||||
*/
|
*/
|
||||||
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
|
public class DefaultStrides(override val shape: IntArray) : Strides() {
|
||||||
override val linearSize: Int get() = strides[shape.size]
|
override val linearSize: Int get() = strides[shape.size]
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -112,6 +112,7 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
/**
|
/**
|
||||||
* Cached builder for default strides
|
* Cached builder for default strides
|
||||||
*/
|
*/
|
||||||
|
@Deprecated("Replace by Strides(shape)")
|
||||||
public operator fun invoke(shape: IntArray): Strides =
|
public operator fun invoke(shape: IntArray): Strides =
|
||||||
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||||
}
|
}
|
||||||
@ -119,3 +120,8 @@ public class DefaultStrides private constructor(override val shape: IntArray) :
|
|||||||
|
|
||||||
@ThreadLocal
|
@ThreadLocal
|
||||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cached builder for default strides
|
||||||
|
*/
|
||||||
|
public fun Strides(shape: IntArray): Strides = defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
@ -101,7 +101,7 @@ public interface StructureND<out T> : Featured<StructureFeature>, WithShape {
|
|||||||
val bufferRepr: String = when (structure.shape.size) {
|
val bufferRepr: String = when (structure.shape.size) {
|
||||||
1 -> (0 until structure.shape[0]).map { structure[it] }
|
1 -> (0 until structure.shape[0]).map { structure[it] }
|
||||||
.joinToString(prefix = "[", postfix = "]", separator = ", ")
|
.joinToString(prefix = "[", postfix = "]", separator = ", ")
|
||||||
2 -> (0 until structure.shape[0]).joinToString(prefix = "[", postfix = "]", separator = ", ") { i ->
|
2 -> (0 until structure.shape[0]).joinToString(prefix = "[\n", postfix = "\n]", separator = ",\n") { i ->
|
||||||
(0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
|
(0 until structure.shape[1]).joinToString(prefix = " [", postfix = "]", separator = ", ") { j ->
|
||||||
structure[i, j].toString()
|
structure[i, j].toString()
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,30 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
|
||||||
|
public open class VirtualStructureND<T>(
|
||||||
|
override val shape: Shape,
|
||||||
|
public val producer: (IntArray) -> T,
|
||||||
|
) : StructureND<T> {
|
||||||
|
override fun get(index: IntArray): T {
|
||||||
|
requireIndexInShape(index, shape)
|
||||||
|
return producer(index)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class VirtualDoubleStructureND(
|
||||||
|
shape: Shape,
|
||||||
|
producer: (IntArray) -> Double,
|
||||||
|
) : VirtualStructureND<Double>(shape, producer)
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public class VirtualIntStructureND(
|
||||||
|
shape: Shape,
|
||||||
|
producer: (IntArray) -> Int,
|
||||||
|
) : VirtualStructureND<Int>(shape, producer)
|
@ -0,0 +1,32 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
public fun <T> StructureND<T>.roll(axis: Int, step: Int = 1): StructureND<T> {
|
||||||
|
require(axis in shape.indices) { "Axis $axis is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
|
return VirtualStructureND(shape) { index ->
|
||||||
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
|
if (indexAxis == axis) {
|
||||||
|
(index[indexAxis] + step).mod(shape[indexAxis])
|
||||||
|
} else {
|
||||||
|
index[indexAxis]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get(newIndex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T> StructureND<T>.roll(pair: Pair<Int, Int>, vararg others: Pair<Int, Int>): StructureND<T> {
|
||||||
|
val axisMap: Map<Int, Int> = mapOf(pair, *others)
|
||||||
|
require(axisMap.keys.all { it in shape.indices }) { "Some of axes ${axisMap.keys} is outside of shape dimensions: [0, ${shape.size})" }
|
||||||
|
return VirtualStructureND(shape) { index ->
|
||||||
|
val newIndex: IntArray = IntArray(index.size) { indexAxis ->
|
||||||
|
val offset = axisMap[indexAxis] ?: 0
|
||||||
|
(index[indexAxis] + offset).mod(shape[indexAxis])
|
||||||
|
}
|
||||||
|
get(newIndex)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class NdOperationsTest {
|
||||||
|
@Test
|
||||||
|
fun roll() {
|
||||||
|
val structure = DoubleField.ndAlgebra.structureND(5, 5) { index ->
|
||||||
|
index.sumOf { it.toDouble() }
|
||||||
|
}
|
||||||
|
|
||||||
|
println(StructureND.toString(structure))
|
||||||
|
|
||||||
|
val rolled = structure.roll(0,-1)
|
||||||
|
|
||||||
|
println(StructureND.toString(rolled))
|
||||||
|
|
||||||
|
assertEquals(4.0, rolled[0, 0])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -18,12 +18,12 @@ public class LazyStructureND<out T>(
|
|||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
private val cache: MutableMap<IntArray, Deferred<T>> = HashMap()
|
private val cache: MutableMap<IntArray, Deferred<T>> = HashMap()
|
||||||
|
|
||||||
public fun deferred(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
public fun async(index: IntArray): Deferred<T> = cache.getOrPut(index) {
|
||||||
scope.async(context = Dispatchers.Math) { function(index) }
|
scope.async(context = Dispatchers.Math) { function(index) }
|
||||||
}
|
}
|
||||||
|
|
||||||
public suspend fun await(index: IntArray): T = deferred(index).await()
|
public suspend fun await(index: IntArray): T = async(index).await()
|
||||||
override operator fun get(index: IntArray): T = runBlocking { deferred(index).await() }
|
override operator fun get(index: IntArray): T = runBlocking { async(index).await() }
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> {
|
override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
@ -33,8 +33,8 @@ public class LazyStructureND<out T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> StructureND<T>.deferred(index: IntArray): Deferred<T> =
|
public fun <T> StructureND<T>.async(index: IntArray): Deferred<T> =
|
||||||
if (this is LazyStructureND<T>) deferred(index) else CompletableDeferred(get(index))
|
if (this is LazyStructureND<T>) this@async.async(index) else CompletableDeferred(get(index))
|
||||||
|
|
||||||
public suspend fun <T> StructureND<T>.await(index: IntArray): T =
|
public suspend fun <T> StructureND<T>.await(index: IntArray): T =
|
||||||
if (this is LazyStructureND<T>) await(index) else get(index)
|
if (this is LazyStructureND<T>) await(index) else get(index)
|
||||||
|
@ -271,7 +271,9 @@ public object EjmlLinearSpaceDDRM : EjmlLinearSpace<Double, DoubleField, DMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -505,7 +507,9 @@ public object EjmlLinearSpaceFDRM : EjmlLinearSpace<Float, FloatField, FMatrixRM
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -734,7 +738,9 @@ public object EjmlLinearSpaceDSCC : EjmlLinearSpace<Double, DoubleField, DMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -963,7 +969,9 @@ public object EjmlLinearSpaceFSCC : EjmlLinearSpace<Float, FloatField, FMatrixSp
|
|||||||
}
|
}
|
||||||
|
|
||||||
else -> null
|
else -> null
|
||||||
}?.let(type::cast)
|
}?.let{
|
||||||
|
type.cast(it)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(PerformancePitfall::class)
|
||||||
|
|
||||||
package space.kscience.kmath.ejml
|
package space.kscience.kmath.ejml
|
||||||
|
|
||||||
import org.ejml.data.DMatrixRMaj
|
import org.ejml.data.DMatrixRMaj
|
||||||
@ -18,11 +20,11 @@ import kotlin.random.Random
|
|||||||
import kotlin.random.asJavaRandom
|
import kotlin.random.asJavaRandom
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
internal fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
||||||
fun <T : Any> assertMatrixEquals(expected: StructureND<T>, actual: StructureND<T>) {
|
|
||||||
assertTrue { StructureND.contentEquals(expected, actual) }
|
assertTrue { StructureND.contentEquals(expected, actual) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
internal class EjmlMatrixTest {
|
internal class EjmlMatrixTest {
|
||||||
private val random = Random(0)
|
private val random = Random(0)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ plugins {
|
|||||||
|
|
||||||
kotlin.sourceSets.commonMain {
|
kotlin.sourceSets.commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(projects.kmath.kmathComplex)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,12 +6,10 @@
|
|||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public interface Vector2D : Point<Double>, Vector {
|
public interface Vector2D : Point<Double>, Vector {
|
||||||
public val x: Double
|
public val x: Double
|
||||||
public val y: Double
|
public val y: Double
|
||||||
@ -29,7 +27,6 @@ public interface Vector2D : Point<Double>, Vector {
|
|||||||
public val Vector2D.r: Double
|
public val Vector2D.r: Double
|
||||||
get() = Euclidean2DSpace { norm() }
|
get() = Euclidean2DSpace { norm() }
|
||||||
|
|
||||||
@Suppress("FunctionName")
|
|
||||||
public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
public fun Vector2D(x: Double, y: Double): Vector2D = Vector2DImpl(x, y)
|
||||||
|
|
||||||
private data class Vector2DImpl(
|
private data class Vector2DImpl(
|
||||||
|
@ -6,12 +6,11 @@
|
|||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.math.sqrt
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
|
||||||
public interface Vector3D : Point<Double>, Vector {
|
public interface Vector3D : Point<Double>, Vector {
|
||||||
public val x: Double
|
public val x: Double
|
||||||
public val y: Double
|
public val y: Double
|
||||||
@ -31,6 +30,19 @@ public interface Vector3D : Point<Double>, Vector {
|
|||||||
@Suppress("FunctionName")
|
@Suppress("FunctionName")
|
||||||
public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
public fun Vector3D(x: Double, y: Double, z: Double): Vector3D = Vector3DImpl(x, y, z)
|
||||||
|
|
||||||
|
public fun Buffer<Double>.asVector3D(): Vector3D = object : Vector3D {
|
||||||
|
init {
|
||||||
|
require(this@asVector3D.size == 3) { "Buffer of size 3 is required for Vector3D" }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val x: Double get() = this@asVector3D[0]
|
||||||
|
override val y: Double get() = this@asVector3D[1]
|
||||||
|
override val z: Double get() = this@asVector3D[2]
|
||||||
|
|
||||||
|
override fun toString(): String = this@asVector3D.toString()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
public val Vector3D.r: Double get() = Euclidean3DSpace { norm() }
|
public val Vector3D.r: Double get() = Euclidean3DSpace { norm() }
|
||||||
|
|
||||||
private data class Vector3DImpl(
|
private data class Vector3DImpl(
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.geometry
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
//TODO move vector to receiver
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Project vector onto a line.
|
* Project vector onto a line.
|
||||||
* @param vector to project
|
* @param vector to project
|
@ -0,0 +1,100 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
import space.kscience.kmath.complex.Quaternion
|
||||||
|
import space.kscience.kmath.complex.QuaternionField
|
||||||
|
import space.kscience.kmath.complex.reciprocal
|
||||||
|
import space.kscience.kmath.linear.LinearSpace
|
||||||
|
import space.kscience.kmath.linear.Matrix
|
||||||
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.linear.matrix
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
|
import kotlin.math.pow
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
internal fun Vector3D.toQuaternion(): Quaternion = Quaternion(0.0, x, y, z)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Angle in radians denoted by this quaternion rotation
|
||||||
|
*/
|
||||||
|
public val Quaternion.theta: Double get() = kotlin.math.acos(w) * 2
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An axis of quaternion rotation
|
||||||
|
*/
|
||||||
|
public val Quaternion.vector: Vector3D
|
||||||
|
get() {
|
||||||
|
val sint2 = sqrt(1 - w * w)
|
||||||
|
|
||||||
|
return object : Vector3D {
|
||||||
|
override val x: Double get() = this@vector.x/sint2
|
||||||
|
override val y: Double get() = this@vector.y/sint2
|
||||||
|
override val z: Double get() = this@vector.z/sint2
|
||||||
|
override fun toString(): String = listOf(x, y, z).toString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rotate a vector in a [Euclidean3DSpace]
|
||||||
|
*/
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, q: Quaternion): Vector3D = with(QuaternionField) {
|
||||||
|
val p = vector.toQuaternion()
|
||||||
|
(q * p * q.reciprocal).vector
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use a composition of quaternions to create a rotation
|
||||||
|
*/
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, composition: QuaternionField.() -> Quaternion): Vector3D =
|
||||||
|
rotate(vector, QuaternionField.composition())
|
||||||
|
|
||||||
|
public fun Euclidean3DSpace.rotate(vector: Vector3D, matrix: Matrix<Double>): Vector3D {
|
||||||
|
require(matrix.colNum == 3 && matrix.rowNum == 3) { "Square 3x3 rotation matrix is required" }
|
||||||
|
return with(DoubleField.linearSpace) { matrix.dot(vector).asVector3D() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a [Quaternion] to a rotation matrix
|
||||||
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public fun Quaternion.toRotationMatrix(
|
||||||
|
linearSpace: LinearSpace<Double, *> = DoubleField.linearSpace,
|
||||||
|
): Matrix<Double> {
|
||||||
|
val s = QuaternionField.norm(this).pow(-2)
|
||||||
|
return linearSpace.matrix(3, 3)(
|
||||||
|
1.0 - 2 * s * (y * y + z * z), 2 * s * (x * y - z * w), 2 * s * (x * z + y * w),
|
||||||
|
2 * s * (x * y + z * w), 1.0 - 2 * s * (x * x + z * z), 2 * s * (y * z - x * w),
|
||||||
|
2 * s * (x * z - y * w), 2 * s * (y * z + x * w), 1.0 - 2 * s * (x * x + y * y)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* taken from https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf
|
||||||
|
*/
|
||||||
|
public fun Quaternion.Companion.fromRotationMatrix(matrix: Matrix<Double>): Quaternion {
|
||||||
|
val t: Double
|
||||||
|
val q = if (matrix[2, 2] < 0) {
|
||||||
|
if (matrix[0, 0] > matrix[1, 1]) {
|
||||||
|
t = 1 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]
|
||||||
|
Quaternion(t, matrix[0, 1] + matrix[1, 0], matrix[2, 0] + matrix[0, 2], matrix[1, 2] - matrix[2, 1])
|
||||||
|
} else {
|
||||||
|
t = 1 - matrix[0, 0] + matrix[1, 1] - matrix[2, 2]
|
||||||
|
Quaternion(matrix[0, 1] + matrix[1, 0], t, matrix[1, 2] + matrix[2, 1], matrix[2, 0] - matrix[0, 2])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (matrix[0, 0] < -matrix[1, 1]) {
|
||||||
|
t = 1 - matrix[0, 0] - matrix[1, 1] + matrix[2, 2]
|
||||||
|
Quaternion(matrix[2, 0] + matrix[0, 2], matrix[1, 2] + matrix[2, 1], t, matrix[0, 1] - matrix[1, 0])
|
||||||
|
} else {
|
||||||
|
t = 1 + matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
|
||||||
|
Quaternion(matrix[1, 2] - matrix[2, 1], matrix[2, 0] - matrix[0, 2], matrix[0, 1] - matrix[1, 0], t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return QuaternionField.invoke { q * (0.5 / sqrt(t)) }
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.geometry
|
||||||
|
|
||||||
|
import space.kscience.kmath.complex.Quaternion
|
||||||
|
import space.kscience.kmath.testutils.assertBufferEquals
|
||||||
|
import kotlin.test.Ignore
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class RotationTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun rotations() = with(Euclidean3DSpace) {
|
||||||
|
val vector = Vector3D(1.0, 1.0, 1.0)
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0)
|
||||||
|
val rotatedByQ = rotate(vector, q)
|
||||||
|
val matrix = q.toRotationMatrix()
|
||||||
|
val rotatedByM = rotate(vector,matrix)
|
||||||
|
|
||||||
|
assertBufferEquals(rotatedByQ, rotatedByM, 1e-4)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
fun rotationConversion() {
|
||||||
|
|
||||||
|
val q = Quaternion(1.0, 2.0, -3.0, 4.0)
|
||||||
|
|
||||||
|
val matrix = q.toRotationMatrix()
|
||||||
|
|
||||||
|
assertEquals(q, Quaternion.fromRotationMatrix(matrix))
|
||||||
|
}
|
||||||
|
}
|
@ -1,17 +1,15 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
id("ru.mipt.npm.gradle.common")
|
|
||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
}
|
}
|
||||||
|
|
||||||
kscience {
|
//apply(plugin = "kotlinx-atomicfu")
|
||||||
useAtomic()
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-core"))
|
api(project(":kmath-core"))
|
||||||
|
api(npmlibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
commonTest {
|
commonTest {
|
||||||
|
@ -6,12 +6,14 @@
|
|||||||
package space.kscience.kmath.histogram
|
package space.kscience.kmath.histogram
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.real.step
|
import space.kscience.kmath.real.step
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
class TreeHistogramTest {
|
class TreeHistogramTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
plugins {
|
plugins {
|
||||||
kotlin("multiplatform")
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
id("ru.mipt.npm.gradle.common")
|
|
||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,10 +3,6 @@ plugins {
|
|||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
}
|
}
|
||||||
|
|
||||||
kscience {
|
|
||||||
useAtomic()
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
all {
|
all {
|
||||||
languageSettings.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
|
languageSettings.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||||
@ -15,6 +11,7 @@ kotlin.sourceSets {
|
|||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
|
api(npmlibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,11 @@ plugins {
|
|||||||
id("ru.mipt.npm.gradle.native")
|
id("ru.mipt.npm.gradle.native")
|
||||||
}
|
}
|
||||||
|
|
||||||
kscience {
|
|
||||||
useAtomic()
|
|
||||||
}
|
|
||||||
|
|
||||||
kotlin.sourceSets {
|
kotlin.sourceSets {
|
||||||
commonMain {
|
commonMain {
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-coroutines"))
|
api(project(":kmath-coroutines"))
|
||||||
|
implementation(npmlibs.atomicfu)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.stat
|
|||||||
import kotlinx.coroutines.flow.first
|
import kotlinx.coroutines.flow.first
|
||||||
import space.kscience.kmath.chains.Chain
|
import space.kscience.kmath.chains.Chain
|
||||||
import space.kscience.kmath.chains.combine
|
import space.kscience.kmath.chains.combine
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
@ -30,6 +31,7 @@ public fun interface Sampler<out T : Any> {
|
|||||||
/**
|
/**
|
||||||
* Sample a bunch of values
|
* Sample a bunch of values
|
||||||
*/
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public fun <T : Any> Sampler<T>.sampleBuffer(
|
public fun <T : Any> Sampler<T>.sampleBuffer(
|
||||||
generator: RandomGenerator,
|
generator: RandomGenerator,
|
||||||
size: Int,
|
size: Int,
|
||||||
|
@ -66,7 +66,7 @@ class MCScopeTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@OptIn(ObsoleteCoroutinesApi::class)
|
@OptIn(DelicateCoroutinesApi::class)
|
||||||
fun compareResult(test: ATest) {
|
fun compareResult(test: ATest) {
|
||||||
val res1 = runBlocking(Dispatchers.Default) { test() }
|
val res1 = runBlocking(Dispatchers.Default) { test() }
|
||||||
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
val res2 = runBlocking(newSingleThreadContext("test")) { test() }
|
||||||
|
@ -7,6 +7,7 @@ import org.tensorflow.op.core.Constant
|
|||||||
import org.tensorflow.types.TFloat64
|
import org.tensorflow.types.TFloat64
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.DefaultStrides
|
import space.kscience.kmath.nd.DefaultStrides
|
||||||
import space.kscience.kmath.nd.Shape
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
@ -74,6 +75,7 @@ public class DoubleTensorFlowAlgebra internal constructor(
|
|||||||
*
|
*
|
||||||
* The resulting tensor is available outside of scope
|
* The resulting tensor is available outside of scope
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public fun DoubleField.produceWithTF(
|
public fun DoubleField.produceWithTF(
|
||||||
block: DoubleTensorFlowAlgebra.() -> StructureND<Double>,
|
block: DoubleTensorFlowAlgebra.() -> StructureND<Double>,
|
||||||
): StructureND<Double> = Graph().use { graph ->
|
): StructureND<Double> = Graph().use { graph ->
|
||||||
|
@ -117,6 +117,7 @@ public open class ViktorFieldOpsND :
|
|||||||
|
|
||||||
public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND
|
public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public open class ViktorFieldND(
|
public open class ViktorFieldND(
|
||||||
override val shape: Shape,
|
override val shape: Shape,
|
||||||
) : ViktorFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
) : ViktorFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
||||||
|
@ -1,7 +1,26 @@
|
|||||||
rootProject.name = "kmath"
|
rootProject.name = "kmath"
|
||||||
|
|
||||||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
|
||||||
|
|
||||||
|
dependencyResolutionManagement {
|
||||||
|
val toolsVersion: String by extra
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenLocal()
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
mavenCentral()
|
||||||
|
gradlePluginPortal()
|
||||||
|
}
|
||||||
|
|
||||||
|
versionCatalogs {
|
||||||
|
create("npmlibs") {
|
||||||
|
from("ru.mipt.npm:version-catalog:$toolsVersion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
include(
|
include(
|
||||||
|
":test-utils",
|
||||||
":kmath-memory",
|
":kmath-memory",
|
||||||
":kmath-complex",
|
":kmath-complex",
|
||||||
":kmath-core",
|
":kmath-core",
|
||||||
|
13
test-utils/build.gradle.kts
Normal file
13
test-utils/build.gradle.kts
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.gradle.mpp")
|
||||||
|
id("ru.mipt.npm.gradle.native")
|
||||||
|
}
|
||||||
|
|
||||||
|
kotlin.sourceSets {
|
||||||
|
commonMain {
|
||||||
|
dependencies {
|
||||||
|
api(projects.kmath.kmathCore)
|
||||||
|
api(kotlin("test"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ import space.kscience.kmath.operations.invoke
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertNotEquals
|
import kotlin.test.assertNotEquals
|
||||||
|
|
||||||
internal class FieldVerifier<T, out A : Field<T>>(
|
public class FieldVerifier<T, out A : Field<T>>(
|
||||||
algebra: A, a: T, b: T, c: T, x: Number,
|
algebra: A, a: T, b: T, c: T, x: Number,
|
||||||
) : RingVerifier<T, A>(algebra, a, b, c, x) {
|
) : RingVerifier<T, A>(algebra, a, b, c, x) {
|
||||||
|
|
@ -10,7 +10,7 @@ import space.kscience.kmath.operations.ScaleOperations
|
|||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
internal open class RingVerifier<T, out A>(algebra: A, a: T, b: T, c: T, x: Number) :
|
public open class RingVerifier<T, out A>(algebra: A, a: T, b: T, c: T, x: Number) :
|
||||||
SpaceVerifier<T, A>(algebra, a, b, c, x) where A : Ring<T>, A : ScaleOperations<T> {
|
SpaceVerifier<T, A>(algebra, a, b, c, x) where A : Ring<T>, A : ScaleOperations<T> {
|
||||||
|
|
||||||
override fun verify() {
|
override fun verify() {
|
@ -11,12 +11,12 @@ import space.kscience.kmath.operations.invoke
|
|||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertNotEquals
|
import kotlin.test.assertNotEquals
|
||||||
|
|
||||||
internal open class SpaceVerifier<T, out S>(
|
public open class SpaceVerifier<T, out S>(
|
||||||
override val algebra: S,
|
override val algebra: S,
|
||||||
val a: T,
|
public val a: T,
|
||||||
val b: T,
|
public val b: T,
|
||||||
val c: T,
|
public val c: T,
|
||||||
val x: Number,
|
public val x: Number,
|
||||||
) : AlgebraicVerifier<T, Ring<T>> where S : Ring<T>, S : ScaleOperations<T> {
|
) : AlgebraicVerifier<T, Ring<T>> where S : Ring<T>, S : ScaleOperations<T> {
|
||||||
override fun verify() {
|
override fun verify() {
|
||||||
algebra {
|
algebra {
|
20
test-utils/src/commonMain/kotlin/asserts.kt
Normal file
20
test-utils/src/commonMain/kotlin/asserts.kt
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.testutils
|
||||||
|
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.structures.indices
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
import kotlin.test.fail
|
||||||
|
|
||||||
|
public fun assertBufferEquals(expected: Buffer<Double>, result: Buffer<Double>, tolerance: Double = 1e-4) {
|
||||||
|
if (expected.size != result.size) {
|
||||||
|
fail("Expected size is ${expected.size}, but the result size is ${result.size}")
|
||||||
|
}
|
||||||
|
expected.indices.forEach {
|
||||||
|
assertEquals(expected[it], result[it], tolerance)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user