Merge branch 'dev' into feature/quaternion
# Conflicts: # CHANGELOG.md # examples/src/main/kotlin/kscience/kmath/stat/DistributionBenchmark.kt # kmath-complex/src/commonMain/kotlin/kscience/kmath/complex/ComplexNDField.kt # kmath-core/src/commonMain/kotlin/kscience/kmath/operations/Complex.kt # kmath-core/src/commonTest/kotlin/kscience/kmath/structures/NDFieldTest.kt
This commit is contained in:
commit
9748e0bfbe
25
CHANGELOG.md
25
CHANGELOG.md
@ -4,28 +4,29 @@
|
||||
### Added
|
||||
- `fun` annotation for SAM interfaces in library
|
||||
- Explicit `public` visibility for all public APIs
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140).
|
||||
- Better trigonometric and hyperbolic functions for `AutoDiffField` (https://github.com/mipt-npm/kmath/pull/140)
|
||||
- Automatic README generation for features (#139)
|
||||
- Native support for `memory`, `core` and `dimensions`
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136).
|
||||
- `kmath-ejml` to supply EJML SimpleMatrix wrapper (https://github.com/mipt-npm/kmath/pull/136)
|
||||
- A separate `Symbol` entity, which is used for global unbound symbol.
|
||||
- A `Symbol` indexing scope.
|
||||
- Basic optimization API for Commons-math.
|
||||
- Chi squared optimization for array-like data in CM
|
||||
- `Fitting` utility object in prob/stat
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`.
|
||||
- Coroutine-deterministic Monte-Carlo scope with a random number generator.
|
||||
- Some minor utilities to `kmath-for-real`.
|
||||
- Basic Quaternion vector support in `kmath-complex`.
|
||||
- ND4J support module submitting `NDStructure` and `NDAlgebra` over `INDArray`
|
||||
- Coroutine-deterministic Monte-Carlo scope with a random number generator
|
||||
- Some minor utilities to `kmath-for-real`
|
||||
- Generic operation result parameter to `MatrixContext`
|
||||
- New `MatrixFeature` interfaces for matrix decompositions
|
||||
- Basic Quaternion vector support in `kmath-complex`.
|
||||
|
||||
### Changed
|
||||
- Package changed from `scientifik` to `kscience.kmath`.
|
||||
- Gradle version: 6.6 -> 6.7.1
|
||||
- Package changed from `scientifik` to `kscience.kmath`
|
||||
- Gradle version: 6.6 -> 6.8
|
||||
- Minor exceptions refactor (throwing `IllegalArgumentException` by argument checks instead of `IllegalStateException`)
|
||||
- `Polynomial` secondary constructor made function.
|
||||
- Kotlin version: 1.3.72 -> 1.4.20
|
||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library.
|
||||
- `Polynomial` secondary constructor made function
|
||||
- Kotlin version: 1.3.72 -> 1.4.21
|
||||
- `kmath-ast` doesn't depend on heavy `kotlin-reflect` library
|
||||
- Full autodiff refactoring based on `Symbol`
|
||||
- `kmath-prob` renamed to `kmath-stat`
|
||||
- Grid generators moved to `kmath-for-real`
|
||||
@ -33,6 +34,8 @@
|
||||
- Optimized dot product for buffer matrices moved to `kmath-for-real`
|
||||
- EjmlMatrix context is an object
|
||||
- Matrix LUP `inverse` renamed to `inverseWithLUP`
|
||||
- `NumericAlgebra` moved outside of regular algebra chain (`Ring` no longer implements it).
|
||||
- Features moved to NDStructure and became transparent.
|
||||
- `Complex` and related features moved to a separate module `kmath-complex`
|
||||
|
||||
### Deprecated
|
||||
|
@ -4,7 +4,7 @@ plugins {
|
||||
id("ru.mipt.npm.project")
|
||||
}
|
||||
|
||||
internal val kmathVersion: String by extra("0.2.0-dev-4")
|
||||
internal val kmathVersion: String by extra("0.2.0-dev-5")
|
||||
internal val bintrayRepo: String by extra("kscience")
|
||||
internal val githubProject: String by extra("kmath")
|
||||
|
||||
|
@ -2,19 +2,22 @@ package kscience.kmath.benchmarks
|
||||
|
||||
import kotlinx.benchmark.Benchmark
|
||||
import kscience.kmath.commons.linear.CMMatrixContext
|
||||
import kscience.kmath.commons.linear.CMMatrixContext.dot
|
||||
import kscience.kmath.commons.linear.toCM
|
||||
import kscience.kmath.ejml.EjmlMatrixContext
|
||||
import kscience.kmath.ejml.toEjml
|
||||
import kscience.kmath.linear.BufferMatrixContext
|
||||
import kscience.kmath.linear.RealMatrixContext
|
||||
import kscience.kmath.linear.real
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.Matrix
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
class MultiplicationBenchmark {
|
||||
class DotBenchmark {
|
||||
companion object {
|
||||
val random = Random(12224)
|
||||
val dim = 1000
|
||||
@ -32,14 +35,14 @@ class MultiplicationBenchmark {
|
||||
|
||||
@Benchmark
|
||||
fun commonsMathMultiplication() {
|
||||
CMMatrixContext.invoke {
|
||||
CMMatrixContext {
|
||||
cmMatrix1 dot cmMatrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun ejmlMultiplication() {
|
||||
EjmlMatrixContext.invoke {
|
||||
EjmlMatrixContext {
|
||||
ejmlMatrix1 dot ejmlMatrix2
|
||||
}
|
||||
}
|
||||
@ -48,13 +51,22 @@ class MultiplicationBenchmark {
|
||||
fun ejmlMultiplicationwithConversion() {
|
||||
val ejmlMatrix1 = matrix1.toEjml()
|
||||
val ejmlMatrix2 = matrix2.toEjml()
|
||||
EjmlMatrixContext.invoke {
|
||||
EjmlMatrixContext {
|
||||
ejmlMatrix1 dot ejmlMatrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun bufferedMultiplication() {
|
||||
matrix1 dot matrix2
|
||||
BufferMatrixContext(RealField, Buffer.Companion::real).invoke{
|
||||
matrix1 dot matrix2
|
||||
}
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun realMultiplication(){
|
||||
RealMatrixContext {
|
||||
matrix1 dot matrix2
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package kscience.kmath.benchmarks
|
||||
|
||||
import kscience.kmath.structures.NDField
|
||||
import org.openjdk.jmh.annotations.Benchmark
|
||||
import org.openjdk.jmh.annotations.Scope
|
||||
import org.openjdk.jmh.annotations.State
|
||||
import org.openjdk.jmh.infra.Blackhole
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
class LargeNDBenchmark {
|
||||
val arraySize = 10000
|
||||
val RANDOM = Random(222)
|
||||
val src1 = DoubleArray(arraySize) { RANDOM.nextDouble() }
|
||||
val src2 = DoubleArray(arraySize) { RANDOM.nextDouble() }
|
||||
val field = NDField.real(arraySize)
|
||||
val kmathArray1 = field.produce { (a) -> src1[a] }
|
||||
val kmathArray2 = field.produce { (a) -> src2[a] }
|
||||
|
||||
@Benchmark
|
||||
fun test10000(bh: Blackhole) {
|
||||
bh.consume(field.add(kmathArray1, kmathArray2))
|
||||
}
|
||||
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
package kscience.kmath.stat
|
||||
package kscience.kmath.commons.prob
|
||||
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.async
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kscience.kmath.stat.*
|
||||
import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler
|
||||
import org.apache.commons.rng.simple.RandomSource
|
||||
import java.time.Duration
|
||||
|
@ -13,7 +13,7 @@ fun main() {
|
||||
val n = 1000
|
||||
|
||||
val realField = NDField.real(dim, dim)
|
||||
val complexField = NDField.complex(dim, dim)
|
||||
val complexField: ComplexNDField = NDField.complex(dim, dim)
|
||||
|
||||
val realTime = measureTimeMillis {
|
||||
realField {
|
||||
|
@ -33,7 +33,7 @@ fun main() {
|
||||
measureAndPrint("Automatic field addition") {
|
||||
autoField {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) { res += number(1.0) }
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ fun main() {
|
||||
measureAndPrint("Nd4j specialized addition") {
|
||||
nd4jField {
|
||||
var res = one
|
||||
repeat(n) { res += 1.0 as Number }
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ fun main() {
|
||||
genericField {
|
||||
var res: NDBuffer<Double> = one
|
||||
repeat(n) {
|
||||
res += one // couldn't avoid using `one` due to resolution ambiguity }
|
||||
res += 1.0 // couldn't avoid using `one` due to resolution ambiguity }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,8 @@ import kscience.kmath.dimensions.D2
|
||||
import kscience.kmath.dimensions.D3
|
||||
import kscience.kmath.dimensions.DMatrixContext
|
||||
import kscience.kmath.dimensions.Dimension
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
private fun DMatrixContext<Double, RealField>.simple() {
|
||||
private fun DMatrixContext<Double>.simple() {
|
||||
val m1 = produce<D2, D3> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D3, D2> { i, j -> (i + j).toDouble() }
|
||||
|
||||
@ -18,7 +17,7 @@ private object D5 : Dimension {
|
||||
override val dim: UInt = 5u
|
||||
}
|
||||
|
||||
private fun DMatrixContext<Double, RealField>.custom() {
|
||||
private fun DMatrixContext<Double>.custom() {
|
||||
val m1 = produce<D2, D5> { i, j -> (i + j).toDouble() }
|
||||
val m2 = produce<D5, D2> { i, j -> (i - j).toDouble() }
|
||||
val m3 = produce<D2, D2> { i, j -> (i - j).toDouble() }
|
||||
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.7.1-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -2,10 +2,9 @@ package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* A Mathematical Syntax Tree node for mathematical expressions.
|
||||
* A Mathematical Syntax Tree (MST) node for mathematical expressions.
|
||||
*
|
||||
* @author Alexander Nozik
|
||||
*/
|
||||
@ -57,21 +56,22 @@ public fun <T> Algebra<T>.evaluate(node: MST): T = when (node) {
|
||||
?: error("Numeric nodes are not supported by $this")
|
||||
|
||||
is MST.Symbolic -> symbol(node.value)
|
||||
is MST.Unary -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
|
||||
is MST.Unary -> when {
|
||||
this is NumericAlgebra && node.value is MST.Numeric -> unaryOperationFunction(node.operation)(number(node.value.value))
|
||||
else -> unaryOperationFunction(node.operation)(evaluate(node.value))
|
||||
}
|
||||
|
||||
is MST.Binary -> when {
|
||||
this !is NumericAlgebra -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
this is NumericAlgebra && node.left is MST.Numeric && node.right is MST.Numeric ->
|
||||
binaryOperationFunction(node.operation)(number(node.left.value), number(node.right.value))
|
||||
|
||||
node.left is MST.Numeric && node.right is MST.Numeric -> {
|
||||
val number = RealField
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
this is NumericAlgebra && node.left is MST.Numeric ->
|
||||
leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
|
||||
number(number)
|
||||
}
|
||||
this is NumericAlgebra && node.right is MST.Numeric ->
|
||||
rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
|
||||
node.left is MST.Numeric -> leftSideNumberOperationFunction(node.operation)(node.left.value, evaluate(node.right))
|
||||
node.right is MST.Numeric -> rightSideNumberOperationFunction(node.operation)(evaluate(node.left), node.right.value)
|
||||
else -> binaryOperationFunction(node.operation)(evaluate(node.left), evaluate(node.right))
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.ast
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
@ -25,8 +26,11 @@ public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
public override fun number(value: Number): MST.Numeric = MstAlgebra.number(value)
|
||||
public override fun symbol(value: String): MST.Symbolic = MstAlgebra.symbol(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(SpaceOperations.PLUS_OPERATION)(a, b)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this)
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this)
|
||||
public override operator fun MST.unaryPlus(): MST.Unary =
|
||||
unaryOperationFunction(SpaceOperations.PLUS_OPERATION)(this)
|
||||
|
||||
public override operator fun MST.unaryMinus(): MST.Unary =
|
||||
unaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this)
|
||||
|
||||
public override operator fun MST.minus(b: MST): MST.Binary =
|
||||
binaryOperationFunction(SpaceOperations.MINUS_OPERATION)(this, b)
|
||||
@ -44,7 +48,8 @@ public object MstSpace : Space<MST>, NumericAlgebra<MST> {
|
||||
/**
|
||||
* [Ring] over [MST] nodes.
|
||||
*/
|
||||
public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstRing : Ring<MST>, RingWithNumbers<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstSpace.zero
|
||||
|
||||
@ -54,7 +59,9 @@ public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
public override fun symbol(value: String): MST.Symbolic = MstSpace.symbol(value)
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstSpace.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstSpace.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = MstSpace { +this@unaryPlus }
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = MstSpace { -this@unaryMinus }
|
||||
public override operator fun MST.minus(b: MST): MST.Binary = MstSpace { this@minus - b }
|
||||
@ -69,7 +76,8 @@ public object MstRing : Ring<MST>, NumericAlgebra<MST> {
|
||||
/**
|
||||
* [Field] over [MST] nodes.
|
||||
*/
|
||||
public object MstField : Field<MST> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstField : Field<MST>, RingWithNumbers<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstRing.zero
|
||||
|
||||
@ -81,7 +89,9 @@ public object MstField : Field<MST> {
|
||||
public override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||
public override fun multiply(a: MST, k: Number): MST.Binary = MstRing.multiply(a, k)
|
||||
public override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary = binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
public override fun divide(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
|
||||
public override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
||||
public override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
||||
public override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b }
|
||||
@ -89,13 +99,14 @@ public object MstField : Field<MST> {
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstRing.binaryOperationFunction(operation)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstRing.unaryOperationFunction(operation)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstRing.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
/**
|
||||
* [ExtendedField] over [MST] nodes.
|
||||
*/
|
||||
public object MstExtendedField : ExtendedField<MST> {
|
||||
public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
||||
public override val zero: MST.Numeric
|
||||
get() = MstField.zero
|
||||
|
||||
@ -103,6 +114,7 @@ public object MstExtendedField : ExtendedField<MST> {
|
||||
get() = MstField.one
|
||||
|
||||
public override fun symbol(value: String): MST.Symbolic = MstField.symbol(value)
|
||||
public override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||
public override fun sin(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
public override fun cos(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.COS_OPERATION)(arg)
|
||||
public override fun tan(arg: MST): MST.Unary = unaryOperationFunction(TrigonometricOperations.TAN_OPERATION)(arg)
|
||||
@ -132,5 +144,6 @@ public object MstExtendedField : ExtendedField<MST> {
|
||||
public override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstField.binaryOperationFunction(operation)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary = MstField.unaryOperationFunction(operation)
|
||||
public override fun unaryOperationFunction(operation: String): (arg: MST) -> MST.Unary =
|
||||
MstField.unaryOperationFunction(operation)
|
||||
}
|
||||
|
@ -1,18 +1,18 @@
|
||||
package kscience.kmath.estree
|
||||
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.estree.internal.ESTreeBuilder
|
||||
import kscience.kmath.estree.internal.estree.BaseExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
@PublishedApi
|
||||
internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
fun ESTreeBuilder<T>.visit(node: MST): BaseExpression = when (node) {
|
||||
is MST.Symbolic -> {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
@ -25,25 +25,29 @@ internal fun <T> MST.compileWith(algebra: Algebra<T>): Expression<T> {
|
||||
variable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> constant(node.value)
|
||||
is MST.Unary -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||
is Numeric -> constant(node.value)
|
||||
|
||||
is MST.Binary -> when {
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> constant(
|
||||
algebra.number(
|
||||
RealField
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
)
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> constant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
else -> call(algebra.unaryOperationFunction(node.operation), visit(node.value))
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> constant(
|
||||
algebra
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> call(
|
||||
algebra is NumericAlgebra && node.left is Numeric -> call(
|
||||
algebra.leftSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> call(
|
||||
algebra is NumericAlgebra && node.right is Numeric -> call(
|
||||
algebra.rightSideNumberOperationFunction(node.operation),
|
||||
visit(node.left),
|
||||
visit(node.right),
|
||||
|
@ -3,11 +3,11 @@ package kscience.kmath.asm
|
||||
import kscience.kmath.asm.internal.AsmBuilder
|
||||
import kscience.kmath.asm.internal.buildName
|
||||
import kscience.kmath.ast.MST
|
||||
import kscience.kmath.ast.MST.*
|
||||
import kscience.kmath.ast.MstExpression
|
||||
import kscience.kmath.expressions.Expression
|
||||
import kscience.kmath.operations.Algebra
|
||||
import kscience.kmath.operations.NumericAlgebra
|
||||
import kscience.kmath.operations.RealField
|
||||
|
||||
/**
|
||||
* Compiles given MST to an Expression using AST compiler.
|
||||
@ -20,7 +20,7 @@ import kscience.kmath.operations.RealField
|
||||
@PublishedApi
|
||||
internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Expression<T> {
|
||||
fun AsmBuilder<T>.visit(node: MST): Unit = when (node) {
|
||||
is MST.Symbolic -> {
|
||||
is Symbolic -> {
|
||||
val symbol = try {
|
||||
algebra.symbol(node.value)
|
||||
} catch (ignored: IllegalStateException) {
|
||||
@ -33,24 +33,29 @@ internal fun <T : Any> MST.compileWith(type: Class<T>, algebra: Algebra<T>): Exp
|
||||
loadVariable(node.value)
|
||||
}
|
||||
|
||||
is MST.Numeric -> loadNumberConstant(node.value)
|
||||
is MST.Unary -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
is Numeric -> loadNumberConstant(node.value)
|
||||
|
||||
is MST.Binary -> when {
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric && node.right is MST.Numeric -> loadObjectConstant(
|
||||
algebra.number(
|
||||
RealField
|
||||
.binaryOperationFunction(node.operation)
|
||||
.invoke(node.left.value.toDouble(), node.right.value.toDouble())
|
||||
)
|
||||
is Unary -> when {
|
||||
algebra is NumericAlgebra && node.value is Numeric -> loadObjectConstant(
|
||||
algebra.unaryOperationFunction(node.operation)(algebra.number(node.value.value)))
|
||||
|
||||
else -> buildCall(algebra.unaryOperationFunction(node.operation)) { visit(node.value) }
|
||||
}
|
||||
|
||||
is Binary -> when {
|
||||
algebra is NumericAlgebra && node.left is Numeric && node.right is Numeric -> loadObjectConstant(
|
||||
algebra.binaryOperationFunction(node.operation)
|
||||
.invoke(algebra.number(node.left.value), algebra.number(node.right.value))
|
||||
)
|
||||
|
||||
algebra is NumericAlgebra<T> && node.left is MST.Numeric -> buildCall(algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
algebra is NumericAlgebra && node.left is Numeric -> buildCall(
|
||||
algebra.leftSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
||||
algebra is NumericAlgebra<T> && node.right is MST.Numeric -> buildCall(algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
algebra is NumericAlgebra && node.right is Numeric -> buildCall(
|
||||
algebra.rightSideNumberOperationFunction(node.operation)) {
|
||||
visit(node.left)
|
||||
visit(node.right)
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package kscience.kmath.commons.expressions
|
||||
|
||||
import kscience.kmath.expressions.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.ExtendedField
|
||||
import kscience.kmath.operations.RingWithNumbers
|
||||
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
|
||||
/**
|
||||
@ -10,15 +12,18 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
* @property order The derivation order.
|
||||
* @property bindings The map of bindings values. All bindings are considered free parameters
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class DerivativeStructureField(
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, Double>,
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure> {
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, RingWithNumbers<DerivativeStructure> {
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
public override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||
public override val one: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order, 1.0) }
|
||||
|
||||
override fun number(value: Number): DerivativeStructure = const(value.toDouble())
|
||||
|
||||
/**
|
||||
* A class that implements both [DerivativeStructure] and a [Symbol]
|
||||
*/
|
||||
|
@ -1,41 +1,28 @@
|
||||
package kscience.kmath.commons.linear
|
||||
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.linear.DiagonalFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.MatrixWrapper
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import org.apache.commons.math3.linear.*
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
public class CMMatrix(public val origin: RealMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
public inline class CMMatrix(public val origin: RealMatrix) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.rowDimension
|
||||
public override val colNum: Int get() = origin.columnDimension
|
||||
|
||||
public override val features: Set<MatrixFeature> = features ?: sequence<MatrixFeature> {
|
||||
if (origin is DiagonalMatrix) yield(DiagonalFeature)
|
||||
}.toHashSet()
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): CMMatrix =
|
||||
CMMatrix(origin, this.features + features)
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
|
||||
DiagonalFeature::class -> if (origin is DiagonalMatrix) DiagonalFeature else null
|
||||
else -> null
|
||||
}?.let { type.cast(it) }
|
||||
|
||||
public override operator fun get(i: Int, j: Int): Double = origin.getEntry(i, j)
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = origin.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.toCM(): CMMatrix = if (this is CMMatrix) {
|
||||
this
|
||||
} else {
|
||||
//TODO add feature analysis
|
||||
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
|
||||
CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
public fun RealMatrix.asMatrix(): CMMatrix = CMMatrix(this)
|
||||
|
||||
@ -60,6 +47,16 @@ public object CMMatrixContext : MatrixContext<Double, CMMatrix> {
|
||||
return CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
|
||||
public fun Matrix<Double>.toCM(): CMMatrix = when {
|
||||
this is CMMatrix -> this
|
||||
this is MatrixWrapper && matrix is CMMatrix -> matrix as CMMatrix
|
||||
else -> {
|
||||
//TODO add feature analysis
|
||||
val array = Array(rowNum) { i -> DoubleArray(colNum) { j -> get(i, j) } }
|
||||
CMMatrix(Array2DRowRealMatrix(array))
|
||||
}
|
||||
}
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): CMMatrix =
|
||||
CMMatrix(toCM().origin.multiply(other.toCM().origin))
|
||||
|
||||
|
@ -1,27 +1,31 @@
|
||||
package kscience.kmath.complex
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.operations.FieldElement
|
||||
import kscience.kmath.structures.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
/**
|
||||
* Convenience alias for [BufferedNDFieldElement] of [Complex].
|
||||
*/
|
||||
public typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
|
||||
|
||||
/**
|
||||
* An optimized nd-field for complex numbers
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class ComplexNDField(override val shape: IntArray) :
|
||||
BufferedNDField<Complex, ComplexField>,
|
||||
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
|
||||
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>>,
|
||||
RingWithNumbers<NDBuffer<Complex>>{
|
||||
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
override val elementContext: ComplexField get() = ComplexField
|
||||
override val zero: ComplexNDElement by lazy { produce { zero } }
|
||||
override val one: ComplexNDElement by lazy { produce { one } }
|
||||
|
||||
override fun number(value: Number): NDBuffer<Complex> {
|
||||
val c = value.toComplex()
|
||||
return produce { c }
|
||||
}
|
||||
|
||||
public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
||||
Buffer.complex(size) { initializer(it) }
|
||||
|
||||
@ -30,7 +34,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
*/
|
||||
override fun map(
|
||||
arg: NDBuffer<Complex>,
|
||||
transform: ComplexField.(Complex) -> Complex
|
||||
transform: ComplexField.(Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(arg)
|
||||
val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) }
|
||||
@ -44,7 +48,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
|
||||
override fun mapIndexed(
|
||||
arg: NDBuffer<Complex>,
|
||||
transform: ComplexField.(index: IntArray, Complex) -> Complex
|
||||
transform: ComplexField.(index: IntArray, Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(arg)
|
||||
|
||||
@ -61,7 +65,7 @@ public class ComplexNDField(override val shape: IntArray) :
|
||||
override fun combine(
|
||||
a: NDBuffer<Complex>,
|
||||
b: NDBuffer<Complex>,
|
||||
transform: ComplexField.(Complex, Complex) -> Complex
|
||||
transform: ComplexField.(Complex, Complex) -> Complex,
|
||||
): ComplexNDElement {
|
||||
check(a, b)
|
||||
|
||||
@ -142,7 +146,7 @@ public fun NDField.Companion.complex(vararg shape: Int): ComplexNDField = Comple
|
||||
|
||||
public fun NDElement.Companion.complex(
|
||||
vararg shape: Int,
|
||||
initializer: ComplexField.(IntArray) -> Complex
|
||||
initializer: ComplexField.(IntArray) -> Complex,
|
||||
): ComplexNDElement = NDField.complex(*shape).produce(initializer)
|
||||
|
||||
/**
|
||||
|
@ -7,8 +7,9 @@ import kscience.kmath.operations.*
|
||||
*
|
||||
* @param algebra The algebra to provide for Expressions built.
|
||||
*/
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val algebra: A) :
|
||||
ExpressionAlgebra<T, Expression<T>> {
|
||||
public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(
|
||||
public val algebra: A,
|
||||
) : ExpressionAlgebra<T, Expression<T>> {
|
||||
/**
|
||||
* Builds an Expression of constant expression which does not depend on arguments.
|
||||
*/
|
||||
@ -42,8 +43,9 @@ public abstract class FunctionalExpressionAlgebra<T, A : Algebra<T>>(public val
|
||||
/**
|
||||
* A context class for [Expression] construction for [Space] algebras.
|
||||
*/
|
||||
public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
public open class FunctionalExpressionSpace<T, A : Space<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionAlgebra<T, A>(algebra), Space<Expression<T>> {
|
||||
public override val zero: Expression<T> get() = const(algebra.zero)
|
||||
|
||||
/**
|
||||
@ -71,8 +73,9 @@ public open class FunctionalExpressionSpace<T, A : Space<T>>(algebra: A) :
|
||||
super<FunctionalExpressionAlgebra>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpressionSpace<T, A>(algebra),
|
||||
Ring<Expression<T>> where A : Ring<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionRing<T, A : Ring<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionSpace<T, A>(algebra), Ring<Expression<T>> {
|
||||
public override val one: Expression<T>
|
||||
get() = const(algebra.one)
|
||||
|
||||
@ -92,9 +95,8 @@ public open class FunctionalExpressionRing<T, A>(algebra: A) : FunctionalExpress
|
||||
super<FunctionalExpressionSpace>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>>
|
||||
where A : Field<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionField<T, A : Field<T>>(algebra: A) :
|
||||
FunctionalExpressionRing<T, A>(algebra), Field<Expression<T>> {
|
||||
/**
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
@ -111,9 +113,12 @@ public open class FunctionalExpressionField<T, A>(algebra: A) :
|
||||
super<FunctionalExpressionRing>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||
FunctionalExpressionField<T, A>(algebra),
|
||||
ExtendedField<Expression<T>> where A : ExtendedField<T>, A : NumericAlgebra<T> {
|
||||
public open class FunctionalExpressionExtendedField<T, A : ExtendedField<T>>(
|
||||
algebra: A,
|
||||
) : FunctionalExpressionField<T, A>(algebra), ExtendedField<Expression<T>> {
|
||||
|
||||
override fun number(value: Number): Expression<T> = const(algebra.number(value))
|
||||
|
||||
public override fun sin(arg: Expression<T>): Expression<T> =
|
||||
unaryOperationFunction(TrigonometricOperations.SIN_OPERATION)(arg)
|
||||
|
||||
@ -135,7 +140,8 @@ public open class FunctionalExpressionExtendedField<T, A>(algebra: A) :
|
||||
public override fun exp(arg: Expression<T>): Expression<T> =
|
||||
unaryOperationFunction(ExponentialOperations.EXP_OPERATION)(arg)
|
||||
|
||||
public override fun ln(arg: Expression<T>): Expression<T> = unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg)
|
||||
public override fun ln(arg: Expression<T>): Expression<T> =
|
||||
unaryOperationFunction(ExponentialOperations.LN_OPERATION)(arg)
|
||||
|
||||
public override fun unaryOperationFunction(operation: String): (arg: Expression<T>) -> Expression<T> =
|
||||
super<FunctionalExpressionField>.unaryOperationFunction(operation)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package kscience.kmath.expressions
|
||||
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
@ -79,10 +80,11 @@ public fun <T : Any, F : Field<T>> F.simpleAutoDiff(
|
||||
/**
|
||||
* Represents field in context of which functions can be derived.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>> {
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, RingWithNumbers<AutoDiffValue<T>> {
|
||||
public override val zero: AutoDiffValue<T>
|
||||
get() = const(context.zero)
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
@ -21,30 +20,11 @@ public class BufferMatrixContext<T : Any, R : Ring<T>>(
|
||||
public companion object
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
public object RealMatrixContext : GenericMatrixContext<Double, RealField, BufferMatrix<Double>> {
|
||||
public override val elementContext: RealField
|
||||
get() = RealField
|
||||
|
||||
public override inline fun produce(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (i: Int, j: Int) -> Double,
|
||||
): BufferMatrix<Double> {
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
public override inline fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
||||
RealBuffer(size, initializer)
|
||||
}
|
||||
|
||||
public class BufferMatrix<T : Any>(
|
||||
public override val rowNum: Int,
|
||||
public override val colNum: Int,
|
||||
public val buffer: Buffer<out T>,
|
||||
public override val features: Set<MatrixFeature> = emptySet(),
|
||||
) : FeaturedMatrix<T> {
|
||||
) : Matrix<T> {
|
||||
|
||||
init {
|
||||
require(buffer.size == rowNum * colNum) { "Dimension mismatch for matrix structure" }
|
||||
@ -52,9 +32,6 @@ public class BufferMatrix<T : Any>(
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): BufferMatrix<T> =
|
||||
BufferMatrix(rowNum, colNum, buffer, this.features + features)
|
||||
|
||||
public override operator fun get(index: IntArray): T = get(index[0], index[1])
|
||||
public override operator fun get(i: Int, j: Int): T = buffer[i * colNum + j]
|
||||
|
||||
@ -66,23 +43,26 @@ public class BufferMatrix<T : Any>(
|
||||
if (this === other) return true
|
||||
|
||||
return when (other) {
|
||||
is NDStructure<*> -> return NDStructure.equals(this, other)
|
||||
is NDStructure<*> -> NDStructure.equals(this, other)
|
||||
else -> false
|
||||
}
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = buffer.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
override fun hashCode(): Int {
|
||||
var result = rowNum
|
||||
result = 31 * result + colNum
|
||||
result = 31 * result + buffer.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
public override fun toString(): String {
|
||||
return if (rowNum <= 5 && colNum <= 5)
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)\n" +
|
||||
"Matrix(rowsNum = $rowNum, colNum = $colNum)\n" +
|
||||
rows.asSequence().joinToString(prefix = "(", postfix = ")", separator = "\n ") { buffer ->
|
||||
buffer.asSequence().joinToString(separator = "\t") { it.toString() }
|
||||
}
|
||||
else "Matrix(rowsNum = $rowNum, colNum = $colNum, features=$features)"
|
||||
else "Matrix(rowsNum = $rowNum, colNum = $colNum)"
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -1,83 +0,0 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kotlin.math.sqrt
|
||||
|
||||
/**
|
||||
* A 2d structure plus optional matrix-specific features
|
||||
*/
|
||||
public interface FeaturedMatrix<T : Any> : Matrix<T> {
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
public val features: Set<MatrixFeature>
|
||||
|
||||
/**
|
||||
* Suggest new feature for this matrix. The result is the new matrix that may or may not reuse existing data structure.
|
||||
*
|
||||
* The implementation does not guarantee to check that matrix actually have the feature, so one should be careful to
|
||||
* add only those features that are valid.
|
||||
*/
|
||||
public fun suggestFeature(vararg features: MatrixFeature): FeaturedMatrix<T>
|
||||
|
||||
public companion object
|
||||
}
|
||||
|
||||
public inline fun Structure2D.Companion.real(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (Int, Int) -> Double,
|
||||
): BufferMatrix<Double> = MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
public fun <T : Any> Structure2D.Companion.square(vararg elements: T): FeaturedMatrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
public val Matrix<*>.features: Set<MatrixFeature> get() = (this as? FeaturedMatrix)?.features ?: emptySet()
|
||||
|
||||
/**
|
||||
* Check if matrix has the given feature class
|
||||
*/
|
||||
public inline fun <reified T : Any> Matrix<*>.hasFeature(): Boolean =
|
||||
features.find { it is T } != null
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
public inline fun <reified T : Any> Matrix<*>.getFeature(): T? =
|
||||
features.filterIsInstance<T>().firstOrNull()
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix(rows, columns, DiagonalFeature) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): FeaturedMatrix<T> =
|
||||
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero }
|
||||
|
||||
public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
colNum,
|
||||
rowNum,
|
||||
setOf(TransposedFeature(this))
|
||||
) { i, j -> get(j, i) }
|
||||
}
|
@ -1,31 +1,31 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
/**
|
||||
* Common implementation of [LUPDecompositionFeature]
|
||||
* Common implementation of [LupDecompositionFeature].
|
||||
*/
|
||||
public class LUPDecomposition<T : Any>(
|
||||
public val context: MatrixContext<T, FeaturedMatrix<T>>,
|
||||
public class LupDecomposition<T : Any>(
|
||||
public val context: MatrixContext<T, Matrix<T>>,
|
||||
public val elementContext: Field<T>,
|
||||
public val lu: Structure2D<T>,
|
||||
public val lu: Matrix<T>,
|
||||
public val pivot: IntArray,
|
||||
private val even: Boolean,
|
||||
) : LUPDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
|
||||
) : LupDecompositionFeature<T>, DeterminantFeature<T> {
|
||||
/**
|
||||
* Returns the matrix L of the decomposition.
|
||||
*
|
||||
* L is a lower-triangular matrix with [Ring.one] in diagonal
|
||||
*/
|
||||
override val l: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(LFeature)) { i, j ->
|
||||
override val l: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
when {
|
||||
j < i -> lu[i, j]
|
||||
j == i -> elementContext.one
|
||||
else -> elementContext.zero
|
||||
}
|
||||
}
|
||||
} + LFeature
|
||||
|
||||
|
||||
/**
|
||||
@ -33,9 +33,9 @@ public class LUPDecomposition<T : Any>(
|
||||
*
|
||||
* U is an upper-triangular matrix including the diagonal
|
||||
*/
|
||||
override val u: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1], setOf(UFeature)) { i, j ->
|
||||
override val u: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j >= i) lu[i, j] else elementContext.zero
|
||||
}
|
||||
} + UFeature
|
||||
|
||||
/**
|
||||
* Returns the P rows permutation matrix.
|
||||
@ -43,7 +43,7 @@ public class LUPDecomposition<T : Any>(
|
||||
* P is a sparse matrix with exactly one element set to [Ring.one] in
|
||||
* each row and each column, all other elements being set to [Ring.zero].
|
||||
*/
|
||||
override val p: FeaturedMatrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
override val p: Matrix<T> = VirtualMatrix(lu.shape[0], lu.shape[1]) { i, j ->
|
||||
if (j == pivot[i]) elementContext.one else elementContext.zero
|
||||
}
|
||||
|
||||
@ -64,12 +64,12 @@ internal fun <T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, *>.abs
|
||||
/**
|
||||
* Create a lup decomposition of generic matrix.
|
||||
*/
|
||||
public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup(
|
||||
public fun <T : Comparable<T>> MatrixContext<T, Matrix<T>>.lup(
|
||||
factory: MutableBufferFactory<T>,
|
||||
elementContext: Field<T>,
|
||||
matrix: Matrix<T>,
|
||||
checkSingular: (T) -> Boolean,
|
||||
): LUPDecomposition<T> {
|
||||
): LupDecomposition<T> {
|
||||
require(matrix.rowNum == matrix.colNum) { "LU decomposition supports only square matrices" }
|
||||
val m = matrix.colNum
|
||||
val pivot = IntArray(matrix.rowNum)
|
||||
@ -138,20 +138,23 @@ public fun <T : Comparable<T>> MatrixContext<T, FeaturedMatrix<T>>.lup(
|
||||
for (row in col + 1 until m) lu[row, col] /= luDiag
|
||||
}
|
||||
|
||||
return LUPDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
|
||||
return LupDecomposition(this@lup, elementContext, lu.collect(), pivot, even)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.lup(
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.lup(
|
||||
matrix: Matrix<T>,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): LUPDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)
|
||||
): LupDecomposition<T> = lup(MutableBuffer.Companion::auto, elementContext, matrix, checkSingular)
|
||||
|
||||
public fun MatrixContext<Double, FeaturedMatrix<Double>>.lup(matrix: Matrix<Double>): LUPDecomposition<Double> =
|
||||
public fun MatrixContext<Double, Matrix<Double>>.lup(matrix: Matrix<Double>): LupDecomposition<Double> =
|
||||
lup(Buffer.Companion::real, RealField, matrix) { it < 1e-11 }
|
||||
|
||||
public fun <T : Any> LUPDecomposition<T>.solveWithLUP(factory: MutableBufferFactory<T>, matrix: Matrix<T>): FeaturedMatrix<T> {
|
||||
public fun <T : Any> LupDecomposition<T>.solveWithLUP(
|
||||
factory: MutableBufferFactory<T>,
|
||||
matrix: Matrix<T>,
|
||||
): Matrix<T> {
|
||||
require(matrix.rowNum == pivot.size) { "Matrix dimension mismatch. Expected ${pivot.size}, but got ${matrix.colNum}" }
|
||||
|
||||
BufferAccessor2D(matrix.rowNum, matrix.colNum, factory).run {
|
||||
@ -196,34 +199,40 @@ public fun <T : Any> LUPDecomposition<T>.solveWithLUP(factory: MutableBufferFact
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <reified T : Any> LUPDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> =
|
||||
public inline fun <reified T : Any> LupDecomposition<T>.solveWithLUP(matrix: Matrix<T>): Matrix<T> =
|
||||
solveWithLUP(MutableBuffer.Companion::auto, matrix)
|
||||
|
||||
/**
|
||||
* Solve a linear equation **a*x = b** using LUP decomposition
|
||||
*/
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.solveWithLUP(
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.solveWithLUP(
|
||||
a: Matrix<T>,
|
||||
b: Matrix<T>,
|
||||
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val decomposition = a.getFeature() ?: lup(bufferFactory, elementContext, a, checkSingular)
|
||||
return decomposition.solveWithLUP(bufferFactory, b)
|
||||
}
|
||||
|
||||
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): FeaturedMatrix<Double> =
|
||||
solveWithLUP(a, b) { it < 1e-11 }
|
||||
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, FeaturedMatrix<T>>.inverseWithLUP(
|
||||
public inline fun <reified T : Comparable<T>, F : Field<T>> GenericMatrixContext<T, F, Matrix<T>>.inverseWithLUP(
|
||||
matrix: Matrix<T>,
|
||||
noinline bufferFactory: MutableBufferFactory<T> = MutableBuffer.Companion::auto,
|
||||
noinline checkSingular: (T) -> Boolean,
|
||||
): FeaturedMatrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
||||
): Matrix<T> = solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), bufferFactory, checkSingular)
|
||||
|
||||
|
||||
public fun RealMatrixContext.solveWithLUP(a: Matrix<Double>, b: Matrix<Double>): Matrix<Double> {
|
||||
// Use existing decomposition if it is provided by matrix
|
||||
val bufferFactory: MutableBufferFactory<Double> = MutableBuffer.Companion::real
|
||||
val decomposition: LupDecomposition<Double> = a.getFeature() ?: lup(bufferFactory, RealField, a) { it < 1e-11 }
|
||||
return decomposition.solveWithLUP(bufferFactory, b)
|
||||
}
|
||||
|
||||
/**
|
||||
* Inverses a square matrix using LUP decomposition. Non square matrix will throw a error.
|
||||
*/
|
||||
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): FeaturedMatrix<Double> =
|
||||
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum), Buffer.Companion::real) { it < 1e-11 }
|
||||
public fun RealMatrixContext.inverseWithLUP(matrix: Matrix<Double>): Matrix<Double> =
|
||||
solveWithLUP(matrix, one(matrix.rowNum, matrix.colNum))
|
@ -1,12 +1,9 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.BufferFactory
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kscience.kmath.structures.*
|
||||
|
||||
public class MatrixBuilder(public val rows: Int, public val columns: Int) {
|
||||
public operator fun <T : Any> invoke(vararg elements: T): FeaturedMatrix<T> {
|
||||
public operator fun <T : Any> invoke(vararg elements: T): Matrix<T> {
|
||||
require(rows * columns == elements.size) { "The number of elements ${elements.size} is not equal $rows * $columns" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
@ -17,7 +14,7 @@ public class MatrixBuilder(public val rows: Int, public val columns: Int) {
|
||||
|
||||
public fun Structure2D.Companion.build(rows: Int, columns: Int): MatrixBuilder = MatrixBuilder(rows, columns)
|
||||
|
||||
public fun <T : Any> Structure2D.Companion.row(vararg values: T): FeaturedMatrix<T> {
|
||||
public fun <T : Any> Structure2D.Companion.row(vararg values: T): Matrix<T> {
|
||||
val buffer = values.asBuffer()
|
||||
return BufferMatrix(1, values.size, buffer)
|
||||
}
|
||||
@ -26,12 +23,12 @@ public inline fun <reified T : Any> Structure2D.Companion.row(
|
||||
size: Int,
|
||||
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||
noinline builder: (Int) -> T
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
val buffer = factory(size, builder)
|
||||
return BufferMatrix(1, size, buffer)
|
||||
}
|
||||
|
||||
public fun <T : Any> Structure2D.Companion.column(vararg values: T): FeaturedMatrix<T> {
|
||||
public fun <T : Any> Structure2D.Companion.column(vararg values: T): Matrix<T> {
|
||||
val buffer = values.asBuffer()
|
||||
return BufferMatrix(values.size, 1, buffer)
|
||||
}
|
||||
@ -40,7 +37,7 @@ public inline fun <reified T : Any> Structure2D.Companion.column(
|
||||
size: Int,
|
||||
factory: BufferFactory<T> = Buffer.Companion::auto,
|
||||
noinline builder: (Int) -> T
|
||||
): FeaturedMatrix<T> {
|
||||
): Matrix<T> {
|
||||
val buffer = factory(size, builder)
|
||||
return BufferMatrix(size, 1, buffer)
|
||||
}
|
||||
|
@ -18,6 +18,11 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
||||
*/
|
||||
public fun produce(rows: Int, columns: Int, initializer: (i: Int, j: Int) -> T): M
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space (and possibly optimized for it)
|
||||
*/
|
||||
public fun point(size: Int, initializer: (Int) -> T): Point<T> = Buffer.boxing(size, initializer)
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public override fun binaryOperationFunction(operation: String): (left: Matrix<T>, right: Matrix<T>) -> M =
|
||||
when (operation) {
|
||||
@ -62,10 +67,6 @@ public interface MatrixContext<T : Any, out M : Matrix<T>> : SpaceOperations<Mat
|
||||
public operator fun T.times(m: Matrix<T>): M = m * this
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Non-boxing double matrix
|
||||
*/
|
||||
public val real: RealMatrixContext = RealMatrixContext
|
||||
|
||||
/**
|
||||
* A structured matrix with custom buffer
|
||||
@ -89,11 +90,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
|
||||
*/
|
||||
public val elementContext: R
|
||||
|
||||
/**
|
||||
* Produce a point compatible with matrix space
|
||||
*/
|
||||
public fun point(size: Int, initializer: (Int) -> T): Point<T>
|
||||
|
||||
public override infix fun Matrix<T>.dot(other: Matrix<T>): M {
|
||||
//TODO add typed error
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
@ -137,8 +133,6 @@ public interface GenericMatrixContext<T : Any, R : Ring<T>, out M : Matrix<T>> :
|
||||
public override fun multiply(a: Matrix<T>, k: Number): M =
|
||||
produce(a.rowNum, a.colNum) { i, j -> elementContext { a[i, j] * k } }
|
||||
|
||||
public operator fun Number.times(matrix: FeaturedMatrix<T>): M = multiply(matrix, this)
|
||||
|
||||
public override operator fun Matrix<T>.times(value: T): M =
|
||||
produce(rowNum, colNum) { i, j -> elementContext { get(i, j) * value } }
|
||||
}
|
||||
|
@ -1,62 +1,158 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Matrix
|
||||
|
||||
/**
|
||||
* A marker interface representing some matrix feature like diagonal, sparse, zero, etc. Features used to optimize matrix
|
||||
* operations performance in some cases.
|
||||
* A marker interface representing some properties of matrices or additional transformations of them. Features are used
|
||||
* to optimize matrix operations performance in some cases or retrieve the APIs.
|
||||
*/
|
||||
public interface MatrixFeature
|
||||
|
||||
/**
|
||||
* The matrix with this feature is considered to have only diagonal non-null elements
|
||||
* Matrices with this feature are considered to have only diagonal non-null elements.
|
||||
*/
|
||||
public object DiagonalFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrix with this feature has all zero elements
|
||||
*/
|
||||
public object ZeroFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrix with this feature have unit elements on diagonal and zero elements in all other places
|
||||
*/
|
||||
public object UnitFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Inverted matrix feature
|
||||
*/
|
||||
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
public val inverse: FeaturedMatrix<T>
|
||||
public interface DiagonalFeature : MatrixFeature{
|
||||
public companion object: DiagonalFeature
|
||||
}
|
||||
|
||||
/**
|
||||
* A determinant container
|
||||
* Matrices with this feature have all zero elements.
|
||||
*/
|
||||
public object ZeroFeature : DiagonalFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature have unit elements on diagonal and zero elements in all other places.
|
||||
*/
|
||||
public object UnitFeature : DiagonalFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature can be inverted: [inverse] = `a`<sup>-1</sup> where `a` is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface InverseMatrixFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The inverse matrix of the matrix that owns this feature.
|
||||
*/
|
||||
public val inverse: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature can compute their determinant.
|
||||
*/
|
||||
public interface DeterminantFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The determinant of the matrix that owns this feature.
|
||||
*/
|
||||
public val determinant: T
|
||||
}
|
||||
|
||||
/**
|
||||
* Produces a [DeterminantFeature] where the [DeterminantFeature.determinant] is [determinant].
|
||||
*
|
||||
* @param determinant the value of determinant.
|
||||
* @return a new [DeterminantFeature].
|
||||
*/
|
||||
@Suppress("FunctionName")
|
||||
public fun <T : Any> DeterminantFeature(determinant: T): DeterminantFeature<T> = object : DeterminantFeature<T> {
|
||||
override val determinant: T = determinant
|
||||
}
|
||||
|
||||
/**
|
||||
* Lower triangular matrix
|
||||
* Matrices with this feature are lower triangular ones.
|
||||
*/
|
||||
public object LFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Upper triangular feature
|
||||
* Matrices with this feature are upper triangular ones.
|
||||
*/
|
||||
public object UFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* TODO add documentation
|
||||
* Matrices with this feature support LU factorization with partial pivoting: *[p] · a = [l] · [u]* where
|
||||
* *a* is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface LUPDecompositionFeature<T : Any> : MatrixFeature {
|
||||
public val l: FeaturedMatrix<T>
|
||||
public val u: FeaturedMatrix<T>
|
||||
public val p: FeaturedMatrix<T>
|
||||
public interface LupDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The lower triangular matrix in this decomposition. It may have [LFeature].
|
||||
*/
|
||||
public val l: Matrix<T>
|
||||
|
||||
/**
|
||||
* The upper triangular matrix in this decomposition. It may have [UFeature].
|
||||
*/
|
||||
public val u: Matrix<T>
|
||||
|
||||
/**
|
||||
* The permutation matrix in this decomposition.
|
||||
*/
|
||||
public val p: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature are orthogonal ones: *a · a<sup>T</sup> = u* where *a* is the owning matrix, *u*
|
||||
* is the unit matrix ([UnitFeature]).
|
||||
*/
|
||||
public object OrthogonalFeature : MatrixFeature
|
||||
|
||||
/**
|
||||
* Matrices with this feature support QR factorization: *a = [q] · [r]* where *a* is the owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface QRDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The orthogonal matrix in this decomposition. It may have [OrthogonalFeature].
|
||||
*/
|
||||
public val q: Matrix<T>
|
||||
|
||||
/**
|
||||
* The upper triangular matrix in this decomposition. It may have [UFeature].
|
||||
*/
|
||||
public val r: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature support Cholesky factorization: *a = [l] · [l]<sup>H</sup>* where *a* is the
|
||||
* owning matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface CholeskyDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The triangular matrix in this decomposition. It may have either [UFeature] or [LFeature].
|
||||
*/
|
||||
public val l: Matrix<T>
|
||||
}
|
||||
|
||||
/**
|
||||
* Matrices with this feature support SVD: *a = [u] · [s] · [v]<sup>H</sup>* where *a* is the owning
|
||||
* matrix.
|
||||
*
|
||||
* @param T the type of matrices' items.
|
||||
*/
|
||||
public interface SingularValueDecompositionFeature<T : Any> : MatrixFeature {
|
||||
/**
|
||||
* The matrix in this decomposition. It is unitary, and it consists from left singular vectors.
|
||||
*/
|
||||
public val u: Matrix<T>
|
||||
|
||||
/**
|
||||
* The matrix in this decomposition. Its main diagonal elements are singular values.
|
||||
*/
|
||||
public val s: Matrix<T>
|
||||
|
||||
/**
|
||||
* The matrix in this decomposition. It is unitary, and it consists from right singular vectors.
|
||||
*/
|
||||
public val v: Matrix<T>
|
||||
|
||||
/**
|
||||
* The buffer of singular values of this SVD.
|
||||
*/
|
||||
public val singularValues: Point<T>
|
||||
}
|
||||
|
||||
//TODO add sparse matrix feature
|
||||
|
@ -0,0 +1,97 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
import kscience.kmath.structures.asBuffer
|
||||
import kscience.kmath.structures.getFeature
|
||||
import kotlin.math.sqrt
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.safeCast
|
||||
|
||||
/**
|
||||
* A [Matrix] that holds [MatrixFeature] objects.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public class MatrixWrapper<T : Any>(
|
||||
public val matrix: Matrix<T>,
|
||||
public val features: Set<MatrixFeature>,
|
||||
) : Matrix<T> by matrix {
|
||||
|
||||
/**
|
||||
* Get the first feature matching given class. Does not guarantee that matrix has only one feature matching the criteria
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = type.safeCast(features.find { type.isInstance(it) })
|
||||
|
||||
override fun equals(other: Any?): Boolean = matrix == other
|
||||
override fun hashCode(): Int = matrix.hashCode()
|
||||
override fun toString(): String {
|
||||
return "MatrixWrapper(matrix=$matrix, features=$features)"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a single feature to a [Matrix]
|
||||
*/
|
||||
public operator fun <T : Any> Matrix<T>.plus(newFeature: MatrixFeature): MatrixWrapper<T> = if (this is MatrixWrapper) {
|
||||
MatrixWrapper(matrix, features + newFeature)
|
||||
} else {
|
||||
MatrixWrapper(this, setOf(newFeature))
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a collection of features to a [Matrix]
|
||||
*/
|
||||
public operator fun <T : Any> Matrix<T>.plus(newFeatures: Collection<MatrixFeature>): MatrixWrapper<T> =
|
||||
if (this is MatrixWrapper) {
|
||||
MatrixWrapper(matrix, features + newFeatures)
|
||||
} else {
|
||||
MatrixWrapper(this, newFeatures.toSet())
|
||||
}
|
||||
|
||||
public inline fun Structure2D.Companion.real(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (Int, Int) -> Double,
|
||||
): BufferMatrix<Double> = MatrixContext.real.produce(rows, columns, initializer)
|
||||
|
||||
/**
|
||||
* Build a square matrix from given elements.
|
||||
*/
|
||||
public fun <T : Any> Structure2D.Companion.square(vararg elements: T): Matrix<T> {
|
||||
val size: Int = sqrt(elements.size.toDouble()).toInt()
|
||||
require(size * size == elements.size) { "The number of elements ${elements.size} is not a full square" }
|
||||
val buffer = elements.asBuffer()
|
||||
return BufferMatrix(size, size, buffer)
|
||||
}
|
||||
|
||||
/**
|
||||
* Diagonal matrix of ones. The matrix is virtual no actual matrix is created
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.one(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix(rows, columns) { i, j ->
|
||||
if (i == j) elementContext.one else elementContext.zero
|
||||
} + UnitFeature
|
||||
|
||||
|
||||
/**
|
||||
* A virtual matrix of zeroes
|
||||
*/
|
||||
public fun <T : Any, R : Ring<T>> GenericMatrixContext<T, R, *>.zero(rows: Int, columns: Int): Matrix<T> =
|
||||
VirtualMatrix(rows, columns) { _, _ -> elementContext.zero } + ZeroFeature
|
||||
|
||||
public class TransposedFeature<T : Any>(public val original: Matrix<T>) : MatrixFeature
|
||||
|
||||
/**
|
||||
* Create a virtual transposed matrix without copying anything. `A.transpose().transpose() === A`
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun <T : Any> Matrix<T>.transpose(): Matrix<T> {
|
||||
return getFeature<TransposedFeature<T>>()?.original ?: VirtualMatrix(
|
||||
colNum,
|
||||
rowNum,
|
||||
) { i, j -> get(j, i) } + TransposedFeature(this)
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
package kscience.kmath.linear
|
||||
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
public object RealMatrixContext : MatrixContext<Double, BufferMatrix<Double>> {
|
||||
|
||||
public override inline fun produce(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: (i: Int, j: Int) -> Double,
|
||||
): BufferMatrix<Double> {
|
||||
val buffer = RealBuffer(rows * columns) { offset -> initializer(offset / columns, offset % columns) }
|
||||
return BufferMatrix(rows, columns, buffer)
|
||||
}
|
||||
|
||||
private fun Matrix<Double>.wrap(): BufferMatrix<Double> = if (this is BufferMatrix) this else {
|
||||
produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
public fun one(rows: Int, columns: Int): Matrix<Double> = VirtualMatrix(rows, columns) { i, j ->
|
||||
if (i == j) 1.0 else 0.0
|
||||
} + DiagonalFeature
|
||||
|
||||
public override infix fun Matrix<Double>.dot(other: Matrix<Double>): BufferMatrix<Double> {
|
||||
require(colNum == other.rowNum) { "Matrix dot operation dimension mismatch: ($rowNum, $colNum) x (${other.rowNum}, ${other.colNum})" }
|
||||
return produce(rowNum, other.colNum) { i, j ->
|
||||
var res = 0.0
|
||||
for (l in 0 until colNum) {
|
||||
res += get(i, l) * other.get(l, j)
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
public override infix fun Matrix<Double>.dot(vector: Point<Double>): Point<Double> {
|
||||
require(colNum == vector.size) { "Matrix dot vector operation dimension mismatch: ($rowNum, $colNum) x (${vector.size})" }
|
||||
return RealBuffer(rowNum) { i ->
|
||||
var res = 0.0
|
||||
for (j in 0 until colNum) {
|
||||
res += get(i, j) * vector[j]
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
override fun add(a: Matrix<Double>, b: Matrix<Double>): BufferMatrix<Double> {
|
||||
require(a.rowNum == b.rowNum) { "Row number mismatch in matrix addition. Left side: ${a.rowNum}, right side: ${b.rowNum}" }
|
||||
require(a.colNum == b.colNum) { "Column number mismatch in matrix addition. Left side: ${a.colNum}, right side: ${b.colNum}" }
|
||||
return produce(a.rowNum, a.colNum) { i, j ->
|
||||
a[i, j] + b[i, j]
|
||||
}
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.times(value: Double): BufferMatrix<Double> =
|
||||
produce(rowNum, colNum) { i, j -> get(i, j) * value }
|
||||
|
||||
|
||||
override fun multiply(a: Matrix<Double>, k: Number): BufferMatrix<Double> =
|
||||
produce(a.rowNum, a.colNum) { i, j -> a[i, j] * k.toDouble() }
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Partially optimized real-valued matrix
|
||||
*/
|
||||
public val MatrixContext.Companion.real: RealMatrixContext get() = RealMatrixContext
|
@ -5,31 +5,16 @@ import kscience.kmath.structures.Matrix
|
||||
public class VirtualMatrix<T : Any>(
|
||||
override val rowNum: Int,
|
||||
override val colNum: Int,
|
||||
override val features: Set<MatrixFeature> = emptySet(),
|
||||
public val generator: (i: Int, j: Int) -> T
|
||||
) : FeaturedMatrix<T> {
|
||||
public constructor(
|
||||
rowNum: Int,
|
||||
colNum: Int,
|
||||
vararg features: MatrixFeature,
|
||||
generator: (i: Int, j: Int) -> T
|
||||
) : this(
|
||||
rowNum,
|
||||
colNum,
|
||||
setOf(*features),
|
||||
generator
|
||||
)
|
||||
) : Matrix<T> {
|
||||
|
||||
override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = generator(i, j)
|
||||
|
||||
override fun suggestFeature(vararg features: MatrixFeature): VirtualMatrix<T> =
|
||||
VirtualMatrix(rowNum, colNum, this.features + features, generator)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is FeaturedMatrix<*>) return false
|
||||
if (other !is Matrix<*>) return false
|
||||
|
||||
if (rowNum != other.rowNum) return false
|
||||
if (colNum != other.colNum) return false
|
||||
@ -40,21 +25,9 @@ public class VirtualMatrix<T : Any>(
|
||||
override fun hashCode(): Int {
|
||||
var result = rowNum
|
||||
result = 31 * result + colNum
|
||||
result = 31 * result + features.hashCode()
|
||||
result = 31 * result + generator.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Wrap a matrix adding additional features to it
|
||||
*/
|
||||
public fun <T : Any> wrap(matrix: Matrix<T>, vararg features: MatrixFeature): FeaturedMatrix<T> {
|
||||
return if (matrix is VirtualMatrix)
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features, matrix.generator)
|
||||
else
|
||||
VirtualMatrix(matrix.rowNum, matrix.colNum, matrix.features + features) { i, j -> matrix[i, j] }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -88,85 +88,6 @@ public interface Algebra<T> {
|
||||
public fun binaryOperation(operation: String, left: T, right: T): T = binaryOperationFunction(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* An algebraic structure where elements can have numeric representation.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
*/
|
||||
public interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wraps a number to [T] object.
|
||||
*
|
||||
* @param value the number to wrap.
|
||||
* @return an object.
|
||||
*/
|
||||
public fun number(value: Number): T
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(number(l), r) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
leftSideNumberOperationFunction(operation)(left, right)
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(l, number(r)) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric second argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
rightSideNumberOperationFunction(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call a block with an [Algebra] as receiver.
|
||||
*/
|
||||
@ -341,47 +262,11 @@ public interface RingOperations<T> : SpaceOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this ring.
|
||||
*/
|
||||
public interface Ring<T> : Space<T>, RingOperations<T>, NumericAlgebra<T> {
|
||||
public interface Ring<T> : Space<T>, RingOperations<T> {
|
||||
/**
|
||||
* neutral operation for multiplication
|
||||
*/
|
||||
public val one: T
|
||||
|
||||
public override fun number(value: Number): T = one * value.toDouble()
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun T.plus(b: Number): T = this + number(b)
|
||||
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun Number.plus(b: T): T = b + this
|
||||
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: Number): T = this - number(b)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun Number.minus(b: T): T = -b + this
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||
import kscience.kmath.structures.*
|
||||
@ -16,7 +17,8 @@ public typealias TBase = ULong
|
||||
*
|
||||
* @author Robert Drynkin (https://github.com/robdrynkin) and Peter Klimai (https://github.com/pklimai)
|
||||
*/
|
||||
public object BigIntField : Field<BigInt> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object BigIntField : Field<BigInt>, RingWithNumbers<BigInt> {
|
||||
override val zero: BigInt = BigInt.ZERO
|
||||
override val one: BigInt = BigInt.ONE
|
||||
|
||||
|
@ -0,0 +1,219 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.memory.MemoryReader
|
||||
import kscience.kmath.memory.MemorySpec
|
||||
import kscience.kmath.memory.MemoryWriter
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.MemoryBuffer
|
||||
import kscience.kmath.structures.MutableBuffer
|
||||
import kscience.kmath.structures.MutableMemoryBuffer
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* This complex's conjugate.
|
||||
*/
|
||||
public val Complex.conjugate: Complex
|
||||
get() = Complex(re, -im)
|
||||
|
||||
/**
|
||||
* This complex's reciprocal.
|
||||
*/
|
||||
public val Complex.reciprocal: Complex
|
||||
get() {
|
||||
val scale = re * re + im * im
|
||||
return Complex(re / scale, -im / scale)
|
||||
}
|
||||
|
||||
/**
|
||||
* Absolute value of complex number.
|
||||
*/
|
||||
public val Complex.r: Double
|
||||
get() = sqrt(re * re + im * im)
|
||||
|
||||
/**
|
||||
* An angle between vector represented by complex number and X axis.
|
||||
*/
|
||||
public val Complex.theta: Double
|
||||
get() = atan(im / re)
|
||||
|
||||
private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
|
||||
/**
|
||||
* A field of [Complex].
|
||||
*/
|
||||
public object ComplexField : ExtendedField<Complex>, Norm<Complex, Complex>, RingWithNumbers<Complex> {
|
||||
override val zero: Complex = 0.0.toComplex()
|
||||
override val one: Complex = 1.0.toComplex()
|
||||
|
||||
/**
|
||||
* The imaginary unit.
|
||||
*/
|
||||
public val i: Complex = Complex(0.0, 1.0)
|
||||
|
||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||
|
||||
override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
||||
|
||||
override fun multiply(a: Complex, b: Complex): Complex =
|
||||
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||
|
||||
override fun divide(a: Complex, b: Complex): Complex = when {
|
||||
b.re.isNaN() || b.im.isNaN() -> Complex(Double.NaN, Double.NaN)
|
||||
|
||||
(if (b.im < 0) -b.im else +b.im) < (if (b.re < 0) -b.re else +b.re) -> {
|
||||
val wr = b.im / b.re
|
||||
val wd = b.re + wr * b.im
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
Complex(Double.NaN, Double.NaN)
|
||||
else
|
||||
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
|
||||
}
|
||||
|
||||
b.im == 0.0 -> Complex(Double.NaN, Double.NaN)
|
||||
|
||||
else -> {
|
||||
val wr = b.re / b.im
|
||||
val wd = b.im + wr * b.re
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
Complex(Double.NaN, Double.NaN)
|
||||
else
|
||||
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
|
||||
}
|
||||
}
|
||||
|
||||
override fun sin(arg: Complex): Complex = i * (exp(-i * arg) - exp(i * arg)) / 2
|
||||
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
||||
|
||||
override fun tan(arg: Complex): Complex {
|
||||
val e1 = exp(-i * arg)
|
||||
val e2 = exp(i * arg)
|
||||
return i * (e1 - e2) / (e1 + e2)
|
||||
}
|
||||
|
||||
override fun asin(arg: Complex): Complex = -i * ln(sqrt(1 - (arg * arg)) + i * arg)
|
||||
override fun acos(arg: Complex): Complex = PI_DIV_2 + i * ln(sqrt(1 - (arg * arg)) + i * arg)
|
||||
|
||||
override fun atan(arg: Complex): Complex {
|
||||
val iArg = i * arg
|
||||
return i * (ln(1 - iArg) - ln(1 + iArg)) / 2
|
||||
}
|
||||
|
||||
override fun power(arg: Complex, pow: Number): Complex = if (arg.im == 0.0)
|
||||
arg.re.pow(pow.toDouble()).toComplex()
|
||||
else
|
||||
exp(pow * ln(arg))
|
||||
|
||||
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
|
||||
|
||||
override fun ln(arg: Complex): Complex = ln(arg.r) + i * atan2(arg.im, arg.re)
|
||||
|
||||
/**
|
||||
* Adds complex number to real one.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param c the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
public operator fun Double.plus(c: Complex): Complex = add(this.toComplex(), c)
|
||||
|
||||
/**
|
||||
* Subtracts complex number from real one.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param c the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
public operator fun Double.minus(c: Complex): Complex = add(this.toComplex(), -c)
|
||||
|
||||
/**
|
||||
* Adds real number to complex one.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param d the augend.
|
||||
* @return the sum.
|
||||
*/
|
||||
public operator fun Complex.plus(d: Double): Complex = d + this
|
||||
|
||||
/**
|
||||
* Subtracts real number from complex one.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param d the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
public operator fun Complex.minus(d: Double): Complex = add(this, -d.toComplex())
|
||||
|
||||
/**
|
||||
* Multiplies real number by complex one.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param c the multiplicand.
|
||||
* @receiver the product.
|
||||
*/
|
||||
public operator fun Double.times(c: Complex): Complex = Complex(c.re * this, c.im * this)
|
||||
|
||||
override fun norm(arg: Complex): Complex = sqrt(arg.conjugate * arg)
|
||||
|
||||
override fun symbol(value: String): Complex = if (value == "i") i else super<ExtendedField>.symbol(value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents `double`-based complex number.
|
||||
*
|
||||
* @property re The real part.
|
||||
* @property im The imaginary part.
|
||||
*/
|
||||
public data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Complex, ComplexField>,
|
||||
Comparable<Complex> {
|
||||
public constructor(re: Number, im: Number) : this(re.toDouble(), im.toDouble())
|
||||
|
||||
override val context: ComplexField get() = ComplexField
|
||||
|
||||
override fun unwrap(): Complex = this
|
||||
|
||||
override fun Complex.wrap(): Complex = this
|
||||
|
||||
override fun compareTo(other: Complex): Int = r.compareTo(other.r)
|
||||
|
||||
override fun toString(): String {
|
||||
return "($re + i*$im)"
|
||||
}
|
||||
|
||||
|
||||
public companion object : MemorySpec<Complex> {
|
||||
override val objectSize: Int
|
||||
get() = 16
|
||||
|
||||
override fun MemoryReader.read(offset: Int): Complex = Complex(readDouble(offset), readDouble(offset + 8))
|
||||
|
||||
override fun MemoryWriter.write(offset: Int, value: Complex) {
|
||||
writeDouble(offset, value.re)
|
||||
writeDouble(offset + 8, value.im)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Creates a complex number with real part equal to this real.
|
||||
*
|
||||
* @receiver the real part.
|
||||
* @return the new complex number.
|
||||
*/
|
||||
public fun Number.toComplex(): Complex = Complex(this, 0.0)
|
||||
|
||||
/**
|
||||
* Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the
|
||||
* specified [init] function.
|
||||
*/
|
||||
public inline fun Buffer.Companion.complex(size: Int, init: (Int) -> Complex): Buffer<Complex> =
|
||||
MemoryBuffer.create(Complex, size, init)
|
||||
|
||||
/**
|
||||
* Creates a new buffer of complex numbers with the specified [size], where each element is calculated by calling the
|
||||
* specified [init] function.
|
||||
*/
|
||||
public inline fun MutableBuffer.Companion.complex(size: Int, init: (Int) -> Complex): MutableBuffer<Complex> =
|
||||
MutableMemoryBuffer.create(Complex, size, init)
|
@ -0,0 +1,125 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* An algebraic structure where elements can have numeric representation.
|
||||
*
|
||||
* @param T the type of element of this structure.
|
||||
*/
|
||||
public interface NumericAlgebra<T> : Algebra<T> {
|
||||
/**
|
||||
* Wraps a number to [T] object.
|
||||
*
|
||||
* @param value the number to wrap.
|
||||
* @return an object.
|
||||
*/
|
||||
public fun number(value: Number): T
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun leftSideNumberOperationFunction(operation: String): (left: Number, right: T) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(number(l), r) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with second [leftSideNumberOperation] overload:
|
||||
* i.e. `leftSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun leftSideNumberOperation(operation: String, left: Number, right: T): T =
|
||||
leftSideNumberOperationFunction(operation)(left, right)
|
||||
|
||||
/**
|
||||
* Dynamically dispatches a binary operation with the certain name with numeric first argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == leftSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @return an operation.
|
||||
*/
|
||||
public fun rightSideNumberOperationFunction(operation: String): (left: T, right: Number) -> T =
|
||||
{ l, r -> binaryOperationFunction(operation)(l, number(r)) }
|
||||
|
||||
/**
|
||||
* Dynamically invokes a binary operation with the certain name with numeric second argument.
|
||||
*
|
||||
* This function must follow two properties:
|
||||
*
|
||||
* 1. In case if operation is not defined in the structure, the function throws [kotlin.IllegalStateException].
|
||||
* 2. This function is symmetric with the other [rightSideNumberOperationFunction] overload:
|
||||
* i.e. `rightSideNumberOperationFunction(a)(b, c) == rightSideNumberOperation(a, b, c)`.
|
||||
*
|
||||
* @param operation the name of operation.
|
||||
* @param left the first argument of operation.
|
||||
* @param right the second argument of operation.
|
||||
* @return a result of operation.
|
||||
*/
|
||||
public fun rightSideNumberOperation(operation: String, left: T, right: Number): T =
|
||||
rightSideNumberOperationFunction(operation)(left, right)
|
||||
}
|
||||
|
||||
/**
|
||||
* A combination of [NumericAlgebra] and [Ring] that adds intrinsic simple operations on numbers like `T+1`
|
||||
* TODO to be removed and replaced by extensions after multiple receivers are there
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface RingWithNumbers<T>: Ring<T>, NumericAlgebra<T>{
|
||||
public override fun number(value: Number): T = one * value
|
||||
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun T.plus(b: Number): T = this + number(b)
|
||||
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the addend.
|
||||
* @param b the augend.
|
||||
*/
|
||||
public operator fun Number.plus(b: T): T = b + this
|
||||
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: Number): T = this - number(b)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun Number.minus(b: T): T = -b + this
|
||||
}
|
@ -37,7 +37,7 @@ public interface ExtendedFieldOperations<T> :
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations.
|
||||
*/
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T> {
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T> {
|
||||
public override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2
|
||||
public override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2
|
||||
public override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||
@ -80,6 +80,8 @@ public object RealField : ExtendedField<Double>, Norm<Double, Double> {
|
||||
public override val one: Double
|
||||
get() = 1.0
|
||||
|
||||
override fun number(value: Number): Double = value.toDouble()
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: Double, right: Double) -> Double =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
@ -131,10 +133,13 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
public override val one: Float
|
||||
get() = 1.0f
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float = when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
override fun number(value: Number): Float = value.toFloat()
|
||||
|
||||
public override fun binaryOperationFunction(operation: String): (left: Float, right: Float) -> Float =
|
||||
when (operation) {
|
||||
PowerOperations.POW_OPERATION -> ::power
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public override inline fun add(a: Float, b: Float): Float = a + b
|
||||
public override inline fun multiply(a: Float, k: Number): Float = a * k.toFloat()
|
||||
@ -174,13 +179,15 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
* A field for [Int] without boxing. Does not produce corresponding ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
|
||||
public override val zero: Int
|
||||
get() = 0
|
||||
|
||||
public override val one: Int
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Int = value.toInt()
|
||||
|
||||
public override inline fun add(a: Int, b: Int): Int = a + b
|
||||
public override inline fun multiply(a: Int, k: Number): Int = k.toInt() * a
|
||||
|
||||
@ -198,13 +205,15 @@ public object IntRing : Ring<Int>, Norm<Int, Int> {
|
||||
* A field for [Short] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> {
|
||||
public override val zero: Short
|
||||
get() = 0
|
||||
|
||||
public override val one: Short
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Short = value.toShort()
|
||||
|
||||
public override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
public override inline fun multiply(a: Short, k: Number): Short = (a * k.toShort()).toShort()
|
||||
|
||||
@ -222,13 +231,15 @@ public object ShortRing : Ring<Short>, Norm<Short, Short> {
|
||||
* A field for [Byte] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
|
||||
public override val zero: Byte
|
||||
get() = 0
|
||||
|
||||
public override val one: Byte
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Byte = value.toByte()
|
||||
|
||||
public override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
public override inline fun multiply(a: Byte, k: Number): Byte = (a * k.toByte()).toByte()
|
||||
|
||||
@ -246,13 +257,15 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte> {
|
||||
* A field for [Double] without boxing. Does not produce appropriate ring element.
|
||||
*/
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public object LongRing : Ring<Long>, Norm<Long, Long> {
|
||||
public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
|
||||
public override val zero: Long
|
||||
get() = 0L
|
||||
|
||||
public override val one: Long
|
||||
get() = 1L
|
||||
|
||||
override fun number(value: Number): Long = value.toLong()
|
||||
|
||||
public override inline fun add(a: Long, b: Long): Long = a + b
|
||||
public override inline fun multiply(a: Long, k: Number): Long = a * k.toLong()
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.native.concurrent.ThreadLocal
|
||||
import kotlin.reflect.KClass
|
||||
@ -38,9 +39,17 @@ public interface NDStructure<T> {
|
||||
*/
|
||||
public fun elements(): Sequence<Pair<IntArray, T>>
|
||||
|
||||
//force override equality and hash code
|
||||
public override fun equals(other: Any?): Boolean
|
||||
public override fun hashCode(): Int
|
||||
|
||||
/**
|
||||
* Feature is additional property or hint that does not directly affect the structure, but could in some cases help
|
||||
* optimize operations and performance. If the feature is not present, null is defined.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun <T : Any> getFeature(type: KClass<T>): T? = null
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Indicates whether some [NDStructure] is equal to another one.
|
||||
@ -120,6 +129,9 @@ public interface NDStructure<T> {
|
||||
*/
|
||||
public operator fun <T> NDStructure<T>.get(vararg index: Int): T = get(index)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public inline fun <reified T : Any> NDStructure<*>.getFeature(): T? = getFeature(T::class)
|
||||
|
||||
/**
|
||||
* Represents mutable [NDStructure].
|
||||
*/
|
||||
@ -133,6 +145,9 @@ public interface MutableNDStructure<T> : NDStructure<T> {
|
||||
public operator fun set(index: IntArray, value: T)
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform a structure element-by element in place.
|
||||
*/
|
||||
public inline fun <T> MutableNDStructure<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
|
||||
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||
|
||||
|
@ -150,6 +150,8 @@ public class RealBufferField(public val size: Int) : ExtendedField<Buffer<Double
|
||||
public override val zero: Buffer<Double> by lazy { RealBuffer(size) { 0.0 } }
|
||||
public override val one: Buffer<Double> by lazy { RealBuffer(size) { 1.0 } }
|
||||
|
||||
override fun number(value: Number): Buffer<Double> = RealBuffer(size) { value.toDouble() }
|
||||
|
||||
public override fun add(a: Buffer<Double>, b: Buffer<Double>): RealBuffer {
|
||||
require(a.size == size) { "The buffer size ${a.size} does not match context size $size" }
|
||||
return RealBufferFieldOperations.add(a, b)
|
||||
|
@ -1,13 +1,17 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.FieldElement
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.RingWithNumbers
|
||||
|
||||
public typealias RealNDElement = BufferedNDFieldElement<Double, RealField>
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class RealNDField(override val shape: IntArray) :
|
||||
BufferedNDField<Double, RealField>,
|
||||
ExtendedNDField<Double, RealField, NDBuffer<Double>> {
|
||||
ExtendedNDField<Double, RealField, NDBuffer<Double>>,
|
||||
RingWithNumbers<NDBuffer<Double>>{
|
||||
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
@ -15,7 +19,12 @@ public class RealNDField(override val shape: IntArray) :
|
||||
override val zero: RealNDElement by lazy { produce { zero } }
|
||||
override val one: RealNDElement by lazy { produce { one } }
|
||||
|
||||
public inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||
override fun number(value: Number): NDBuffer<Double> {
|
||||
val d = value.toDouble()
|
||||
return produce { d }
|
||||
}
|
||||
|
||||
private inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||
RealBuffer(DoubleArray(size) { initializer(it) })
|
||||
|
||||
/**
|
||||
@ -59,7 +68,8 @@ public class RealNDField(override val shape: IntArray) :
|
||||
check(a, b)
|
||||
return BufferedNDFieldElement(
|
||||
this,
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) }
|
||||
)
|
||||
}
|
||||
|
||||
override fun NDBuffer<Double>.toElement(): FieldElement<NDBuffer<Double>, *, out BufferedNDField<Double, RealField>> =
|
||||
|
@ -1,12 +1,42 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
/**
|
||||
* A structure that is guaranteed to be two-dimensional
|
||||
* A structure that is guaranteed to be two-dimensional.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public interface Structure2D<T> : NDStructure<T> {
|
||||
public val rowNum: Int get() = shape[0]
|
||||
public val colNum: Int get() = shape[1]
|
||||
/**
|
||||
* The number of rows in this structure.
|
||||
*/
|
||||
public val rowNum: Int
|
||||
|
||||
/**
|
||||
* The number of columns in this structure.
|
||||
*/
|
||||
public val colNum: Int
|
||||
|
||||
public override val shape: IntArray get() = intArrayOf(rowNum, colNum)
|
||||
|
||||
/**
|
||||
* The buffer of rows of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
public val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } }
|
||||
|
||||
/**
|
||||
* The buffer of columns of this structure. It gets elements from the structure dynamically.
|
||||
*/
|
||||
public val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
||||
|
||||
/**
|
||||
* Retrieves an element from the structure by two indices.
|
||||
*
|
||||
* @param i the first index.
|
||||
* @param j the second index.
|
||||
* @return an element.
|
||||
*/
|
||||
public operator fun get(i: Int, j: Int): T
|
||||
|
||||
override operator fun get(index: IntArray): T {
|
||||
@ -14,15 +44,9 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
return get(index[0], index[1])
|
||||
}
|
||||
|
||||
public val rows: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(rowNum) { i -> VirtualBuffer(colNum) { j -> get(i, j) } }
|
||||
|
||||
public val columns: Buffer<Buffer<T>>
|
||||
get() = VirtualBuffer(colNum) { j -> VirtualBuffer(rowNum) { i -> get(i, j) } }
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = sequence {
|
||||
for (i in (0 until rowNum))
|
||||
for (j in (0 until colNum)) yield(intArrayOf(i, j) to get(i, j))
|
||||
for (i in 0 until rowNum)
|
||||
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
|
||||
public companion object
|
||||
@ -34,6 +58,9 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
private inline class Structure2DWrapper<T>(val structure: NDStructure<T>) : Structure2D<T> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = structure.elements()
|
||||
@ -47,4 +74,9 @@ public fun <T> NDStructure<T>.as2D(): Structure2D<T> = if (shape.size == 2)
|
||||
else
|
||||
error("Can't create 2d-structure from ${shape.size}d-structure")
|
||||
|
||||
/**
|
||||
* Alias for [Structure2D] with more familiar name.
|
||||
*
|
||||
* @param T the type of items.
|
||||
*/
|
||||
public typealias Matrix<T> = Structure2D<T>
|
||||
|
@ -1,14 +1,13 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.testutils.FieldVerifier
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.operations.internal.FieldVerifier
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
internal class NDFieldTest {
|
||||
@Test
|
||||
fun verify() {
|
||||
(NDField.real(12, 32)) { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||
NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -7,7 +7,7 @@ import java.math.MathContext
|
||||
/**
|
||||
* A field over [BigInteger].
|
||||
*/
|
||||
public object JBigIntegerField : Field<BigInteger> {
|
||||
public object JBigIntegerField : Field<BigInteger>, NumericAlgebra<BigInteger> {
|
||||
public override val zero: BigInteger
|
||||
get() = BigInteger.ZERO
|
||||
|
||||
@ -28,9 +28,9 @@ public object JBigIntegerField : Field<BigInteger> {
|
||||
*
|
||||
* @property mathContext the [MathContext] to use.
|
||||
*/
|
||||
public abstract class JBigDecimalFieldBase internal constructor(public val mathContext: MathContext = MathContext.DECIMAL64) :
|
||||
Field<BigDecimal>,
|
||||
PowerOperations<BigDecimal> {
|
||||
public abstract class JBigDecimalFieldBase internal constructor(
|
||||
private val mathContext: MathContext = MathContext.DECIMAL64,
|
||||
) : Field<BigDecimal>, PowerOperations<BigDecimal>, NumericAlgebra<BigDecimal> {
|
||||
public override val zero: BigDecimal
|
||||
get() = BigDecimal.ZERO
|
||||
|
||||
|
@ -1,11 +1,6 @@
|
||||
package kscience.kmath.dimensions
|
||||
|
||||
import kscience.kmath.linear.GenericMatrixContext
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.transpose
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.Structure2D
|
||||
@ -42,9 +37,11 @@ public interface DMatrix<T, R : Dimension, C : Dimension> : Structure2D<T> {
|
||||
* An inline wrapper for a Matrix
|
||||
*/
|
||||
public inline class DMatrixWrapper<T, R : Dimension, C : Dimension>(
|
||||
private val structure: Structure2D<T>
|
||||
private val structure: Structure2D<T>,
|
||||
) : DMatrix<T, R, C> {
|
||||
override val shape: IntArray get() = structure.shape
|
||||
override val rowNum: Int get() = shape[0]
|
||||
override val colNum: Int get() = shape[1]
|
||||
override operator fun get(i: Int, j: Int): T = structure[i, j]
|
||||
}
|
||||
|
||||
@ -81,7 +78,7 @@ public inline class DPointWrapper<T, D : Dimension>(public val point: Point<T>)
|
||||
/**
|
||||
* Basic operations on dimension-safe matrices. Operates on [Matrix]
|
||||
*/
|
||||
public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: GenericMatrixContext<T, Ri, Matrix<T>>) {
|
||||
public inline class DMatrixContext<T : Any>(public val context: MatrixContext<T, Matrix<T>>) {
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> Matrix<T>.coerce(): DMatrix<T, R, C> {
|
||||
require(rowNum == Dimension.dim<R>().toInt()) {
|
||||
"Row number mismatch: expected ${Dimension.dim<R>()} but found $rowNum"
|
||||
@ -115,7 +112,7 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
||||
}
|
||||
|
||||
public inline infix fun <reified R1 : Dimension, reified C1 : Dimension, reified C2 : Dimension> DMatrix<T, R1, C1>.dot(
|
||||
other: DMatrix<T, C1, C2>
|
||||
other: DMatrix<T, C1, C2>,
|
||||
): DMatrix<T, R1, C2> = context { this@dot dot other }.coerce()
|
||||
|
||||
public inline infix fun <reified R : Dimension, reified C : Dimension> DMatrix<T, R, C>.dot(vector: DPoint<T, C>): DPoint<T, R> =
|
||||
@ -139,18 +136,20 @@ public inline class DMatrixContext<T : Any, Ri : Ring<T>>(public val context: Ge
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrix<T, C, R>.transpose(): DMatrix<T, R, C> =
|
||||
context { (this@transpose as Matrix<T>).transpose() }.coerce()
|
||||
|
||||
/**
|
||||
* A square unit matrix
|
||||
*/
|
||||
public inline fun <reified D : Dimension> one(): DMatrix<T, D, D> = produce { i, j ->
|
||||
if (i == j) context.elementContext.one else context.elementContext.zero
|
||||
}
|
||||
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> zero(): DMatrix<T, R, C> = produce { _, _ ->
|
||||
context.elementContext.zero
|
||||
}
|
||||
|
||||
public companion object {
|
||||
public val real: DMatrixContext<Double, RealField> = DMatrixContext(MatrixContext.real)
|
||||
public val real: DMatrixContext<Double> = DMatrixContext(MatrixContext.real)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A square unit matrix
|
||||
*/
|
||||
public inline fun <reified D : Dimension> DMatrixContext<Double>.one(): DMatrix<Double, D, D> = produce { i, j ->
|
||||
if (i == j) 1.0 else 0.0
|
||||
}
|
||||
|
||||
public inline fun <reified R : Dimension, reified C : Dimension> DMatrixContext<Double>.zero(): DMatrix<Double, R, C> =
|
||||
produce { _, _ ->
|
||||
0.0
|
||||
}
|
@ -3,6 +3,7 @@ package kscience.dimensions
|
||||
import kscience.kmath.dimensions.D2
|
||||
import kscience.kmath.dimensions.D3
|
||||
import kscience.kmath.dimensions.DMatrixContext
|
||||
import kscience.kmath.dimensions.one
|
||||
import kotlin.test.Test
|
||||
|
||||
internal class DMatrixContextTest {
|
||||
|
@ -1,12 +1,13 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import kscience.kmath.linear.DeterminantFeature
|
||||
import kscience.kmath.linear.FeaturedMatrix
|
||||
import kscience.kmath.linear.LUPDecompositionFeature
|
||||
import kscience.kmath.linear.MatrixFeature
|
||||
import kscience.kmath.structures.NDStructure
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
||||
/**
|
||||
* Represents featured matrix over EJML [SimpleMatrix].
|
||||
@ -14,58 +15,65 @@ import kscience.kmath.structures.NDStructure
|
||||
* @property origin the underlying [SimpleMatrix].
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public class EjmlMatrix(public val origin: SimpleMatrix, features: Set<MatrixFeature>? = null) : FeaturedMatrix<Double> {
|
||||
public override val rowNum: Int
|
||||
get() = origin.numRows()
|
||||
public inline class EjmlMatrix(
|
||||
public val origin: SimpleMatrix,
|
||||
) : Matrix<Double> {
|
||||
public override val rowNum: Int get() = origin.numRows()
|
||||
|
||||
public override val colNum: Int
|
||||
get() = origin.numCols()
|
||||
public override val colNum: Int get() = origin.numCols()
|
||||
|
||||
public override val shape: IntArray
|
||||
get() = intArrayOf(origin.numRows(), origin.numCols())
|
||||
|
||||
public override val features: Set<MatrixFeature> = setOf(
|
||||
object : LUPDecompositionFeature<Double>, DeterminantFeature<Double> {
|
||||
override val determinant: Double
|
||||
get() = origin.determinant()
|
||||
|
||||
private val lup by lazy {
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols())
|
||||
.also { it.decompose(origin.ddrm.copy()) }
|
||||
|
||||
Triple(
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getRowPivot(null))),
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getLower(null))),
|
||||
EjmlMatrix(SimpleMatrix(ludecompositionF64.getUpper(null))),
|
||||
)
|
||||
@UnstableKMathAPI
|
||||
override fun <T : Any> getFeature(type: KClass<T>): T? = when (type) {
|
||||
InverseMatrixFeature::class -> object : InverseMatrixFeature<Double> {
|
||||
override val inverse: Matrix<Double> by lazy { EjmlMatrix(origin.invert()) }
|
||||
}
|
||||
DeterminantFeature::class -> object : DeterminantFeature<Double> {
|
||||
override val determinant: Double by lazy(origin::determinant)
|
||||
}
|
||||
SingularValueDecompositionFeature::class -> object : SingularValueDecompositionFeature<Double> {
|
||||
private val svd by lazy {
|
||||
DecompositionFactory_DDRM.svd(origin.numRows(), origin.numCols(), true, true, false)
|
||||
.apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
override val l: FeaturedMatrix<Double>
|
||||
get() = lup.second
|
||||
|
||||
override val u: FeaturedMatrix<Double>
|
||||
get() = lup.third
|
||||
|
||||
override val p: FeaturedMatrix<Double>
|
||||
get() = lup.first
|
||||
override val u: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getU(null, false))) }
|
||||
override val s: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getW(null))) }
|
||||
override val v: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(svd.getV(null, false))) }
|
||||
override val singularValues: Point<Double> by lazy { RealBuffer(svd.singularValues) }
|
||||
}
|
||||
) union features.orEmpty()
|
||||
QRDecompositionFeature::class -> object : QRDecompositionFeature<Double> {
|
||||
private val qr by lazy {
|
||||
DecompositionFactory_DDRM.qr().apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
public override fun suggestFeature(vararg features: MatrixFeature): EjmlMatrix =
|
||||
EjmlMatrix(origin, this.features + features)
|
||||
override val q: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getQ(null, false))) }
|
||||
override val r: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(qr.getR(null, false))) }
|
||||
}
|
||||
CholeskyDecompositionFeature::class -> object : CholeskyDecompositionFeature<Double> {
|
||||
override val l: Matrix<Double> by lazy {
|
||||
val cholesky =
|
||||
DecompositionFactory_DDRM.chol(rowNum, true).apply { decompose(origin.ddrm.copy()) }
|
||||
|
||||
EjmlMatrix(SimpleMatrix(cholesky.getT(null))) + LFeature
|
||||
}
|
||||
}
|
||||
LupDecompositionFeature::class -> object : LupDecompositionFeature<Double> {
|
||||
private val lup by lazy {
|
||||
DecompositionFactory_DDRM.lu(origin.numRows(), origin.numCols()).apply { decompose(origin.ddrm.copy()) }
|
||||
}
|
||||
|
||||
override val l: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getLower(null))) + LFeature
|
||||
}
|
||||
|
||||
override val u: Matrix<Double> by lazy {
|
||||
EjmlMatrix(SimpleMatrix(lup.getUpper(null))) + UFeature
|
||||
}
|
||||
|
||||
override val p: Matrix<Double> by lazy { EjmlMatrix(SimpleMatrix(lup.getRowPivot(null))) }
|
||||
}
|
||||
else -> null
|
||||
}?.let{type.cast(it)}
|
||||
|
||||
public override operator fun get(i: Int, j: Int): Double = origin[i, j]
|
||||
|
||||
public override fun equals(other: Any?): Boolean {
|
||||
if (other is EjmlMatrix) return origin.isIdentical(other.origin, 0.0)
|
||||
return NDStructure.equals(this, other as? NDStructure<*> ?: return false)
|
||||
}
|
||||
|
||||
public override fun hashCode(): Int {
|
||||
var result = origin.hashCode()
|
||||
result = 31 * result + features.hashCode()
|
||||
return result
|
||||
}
|
||||
|
||||
public override fun toString(): String = "EjmlMatrix(origin=$origin, features=$features)"
|
||||
}
|
||||
|
@ -1,16 +1,14 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.InverseMatrixFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.MatrixWrapper
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.getFeature
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||
if (this is EjmlMatrix) this else EjmlMatrixContext.produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
|
||||
/**
|
||||
* Represents context of basic operations operating with [EjmlMatrix].
|
||||
*
|
||||
@ -18,6 +16,15 @@ public fun Matrix<Double>.toEjml(): EjmlMatrix =
|
||||
*/
|
||||
public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
|
||||
|
||||
/**
|
||||
* Converts this matrix to EJML one.
|
||||
*/
|
||||
public fun Matrix<Double>.toEjml(): EjmlMatrix = when {
|
||||
this is EjmlMatrix -> this
|
||||
this is MatrixWrapper && matrix is EjmlMatrix -> matrix as EjmlMatrix
|
||||
else -> produce(rowNum, colNum) { i, j -> get(i, j) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts this vector to EJML one.
|
||||
*/
|
||||
@ -33,6 +40,11 @@ public object EjmlMatrixContext : MatrixContext<Double, EjmlMatrix> {
|
||||
}
|
||||
})
|
||||
|
||||
override fun point(size: Int, initializer: (Int) -> Double): Point<Double> =
|
||||
EjmlVector(SimpleMatrix(size, 1).also {
|
||||
(0 until it.numRows()).forEach { row -> it[row, 0] = initializer(row) }
|
||||
})
|
||||
|
||||
public override fun Matrix<Double>.dot(other: Matrix<Double>): EjmlMatrix =
|
||||
EjmlMatrix(toEjml().origin.mult(other.toEjml().origin))
|
||||
|
||||
@ -74,11 +86,7 @@ public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Matrix<Double>): EjmlMa
|
||||
public fun EjmlMatrixContext.solve(a: Matrix<Double>, b: Point<Double>): EjmlVector =
|
||||
EjmlVector(a.toEjml().origin.solve(b.toEjml().origin))
|
||||
|
||||
/**
|
||||
* Returns the inverse of given matrix: b = a^(-1).
|
||||
*
|
||||
* @param a the matrix.
|
||||
* @return the inverse of this matrix.
|
||||
* @author Iaroslav Postovalov
|
||||
*/
|
||||
public fun EjmlMatrixContext.inverse(a: Matrix<Double>): EjmlMatrix = EjmlMatrix(a.toEjml().origin.invert())
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public fun EjmlMatrix.inverted(): EjmlMatrix = getFeature<InverseMatrixFeature<Double>>()!!.inverse as EjmlMatrix
|
||||
|
||||
public fun EjmlMatrixContext.inverse(matrix: Matrix<Double>): Matrix<Double> = matrix.toEjml().inverted()
|
@ -1,9 +1,10 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.DeterminantFeature
|
||||
import kscience.kmath.linear.LUPDecompositionFeature
|
||||
import kscience.kmath.linear.LupDecompositionFeature
|
||||
import kscience.kmath.linear.MatrixFeature
|
||||
import kscience.kmath.linear.getFeature
|
||||
import kscience.kmath.linear.plus
|
||||
import kscience.kmath.structures.getFeature
|
||||
import org.ejml.dense.row.factory.DecompositionFactory_DDRM
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
import kotlin.random.Random
|
||||
@ -44,7 +45,7 @@ internal class EjmlMatrixTest {
|
||||
val w = EjmlMatrix(m)
|
||||
val det = w.getFeature<DeterminantFeature<Double>>() ?: fail()
|
||||
assertEquals(m.determinant(), det.determinant)
|
||||
val lup = w.getFeature<LUPDecompositionFeature<Double>>() ?: fail()
|
||||
val lup = w.getFeature<LupDecompositionFeature<Double>>() ?: fail()
|
||||
|
||||
val ludecompositionF64 = DecompositionFactory_DDRM.lu(m.numRows(), m.numCols())
|
||||
.also { it.decompose(m.ddrm.copy()) }
|
||||
@ -58,7 +59,7 @@ internal class EjmlMatrixTest {
|
||||
|
||||
@Test
|
||||
fun suggestFeature() {
|
||||
assertNotNull(EjmlMatrix(randomMatrix).suggestFeature(SomeFeature).getFeature<SomeFeature>())
|
||||
assertNotNull((EjmlMatrix(randomMatrix) + SomeFeature).getFeature<SomeFeature>())
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1,14 +1,12 @@
|
||||
package kscience.kmath.real
|
||||
|
||||
import kscience.kmath.linear.FeaturedMatrix
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.RealMatrixContext.elementContext
|
||||
import kscience.kmath.linear.VirtualMatrix
|
||||
import kscience.kmath.linear.inverseWithLUP
|
||||
import kscience.kmath.linear.real
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.operations.sum
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import kscience.kmath.structures.asIterable
|
||||
import kotlin.math.pow
|
||||
@ -25,7 +23,7 @@ import kotlin.math.pow
|
||||
* Functions that help create a real (Double) matrix
|
||||
*/
|
||||
|
||||
public typealias RealMatrix = FeaturedMatrix<Double>
|
||||
public typealias RealMatrix = Matrix<Double>
|
||||
|
||||
public fun realMatrix(rowNum: Int, colNum: Int, initializer: (i: Int, j: Int) -> Double): RealMatrix =
|
||||
MatrixContext.real.produce(rowNum, colNum, initializer)
|
||||
@ -122,8 +120,7 @@ public fun RealMatrix.extractColumn(columnIndex: Int): RealMatrix =
|
||||
extractColumns(columnIndex..columnIndex)
|
||||
|
||||
public fun RealMatrix.sumByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
val column = columns[j]
|
||||
elementContext { sum(column.asIterable()) }
|
||||
columns[j].asIterable().sum()
|
||||
}
|
||||
|
||||
public fun RealMatrix.minByColumn(): RealBuffer = RealBuffer(colNum) { j ->
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kaceince.kmath.real
|
||||
|
||||
import kscience.kmath.linear.VirtualMatrix
|
||||
import kscience.kmath.linear.build
|
||||
import kscience.kmath.real.*
|
||||
import kscience.kmath.structures.Matrix
|
||||
@ -42,7 +41,7 @@ internal class RealMatrixTest {
|
||||
1.0, 0.0, 0.0,
|
||||
0.0, 1.0, 2.0
|
||||
)
|
||||
assertEquals(VirtualMatrix.wrap(matrix2), matrix1.repeatStackVertical(3))
|
||||
assertEquals(matrix2, matrix1.repeatStackVertical(3))
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1,5 +1,6 @@
|
||||
package kscience.kmath.nd4j
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.operations.*
|
||||
import kscience.kmath.structures.NDAlgebra
|
||||
import kscience.kmath.structures.NDField
|
||||
@ -35,7 +36,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
|
||||
public override fun mapIndexed(
|
||||
arg: Nd4jArrayStructure<T>,
|
||||
transform: C.(index: IntArray, T) -> T
|
||||
transform: C.(index: IntArray, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(arg)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
@ -46,7 +47,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
public override fun combine(
|
||||
a: Nd4jArrayStructure<T>,
|
||||
b: Nd4jArrayStructure<T>,
|
||||
transform: C.(T, T) -> T
|
||||
transform: C.(T, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
@ -61,8 +62,8 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C, Nd4jArrayStructure<T>>
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||
Nd4jArrayAlgebra<T, S> where S : Space<T> {
|
||||
public interface Nd4jArraySpace<T, S : Space<T>> : NDSpace<T, S, Nd4jArrayStructure<T>>, Nd4jArrayAlgebra<T, S> {
|
||||
|
||||
public override val zero: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
|
||||
@ -103,7 +104,9 @@ public interface Nd4jArraySpace<T, S> : NDSpace<T, S, Nd4jArrayStructure<T>>,
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> where R : Ring<T> {
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public interface Nd4jArrayRing<T, R : Ring<T>> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4jArraySpace<T, R> {
|
||||
|
||||
public override val one: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
@ -111,21 +114,21 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
|
||||
check(a, b)
|
||||
return a.ndArray.mul(b.ndArray).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.sub(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||
check(this)
|
||||
return ndArray.add(b).wrap()
|
||||
}
|
||||
|
||||
public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(b)
|
||||
return b.ndArray.rsub(this).wrap()
|
||||
}
|
||||
//
|
||||
// public override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
// check(this)
|
||||
// return ndArray.sub(b).wrap()
|
||||
// }
|
||||
//
|
||||
// public override operator fun Nd4jArrayStructure<T>.plus(b: Number): Nd4jArrayStructure<T> {
|
||||
// check(this)
|
||||
// return ndArray.add(b).wrap()
|
||||
// }
|
||||
//
|
||||
// public override operator fun Number.minus(b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
// check(b)
|
||||
// return b.ndArray.rsub(this).wrap()
|
||||
// }
|
||||
|
||||
public companion object {
|
||||
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||
@ -165,7 +168,8 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> where F : Field<T> {
|
||||
public interface Nd4jArrayField<T, F : Field<T>> : NDField<T, F, Nd4jArrayStructure<T>>, Nd4jArrayRing<T, F> {
|
||||
|
||||
public override fun divide(a: Nd4jArrayStructure<T>, b: Nd4jArrayStructure<T>): Nd4jArrayStructure<T> {
|
||||
check(a, b)
|
||||
return a.ndArray.div(b.ndArray).wrap()
|
||||
|
Loading…
Reference in New Issue
Block a user