Merge pull request #419 from mipt-npm/feature/multik
Feature/multik
This commit is contained in:
commit
ae8655d6af
@ -1,6 +0,0 @@
|
||||
<component name="CopyrightManager">
|
||||
<copyright>
|
||||
<option name="notice" value="Copyright 2018-2021 KMath contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file." />
|
||||
<option name="myName" value="kmath" />
|
||||
</copyright>
|
||||
</component>
|
@ -1,21 +0,0 @@
|
||||
<component name="CopyrightManager">
|
||||
<settings default="kmath">
|
||||
<module2copyright>
|
||||
<element module="Apply copyright" copyright="kmath" />
|
||||
</module2copyright>
|
||||
<LanguageOptions name="Groovy">
|
||||
<option name="fileTypeOverride" value="1" />
|
||||
</LanguageOptions>
|
||||
<LanguageOptions name="HTML">
|
||||
<option name="fileTypeOverride" value="1" />
|
||||
<option name="prefixLines" value="false" />
|
||||
</LanguageOptions>
|
||||
<LanguageOptions name="Properties">
|
||||
<option name="fileTypeOverride" value="1" />
|
||||
</LanguageOptions>
|
||||
<LanguageOptions name="XML">
|
||||
<option name="fileTypeOverride" value="1" />
|
||||
<option name="prefixLines" value="false" />
|
||||
</LanguageOptions>
|
||||
</settings>
|
||||
</component>
|
@ -1,4 +0,0 @@
|
||||
<component name="DependencyValidationManager">
|
||||
<scope name="Apply copyright"
|
||||
pattern="!file[*]:*//testData//*&&!file[*]:testData//*&&!file[*]:*.gradle.kts&&!file[*]:*.gradle&&!file[group:kotlin-ultimate]:*/&&!file[kotlin.libraries]:stdlib/api//*"/>
|
||||
</component>
|
@ -42,6 +42,9 @@
|
||||
- Use `Symbol` factory function instead of `StringSymbol`
|
||||
- New discoverability pattern: `<Type>.algebra.<nd/etc>`
|
||||
- Adjusted commons-math API for linear solvers to match conventions.
|
||||
- Buffer algebra does not require size anymore
|
||||
- Operations -> Ops
|
||||
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
|
||||
|
||||
### Deprecated
|
||||
- Specialized `DoubleBufferAlgebra`
|
||||
|
@ -48,6 +48,7 @@ kotlin {
|
||||
implementation(project(":kmath-nd4j"))
|
||||
implementation(project(":kmath-kotlingrad"))
|
||||
implementation(project(":kmath-viktor"))
|
||||
implementation(projects.kmathMultik)
|
||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||
// uncomment if your system supports AVX2
|
||||
// val os = System.getProperty("os.name")
|
||||
|
@ -9,56 +9,85 @@ import kotlinx.benchmark.Benchmark
|
||||
import kotlinx.benchmark.Blackhole
|
||||
import kotlinx.benchmark.Scope
|
||||
import kotlinx.benchmark.State
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ones
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.multik.multikND
|
||||
import space.kscience.kmath.multik.multikTensorAlgebra
|
||||
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.autoNdAlgebra
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.one
|
||||
import space.kscience.kmath.nd4j.nd4j
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import space.kscience.kmath.tensors.core.ones
|
||||
import space.kscience.kmath.tensors.core.one
|
||||
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||
import space.kscience.kmath.viktor.viktorAlgebra
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
internal class NDFieldBenchmark {
|
||||
@Benchmark
|
||||
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
|
||||
var res: StructureND<Double> = one
|
||||
repeat(n) { res += one }
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun multikAdd(blackhole: Blackhole) = with(multikField) {
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun viktorAdd(blackhole: Blackhole) = with(viktorField) {
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||
var res: DoubleTensor = ones(dim, dim)
|
||||
var res: DoubleTensor = one(shape)
|
||||
repeat(n) { res = res + 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||
val res: DoubleTensor = ones(dim, dim)
|
||||
val res: DoubleTensor = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) {
|
||||
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
|
||||
// @Benchmark
|
||||
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
|
||||
// var res: StructureND<Double> = one
|
||||
// var res: StructureND<Double> = one(dim, dim)
|
||||
// repeat(n) { res += 1.0 }
|
||||
// blackhole.consume(res)
|
||||
// }
|
||||
@ -66,9 +95,12 @@ internal class NDFieldBenchmark {
|
||||
private companion object {
|
||||
private const val dim = 1000
|
||||
private const val n = 100
|
||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
||||
private val specializedField = DoubleField.ndAlgebra(dim, dim)
|
||||
private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim)
|
||||
private val nd4jField = DoubleField.nd4j(dim, dim)
|
||||
private val shape = intArrayOf(dim, dim)
|
||||
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||
private val specializedField = DoubleField.ndAlgebra
|
||||
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
||||
private val nd4jField = DoubleField.nd4j
|
||||
private val multikField = DoubleField.multikND
|
||||
private val viktorField = DoubleField.viktorAlgebra
|
||||
}
|
||||
}
|
||||
|
@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole
|
||||
import kotlinx.benchmark.Scope
|
||||
import kotlinx.benchmark.State
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.autoNdAlgebra
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.viktor.ViktorNDField
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.viktor.ViktorFieldND
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
internal class ViktorBenchmark {
|
||||
@Benchmark
|
||||
fun automaticFieldAddition(blackhole: Blackhole) {
|
||||
with(autoField) {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
@ -30,7 +29,7 @@ internal class ViktorBenchmark {
|
||||
@Benchmark
|
||||
fun realFieldAddition(blackhole: Blackhole) {
|
||||
with(realField) {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
@ -39,7 +38,7 @@ internal class ViktorBenchmark {
|
||||
@Benchmark
|
||||
fun viktorFieldAddition(blackhole: Blackhole) {
|
||||
with(viktorField) {
|
||||
var res = one
|
||||
var res = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
@ -56,10 +55,11 @@ internal class ViktorBenchmark {
|
||||
private companion object {
|
||||
private const val dim = 1000
|
||||
private const val n = 100
|
||||
private val shape = Shape(dim, dim)
|
||||
|
||||
// automatically build context most suited for given type.
|
||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
||||
private val realField = DoubleField.ndAlgebra(dim, dim)
|
||||
private val viktorField = ViktorNDField(dim, dim)
|
||||
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||
private val realField = DoubleField.ndAlgebra
|
||||
private val viktorField = ViktorFieldND(dim, dim)
|
||||
}
|
||||
}
|
||||
|
@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole
|
||||
import kotlinx.benchmark.Scope
|
||||
import kotlinx.benchmark.State
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
import space.kscience.kmath.nd.autoNdAlgebra
|
||||
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.one
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.viktor.ViktorFieldND
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
internal class ViktorLogBenchmark {
|
||||
@Benchmark
|
||||
fun realFieldLog(blackhole: Blackhole) {
|
||||
with(realNdField) {
|
||||
val fortyTwo = produce { 42.0 }
|
||||
var res = one
|
||||
with(realField) {
|
||||
val fortyTwo = structureND(shape) { 42.0 }
|
||||
var res = one(shape)
|
||||
repeat(n) { res = ln(fortyTwo) }
|
||||
blackhole.consume(res)
|
||||
}
|
||||
@ -30,7 +33,7 @@ internal class ViktorLogBenchmark {
|
||||
@Benchmark
|
||||
fun viktorFieldLog(blackhole: Blackhole) {
|
||||
with(viktorField) {
|
||||
val fortyTwo = produce { 42.0 }
|
||||
val fortyTwo = structureND(shape) { 42.0 }
|
||||
var res = one
|
||||
repeat(n) { res = ln(fortyTwo) }
|
||||
blackhole.consume(res)
|
||||
@ -48,10 +51,11 @@ internal class ViktorLogBenchmark {
|
||||
private companion object {
|
||||
private const val dim = 1000
|
||||
private const val n = 100
|
||||
private val shape = Shape(dim, dim)
|
||||
|
||||
// automatically build context most suited for given type.
|
||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
||||
private val realNdField = DoubleField.ndAlgebra(dim, dim)
|
||||
private val viktorField = ViktorFieldND(intArrayOf(dim, dim))
|
||||
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||
private val realField = DoubleField.ndAlgebra
|
||||
private val viktorField = ViktorFieldND(dim, dim)
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,11 @@ dependencies {
|
||||
implementation(project(":kmath-tensors"))
|
||||
implementation(project(":kmath-symja"))
|
||||
implementation(project(":kmath-for-real"))
|
||||
//jafama
|
||||
implementation(project(":kmath-jafama"))
|
||||
//multik
|
||||
implementation(projects.kmathMultik)
|
||||
|
||||
|
||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||
|
||||
@ -42,11 +47,12 @@ dependencies {
|
||||
// } else
|
||||
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||
|
||||
implementation("org.slf4j:slf4j-simple:1.7.31")
|
||||
// multik implementation
|
||||
implementation("org.jetbrains.kotlinx:multik-default:0.1.0")
|
||||
|
||||
implementation("org.slf4j:slf4j-simple:1.7.32")
|
||||
// plotting
|
||||
implementation("space.kscience:plotlykt-server:0.4.2")
|
||||
//jafama
|
||||
implementation(project(":kmath-jafama"))
|
||||
implementation("space.kscience:plotlykt-server:0.5.0")
|
||||
}
|
||||
|
||||
kotlin.sourceSets.all {
|
||||
|
@ -9,6 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator
|
||||
import space.kscience.kmath.integration.integrate
|
||||
import space.kscience.kmath.integration.value
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.structureND
|
||||
import space.kscience.kmath.nd.withNdAlgebra
|
||||
import space.kscience.kmath.operations.algebra
|
||||
import space.kscience.kmath.operations.invoke
|
||||
@ -17,7 +18,7 @@ fun main(): Unit = Double.algebra {
|
||||
withNdAlgebra(2, 2) {
|
||||
|
||||
//Produce a diagonal StructureND
|
||||
fun diagonal(v: Double) = produce { (i, j) ->
|
||||
fun diagonal(v: Double) = structureND { (i, j) ->
|
||||
if (i == j) v else 0.0
|
||||
}
|
||||
|
||||
|
@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra
|
||||
import space.kscience.kmath.complex.ndAlgebra
|
||||
import space.kscience.kmath.nd.BufferND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.structureND
|
||||
|
||||
fun main() = Complex.algebra {
|
||||
val complex = 2 + 2 * i
|
||||
println(complex * 8 - 5 * i)
|
||||
|
||||
//flat buffer
|
||||
val buffer = bufferAlgebra(8).run {
|
||||
buffer { Complex(it, -it) }.map { Complex(it.im, it.re) }
|
||||
val buffer = with(bufferAlgebra){
|
||||
buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) }
|
||||
}
|
||||
println(buffer)
|
||||
|
||||
|
||||
// 2d element
|
||||
val element: BufferND<Complex> = ndAlgebra(2, 2).produce { (i, j) ->
|
||||
val element: BufferND<Complex> = ndAlgebra.structureND(2, 2) { (i, j) ->
|
||||
Complex(i - j, i + j)
|
||||
}
|
||||
println(element)
|
||||
|
||||
// 1d element operation
|
||||
val result: StructureND<Complex> = ndAlgebra(8).run {
|
||||
val a = produce { (it) -> i * it - it.toDouble() }
|
||||
val result: StructureND<Complex> = ndAlgebra{
|
||||
val a = structureND(8) { (it) -> i * it - it.toDouble() }
|
||||
val b = 3
|
||||
val c = Complex(1.0, 1.0)
|
||||
|
||||
|
@ -0,0 +1,24 @@
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.linear.matrix
|
||||
import space.kscience.kmath.nd.DoubleBufferND
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.viktor.ViktorStructureND
|
||||
import space.kscience.kmath.viktor.viktorAlgebra
|
||||
|
||||
fun main() {
|
||||
val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.structureND(Shape(2, 2)) { (i, j) ->
|
||||
if (i == j) 2.0 else 0.0
|
||||
}
|
||||
|
||||
val cmMatrix: Structure2D<Double> = CMLinearSpace.matrix(2, 2)(0.0, 1.0, 0.0, 3.0)
|
||||
|
||||
val res: DoubleBufferND = DoubleField.ndAlgebra {
|
||||
exp(viktorStructure) + 2.0 * cmMatrix
|
||||
}
|
||||
|
||||
println(res)
|
||||
}
|
@ -12,6 +12,7 @@ import space.kscience.kmath.linear.transpose
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.structureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.system.measureTimeMillis
|
||||
@ -54,7 +55,7 @@ fun complexExample() {
|
||||
val x = one * 2.5
|
||||
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
|
||||
//a structure generator specific to this context
|
||||
val matrix = produce { (k, l) -> k + l * i }
|
||||
val matrix = structureND { (k, l) -> k + l * i }
|
||||
//Perform sum
|
||||
val sum = matrix + x + 1.0
|
||||
|
||||
|
@ -8,13 +8,11 @@ package space.kscience.kmath.structures
|
||||
import kotlinx.coroutines.DelicateCoroutinesApi
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.autoNdAlgebra
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd4j.Nd4jArrayField
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.nd4j.nd4j
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.viktor.ViktorNDField
|
||||
import space.kscience.kmath.viktor.ViktorFieldND
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.system.measureTimeMillis
|
||||
@ -31,37 +29,39 @@ fun main() {
|
||||
Nd4j.zeros(0)
|
||||
val dim = 1000
|
||||
val n = 1000
|
||||
val shape = Shape(dim, dim)
|
||||
|
||||
|
||||
// automatically build context most suited for given type.
|
||||
val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
||||
val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||
// specialized nd-field for Double. It works as generic Double field as well.
|
||||
val realField = DoubleField.ndAlgebra(dim, dim)
|
||||
val realField = DoubleField.ndAlgebra
|
||||
//A generic boxing field. It should be used for objects, not primitives.
|
||||
val boxingField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim)
|
||||
val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
||||
// Nd4j specialized field.
|
||||
val nd4jField = Nd4jArrayField.real(dim, dim)
|
||||
val nd4jField = DoubleField.nd4j
|
||||
//viktor field
|
||||
val viktorField = ViktorNDField(dim, dim)
|
||||
val viktorField = ViktorFieldND(dim, dim)
|
||||
//parallel processing based on Java Streams
|
||||
val parallelField = DoubleField.ndStreaming(dim, dim)
|
||||
|
||||
measureAndPrint("Boxing addition") {
|
||||
boxingField {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Specialized addition") {
|
||||
realField {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Nd4j specialized addition") {
|
||||
nd4jField {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
@ -82,13 +82,13 @@ fun main() {
|
||||
|
||||
measureAndPrint("Automatic field addition") {
|
||||
autoField {
|
||||
var res: StructureND<Double> = one
|
||||
var res: StructureND<Double> = one(shape)
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Lazy addition") {
|
||||
val res = realField.one.mapAsync(GlobalScope) {
|
||||
val res = realField.one(shape).mapAsync(GlobalScope) {
|
||||
var c = 0.0
|
||||
repeat(n) {
|
||||
c += 1.0
|
||||
|
@ -8,7 +8,7 @@ package space.kscience.kmath.structures
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.ExtendedField
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.NumbersAddOps
|
||||
import java.util.*
|
||||
import java.util.stream.IntStream
|
||||
|
||||
@ -17,17 +17,17 @@ import java.util.stream.IntStream
|
||||
* execution.
|
||||
*/
|
||||
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
||||
NumbersAddOperations<StructureND<Double>>,
|
||||
NumbersAddOps<StructureND<Double>>,
|
||||
ExtendedField<StructureND<Double>> {
|
||||
|
||||
private val strides = DefaultStrides(shape)
|
||||
override val elementContext: DoubleField get() = DoubleField
|
||||
override val zero: BufferND<Double> by lazy { produce { zero } }
|
||||
override val one: BufferND<Double> by lazy { produce { one } }
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
override val zero: BufferND<Double> by lazy { structureND(shape) { zero } }
|
||||
override val one: BufferND<Double> by lazy { structureND(shape) { one } }
|
||||
|
||||
override fun number(value: Number): BufferND<Double> {
|
||||
val d = value.toDouble() // minimize conversions
|
||||
return produce { d }
|
||||
return structureND(shape) { d }
|
||||
}
|
||||
|
||||
private val StructureND<Double>.buffer: DoubleBuffer
|
||||
@ -36,11 +36,11 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
||||
this@StreamDoubleFieldND.shape,
|
||||
shape
|
||||
)
|
||||
this is BufferND && this.strides == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
|
||||
this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
|
||||
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||
}
|
||||
|
||||
override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
||||
override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
||||
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||
val index = strides.index(offset)
|
||||
DoubleField.initializer(index)
|
||||
@ -69,13 +69,13 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
||||
return BufferND(strides, array.asBuffer())
|
||||
}
|
||||
|
||||
override fun combine(
|
||||
a: StructureND<Double>,
|
||||
b: StructureND<Double>,
|
||||
override fun zip(
|
||||
left: StructureND<Double>,
|
||||
right: StructureND<Double>,
|
||||
transform: DoubleField.(Double, Double) -> Double,
|
||||
): BufferND<Double> {
|
||||
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
||||
DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset])
|
||||
}.toArray()
|
||||
return BufferND(strides, array.asBuffer())
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.structures
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.buffer
|
||||
import space.kscience.kmath.operations.bufferAlgebra
|
||||
import space.kscience.kmath.operations.withSize
|
||||
|
||||
inline fun <reified R : Any> MutableBuffer.Companion.same(
|
||||
n: Int,
|
||||
@ -16,7 +17,7 @@ inline fun <reified R : Any> MutableBuffer.Companion.same(
|
||||
|
||||
|
||||
fun main() {
|
||||
with(DoubleField.bufferAlgebra(5)) {
|
||||
with(DoubleField.bufferAlgebra.withSize(5)) {
|
||||
println(number(2.0) + buffer(1, 2, 3, 4, 5))
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,24 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import edu.mcgill.kaliningraph.power
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||
import org.jetbrains.kotlinx.multik.api.math.exp
|
||||
import org.jetbrains.kotlinx.multik.api.ndarray
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.unaryMinus
|
||||
import space.kscience.kmath.multik.multikND
|
||||
import space.kscience.kmath.nd.one
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
|
||||
fun main(): Unit = with(DoubleField.multikND) {
|
||||
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
|
||||
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0))
|
||||
one(a.shape) - a
|
||||
}
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult
|
||||
import com.github.h0tk3y.betterParse.parser.Parser
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.operations.FieldOperations
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.FieldOps
|
||||
import space.kscience.kmath.operations.GroupOps
|
||||
import space.kscience.kmath.operations.PowerOperations
|
||||
import space.kscience.kmath.operations.RingOperations
|
||||
import space.kscience.kmath.operations.RingOps
|
||||
|
||||
/**
|
||||
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
|
||||
@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
.or(binaryFunction)
|
||||
.or(unaryFunction)
|
||||
.or(singular)
|
||||
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) })
|
||||
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) })
|
||||
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
|
||||
|
||||
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||
@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
operator = div or mul use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == div)
|
||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
||||
MST.Binary(FieldOps.DIV_OPERATION, a, b)
|
||||
else
|
||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
||||
MST.Binary(RingOps.TIMES_OPERATION, a, b)
|
||||
}
|
||||
|
||||
private val subSumChain: Parser<MST> by leftAssociative(
|
||||
@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
||||
operator = plus or minus use TokenMatch::type
|
||||
) { a, op, b ->
|
||||
if (op == plus)
|
||||
MST.Binary(GroupOperations.PLUS_OPERATION, a, b)
|
||||
MST.Binary(GroupOps.PLUS_OPERATION, a, b)
|
||||
else
|
||||
MST.Binary(GroupOperations.MINUS_OPERATION, a, b)
|
||||
MST.Binary(GroupOps.MINUS_OPERATION, a, b)
|
||||
}
|
||||
|
||||
override val rootParser: Parser<MST> by subSumChain
|
||||
|
@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node ->
|
||||
@UnstableKMathAPI
|
||||
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
|
||||
UnaryMinusSyntax(
|
||||
operation = GroupOperations.MINUS_OPERATION,
|
||||
operation = GroupOps.MINUS_OPERATION,
|
||||
operand = OperandSyntax(
|
||||
operand = NumberSyntax(string = s.removePrefix("-")),
|
||||
parentheses = true,
|
||||
@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
|
||||
val exponent = afterE.toDouble().toString().removeSuffix(".0")
|
||||
|
||||
return MultiplicationSyntax(
|
||||
operation = RingOperations.TIMES_OPERATION,
|
||||
operation = RingOps.TIMES_OPERATION,
|
||||
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
|
||||
right = OperandSyntax(
|
||||
operand = SuperscriptSyntax(
|
||||
@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
|
||||
|
||||
if (toString.startsWith('-'))
|
||||
return UnaryMinusSyntax(
|
||||
operation = GroupOperations.MINUS_OPERATION,
|
||||
operation = GroupOps.MINUS_OPERATION,
|
||||
operand = OperandSyntax(operand = infty, parentheses = true),
|
||||
)
|
||||
|
||||
@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
||||
* The default instance configured with [GroupOps.PLUS_OPERATION].
|
||||
*/
|
||||
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
||||
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
||||
* The default instance configured with [GroupOps.MINUS_OPERATION].
|
||||
*/
|
||||
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
||||
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
||||
* The default instance configured with [GroupOps.PLUS_OPERATION].
|
||||
*/
|
||||
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
||||
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
||||
* The default instance configured with [GroupOps.MINUS_OPERATION].
|
||||
*/
|
||||
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
||||
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
@ -295,9 +295,9 @@ public class Fraction(operations: Collection<String>?) : Binary(operations) {
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [FieldOperations.DIV_OPERATION].
|
||||
* The default instance configured with [FieldOps.DIV_OPERATION].
|
||||
*/
|
||||
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
|
||||
public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
@ -422,9 +422,9 @@ public class Multiplication(operations: Collection<String>?) : Binary(operations
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* The default instance configured with [RingOperations.TIMES_OPERATION].
|
||||
* The default instance configured with [RingOps.TIMES_OPERATION].
|
||||
*/
|
||||
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION))
|
||||
public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.FieldOperations
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.FieldOps
|
||||
import space.kscience.kmath.operations.GroupOps
|
||||
import space.kscience.kmath.operations.PowerOperations
|
||||
import space.kscience.kmath.operations.RingOperations
|
||||
import space.kscience.kmath.operations.RingOps
|
||||
|
||||
/**
|
||||
* Removes unnecessary times (×) symbols from [MultiplicationSyntax].
|
||||
@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
|
||||
|
||||
is BinarySyntax -> when (it.operation) {
|
||||
PowerOperations.POW_OPERATION -> 1
|
||||
RingOperations.TIMES_OPERATION -> 3
|
||||
FieldOperations.DIV_OPERATION -> 3
|
||||
GroupOperations.MINUS_OPERATION -> 4
|
||||
GroupOperations.PLUS_OPERATION -> 4
|
||||
RingOps.TIMES_OPERATION -> 3
|
||||
FieldOps.DIV_OPERATION -> 3
|
||||
GroupOps.MINUS_OPERATION -> 4
|
||||
GroupOps.PLUS_OPERATION -> 4
|
||||
else -> 0
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.GroupOps
|
||||
import kotlin.test.Test
|
||||
|
||||
internal class TestLatex {
|
||||
@ -36,7 +36,7 @@ internal class TestLatex {
|
||||
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
|
||||
|
||||
@Test
|
||||
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1")
|
||||
fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1")
|
||||
|
||||
@Test
|
||||
fun unaryMinus() = testLatex("-x", "-x")
|
||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
|
||||
|
||||
import space.kscience.kmath.ast.rendering.TestUtils.testMathML
|
||||
import space.kscience.kmath.expressions.MST
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.GroupOps
|
||||
import kotlin.test.Test
|
||||
|
||||
internal class TestMathML {
|
||||
@ -47,7 +47,7 @@ internal class TestMathML {
|
||||
|
||||
@Test
|
||||
fun unaryPlus() =
|
||||
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
|
||||
testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
|
||||
|
||||
@Test
|
||||
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")
|
||||
|
@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
|
||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
||||
|
||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
||||
GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
||||
GroupOperations.PLUS_OPERATION -> visit(mst.value)
|
||||
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
||||
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
||||
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
|
||||
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
|
||||
@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
|
||||
}
|
||||
|
||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||
GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||
GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
||||
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||
FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
||||
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
|
||||
else -> super.visitBinary(mst)
|
||||
}
|
||||
@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, targ
|
||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
||||
|
||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
||||
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
||||
GroupOperations.PLUS_OPERATION -> visit(mst.value)
|
||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
||||
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
||||
else -> super.visitUnary(mst)
|
||||
}
|
||||
|
||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||
GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||
else -> super.visitBinary(mst)
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
||||
import space.kscience.kmath.expressions.*
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.ExtendedField
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.NumbersAddOps
|
||||
|
||||
/**
|
||||
* A field over commons-math [DerivativeStructure].
|
||||
@ -22,7 +22,7 @@ public class DerivativeStructureField(
|
||||
public val order: Int,
|
||||
bindings: Map<Symbol, Double>,
|
||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
||||
NumbersAddOperations<DerivativeStructure> {
|
||||
NumbersAddOps<DerivativeStructure> {
|
||||
public val numberOfVariables: Int = bindings.size
|
||||
|
||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||
@ -70,12 +70,12 @@ public class DerivativeStructureField(
|
||||
|
||||
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
|
||||
|
||||
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b)
|
||||
override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right)
|
||||
|
||||
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
|
||||
|
||||
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b)
|
||||
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b)
|
||||
override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right)
|
||||
override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
|
||||
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
|
||||
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
|
||||
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
|
||||
@ -99,10 +99,10 @@ public class DerivativeStructureField(
|
||||
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
|
||||
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
|
||||
|
||||
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble())
|
||||
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble())
|
||||
override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this
|
||||
override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this
|
||||
override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure = add(other.toDouble())
|
||||
override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = subtract(other.toDouble())
|
||||
override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this
|
||||
override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
|
||||
public object ComplexField :
|
||||
ExtendedField<Complex>,
|
||||
Norm<Complex, Complex>,
|
||||
NumbersAddOperations<Complex>,
|
||||
NumbersAddOps<Complex>,
|
||||
ScaleOperations<Complex> {
|
||||
|
||||
override val zero: Complex = 0.0.toComplex()
|
||||
@ -77,33 +77,33 @@ public object ComplexField :
|
||||
|
||||
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
|
||||
|
||||
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||
override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im)
|
||||
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
|
||||
|
||||
override fun multiply(a: Complex, b: Complex): Complex =
|
||||
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||
override fun multiply(left: Complex, right: Complex): Complex =
|
||||
Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re)
|
||||
|
||||
override fun divide(a: Complex, b: Complex): Complex = when {
|
||||
abs(b.im) < abs(b.re) -> {
|
||||
val wr = b.im / b.re
|
||||
val wd = b.re + wr * b.im
|
||||
override fun divide(left: Complex, right: Complex): Complex = when {
|
||||
abs(right.im) < abs(right.re) -> {
|
||||
val wr = right.im / right.re
|
||||
val wd = right.re + wr * right.im
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
throw ArithmeticException("Division by zero or infinity")
|
||||
else
|
||||
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd)
|
||||
Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd)
|
||||
}
|
||||
|
||||
b.im == 0.0 -> throw ArithmeticException("Division by zero")
|
||||
right.im == 0.0 -> throw ArithmeticException("Division by zero")
|
||||
|
||||
else -> {
|
||||
val wr = b.re / b.im
|
||||
val wd = b.im + wr * b.re
|
||||
val wr = right.re / right.im
|
||||
val wd = right.im + wr * right.re
|
||||
|
||||
if (wd.isNaN() || wd == 0.0)
|
||||
throw ArithmeticException("Division by zero or infinity")
|
||||
else
|
||||
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd)
|
||||
Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd)
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) {
|
||||
|
||||
public val Complex.Companion.algebra: ComplexField get() = ComplexField
|
||||
|
||||
|
||||
/**
|
||||
* Creates a complex number with real part equal to this real.
|
||||
*
|
||||
|
@ -6,13 +6,8 @@
|
||||
package space.kscience.kmath.complex
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.BufferND
|
||||
import space.kscience.kmath.nd.BufferedFieldND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.BufferField
|
||||
import space.kscience.kmath.operations.ExtendedField
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.bufferAlgebra
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
@ -22,100 +17,61 @@ import kotlin.contracts.contract
|
||||
* An optimized nd-field for complex numbers
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class ComplexFieldND(
|
||||
shape: IntArray,
|
||||
) : BufferedFieldND<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
|
||||
NumbersAddOperations<StructureND<Complex>>,
|
||||
ExtendedField<StructureND<Complex>> {
|
||||
public sealed class ComplexFieldOpsND : BufferedFieldOpsND<Complex, ComplexField>(ComplexField.bufferAlgebra),
|
||||
ScaleOperations<StructureND<Complex>>, ExtendedFieldOps<StructureND<Complex>> {
|
||||
|
||||
override val zero: BufferND<Complex> by lazy { produce { zero } }
|
||||
override val one: BufferND<Complex> by lazy { produce { one } }
|
||||
|
||||
override fun number(value: Number): BufferND<Complex> {
|
||||
val d = value.toComplex() // minimize conversions
|
||||
return produce { d }
|
||||
override fun StructureND<Complex>.toBufferND(): BufferND<Complex> = when (this) {
|
||||
is BufferND -> this
|
||||
else -> {
|
||||
val indexer = indexerBuilder(shape)
|
||||
BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// @Suppress("OVERRIDE_BY_INLINE")
|
||||
// override inline fun map(
|
||||
// arg: AbstractNDBuffer<Double>,
|
||||
// transform: DoubleField.(Double) -> Double,
|
||||
// ): RealNDElement {
|
||||
// check(arg)
|
||||
// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) }
|
||||
// return BufferedNDFieldElement(this, array)
|
||||
// }
|
||||
//
|
||||
// @Suppress("OVERRIDE_BY_INLINE")
|
||||
// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
|
||||
// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
||||
// return BufferedNDFieldElement(this, array)
|
||||
// }
|
||||
//
|
||||
// @Suppress("OVERRIDE_BY_INLINE")
|
||||
// override inline fun mapIndexed(
|
||||
// arg: AbstractNDBuffer<Double>,
|
||||
// transform: DoubleField.(index: IntArray, Double) -> Double,
|
||||
// ): RealNDElement {
|
||||
// check(arg)
|
||||
// return BufferedNDFieldElement(
|
||||
// this,
|
||||
// RealBuffer(arg.strides.linearSize) { offset ->
|
||||
// elementContext.transform(
|
||||
// arg.strides.index(offset),
|
||||
// arg.buffer[offset]
|
||||
// )
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// @Suppress("OVERRIDE_BY_INLINE")
|
||||
// override inline fun combine(
|
||||
// a: AbstractNDBuffer<Double>,
|
||||
// b: AbstractNDBuffer<Double>,
|
||||
// transform: DoubleField.(Double, Double) -> Double,
|
||||
// ): RealNDElement {
|
||||
// check(a, b)
|
||||
// val buffer = RealBuffer(strides.linearSize) { offset ->
|
||||
// elementContext.transform(a.buffer[offset], b.buffer[offset])
|
||||
// }
|
||||
// return BufferedNDFieldElement(this, buffer)
|
||||
// }
|
||||
//TODO do specialization
|
||||
|
||||
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) }
|
||||
override fun scale(a: StructureND<Complex>, value: Double): BufferND<Complex> =
|
||||
mapInline(a.toBufferND()) { it * value }
|
||||
|
||||
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) }
|
||||
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> =
|
||||
mapInline(arg.toBufferND()) { power(it, pow) }
|
||||
|
||||
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) }
|
||||
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { exp(it) }
|
||||
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { ln(it) }
|
||||
|
||||
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) }
|
||||
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) }
|
||||
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) }
|
||||
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) }
|
||||
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) }
|
||||
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) }
|
||||
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sin(it) }
|
||||
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cos(it) }
|
||||
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tan(it) }
|
||||
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asin(it) }
|
||||
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acos(it) }
|
||||
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atan(it) }
|
||||
|
||||
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) }
|
||||
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) }
|
||||
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) }
|
||||
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) }
|
||||
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) }
|
||||
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) }
|
||||
}
|
||||
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sinh(it) }
|
||||
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cosh(it) }
|
||||
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tanh(it) }
|
||||
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asinh(it) }
|
||||
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acosh(it) }
|
||||
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atanh(it) }
|
||||
|
||||
|
||||
/**
|
||||
* Fast element production using function inlining
|
||||
*/
|
||||
public inline fun BufferedFieldND<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND<Complex> {
|
||||
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
|
||||
return BufferND(strides, buffer)
|
||||
public companion object : ComplexFieldOpsND()
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun ComplexField.bufferAlgebra(size: Int): BufferField<Complex, ComplexField> =
|
||||
bufferAlgebra(Buffer.Companion::complex, size)
|
||||
public val ComplexField.bufferAlgebra: BufferFieldOps<Complex, ComplexField>
|
||||
get() = bufferAlgebra(Buffer.Companion::complex)
|
||||
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class ComplexFieldND(override val shape: Shape) :
|
||||
ComplexFieldOpsND(), FieldND<Complex, ComplexField>, NumbersAddOps<StructureND<Complex>> {
|
||||
|
||||
override fun number(value: Number): BufferND<Complex> {
|
||||
val d = value.toDouble() // minimize conversions
|
||||
return structureND(shape) { d.toComplex() }
|
||||
}
|
||||
}
|
||||
|
||||
public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND
|
||||
|
||||
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)
|
||||
|
||||
|
@ -44,7 +44,7 @@ public val Quaternion.r: Double
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
|
||||
ExponentialOperations<Quaternion>, NumbersAddOperations<Quaternion>, ScaleOperations<Quaternion> {
|
||||
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
||||
override val zero: Quaternion = 0.toQuaternion()
|
||||
override val one: Quaternion = 1.toQuaternion()
|
||||
|
||||
@ -63,27 +63,27 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
||||
*/
|
||||
public val k: Quaternion = Quaternion(0, 0, 0, 1)
|
||||
|
||||
override fun add(a: Quaternion, b: Quaternion): Quaternion =
|
||||
Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z)
|
||||
override fun add(left: Quaternion, right: Quaternion): Quaternion =
|
||||
Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
|
||||
|
||||
override fun scale(a: Quaternion, value: Double): Quaternion =
|
||||
Quaternion(a.w * value, a.x * value, a.y * value, a.z * value)
|
||||
|
||||
override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion(
|
||||
a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z,
|
||||
a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y,
|
||||
a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x,
|
||||
a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w,
|
||||
override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion(
|
||||
left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z,
|
||||
left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y,
|
||||
left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x,
|
||||
left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w,
|
||||
)
|
||||
|
||||
override fun divide(a: Quaternion, b: Quaternion): Quaternion {
|
||||
val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z
|
||||
override fun divide(left: Quaternion, right: Quaternion): Quaternion {
|
||||
val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z
|
||||
|
||||
return Quaternion(
|
||||
(b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s,
|
||||
(b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s,
|
||||
(b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s,
|
||||
(b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s,
|
||||
(right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s,
|
||||
(right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s,
|
||||
(right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s,
|
||||
(right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s,
|
||||
)
|
||||
}
|
||||
|
||||
@ -158,16 +158,16 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
|
||||
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
|
||||
}
|
||||
|
||||
override operator fun Number.plus(b: Quaternion): Quaternion = Quaternion(toDouble() + b.w, b.x, b.y, b.z)
|
||||
override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z)
|
||||
|
||||
override operator fun Number.minus(b: Quaternion): Quaternion =
|
||||
Quaternion(toDouble() - b.w, -b.x, -b.y, -b.z)
|
||||
override operator fun Number.minus(other: Quaternion): Quaternion =
|
||||
Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
|
||||
|
||||
override operator fun Quaternion.plus(b: Number): Quaternion = Quaternion(w + b.toDouble(), x, y, z)
|
||||
override operator fun Quaternion.minus(b: Number): Quaternion = Quaternion(w - b.toDouble(), x, y, z)
|
||||
override operator fun Quaternion.plus(other: Number): Quaternion = Quaternion(w + other.toDouble(), x, y, z)
|
||||
override operator fun Quaternion.minus(other: Number): Quaternion = Quaternion(w - other.toDouble(), x, y, z)
|
||||
|
||||
override operator fun Number.times(b: Quaternion): Quaternion =
|
||||
Quaternion(toDouble() * b.w, toDouble() * b.x, toDouble() * b.y, toDouble() * b.z)
|
||||
override operator fun Number.times(other: Quaternion): Quaternion =
|
||||
Quaternion(toDouble() * other.w, toDouble() * other.x, toDouble() * other.y, toDouble() * other.z)
|
||||
|
||||
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
|
||||
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)
|
||||
|
@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
|
||||
override val zero: Expression<T> get() = const(algebra.zero)
|
||||
|
||||
override fun Expression<T>.unaryMinus(): Expression<T> =
|
||||
unaryOperation(GroupOperations.MINUS_OPERATION, this)
|
||||
unaryOperation(GroupOps.MINUS_OPERATION, this)
|
||||
|
||||
/**
|
||||
* Builds an Expression of addition of two another expressions.
|
||||
*/
|
||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperation(GroupOperations.PLUS_OPERATION, a, b)
|
||||
override fun add(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
binaryOperation(GroupOps.PLUS_OPERATION, left, right)
|
||||
|
||||
// /**
|
||||
// * Builds an Expression of multiplication of expression by number.
|
||||
@ -88,8 +88,8 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
|
||||
/**
|
||||
* Builds an Expression of multiplication of two expressions.
|
||||
*/
|
||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
override fun multiply(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
|
||||
|
||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||
@ -107,8 +107,8 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
|
||||
/**
|
||||
* Builds an Expression of division an expression by another one.
|
||||
*/
|
||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
override fun divide(left: Expression<T>, right: Expression<T>): Expression<T> =
|
||||
binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
|
||||
|
||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||
|
@ -31,18 +31,18 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b)
|
||||
override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right)
|
||||
override operator fun MST.unaryPlus(): MST.Unary =
|
||||
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this)
|
||||
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
|
||||
|
||||
override operator fun MST.unaryMinus(): MST.Unary =
|
||||
unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this)
|
||||
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
|
||||
|
||||
override operator fun MST.minus(b: MST): MST.Binary =
|
||||
binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b)
|
||||
override operator fun MST.minus(other: MST): MST.Binary =
|
||||
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, other)
|
||||
|
||||
override fun scale(a: MST, value: Double): MST.Binary =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value))
|
||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstNumericAlgebra.binaryOperationFunction(operation)
|
||||
@ -56,23 +56,23 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
||||
*/
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
|
||||
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||
override inline val zero: MST.Numeric get() = MstGroup.zero
|
||||
override val one: MST.Numeric = number(1.0)
|
||||
|
||||
override fun number(value: Number): MST.Numeric = MstGroup.number(value)
|
||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
|
||||
override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right)
|
||||
|
||||
override fun scale(a: MST, value: Double): MST.Binary =
|
||||
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
||||
override fun multiply(left: MST, right: MST): MST.Binary =
|
||||
binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
|
||||
|
||||
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
|
||||
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
|
||||
override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b }
|
||||
override operator fun MST.minus(other: MST): MST.Binary = MstGroup { this@minus - other }
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstGroup.binaryOperationFunction(operation)
|
||||
@ -86,24 +86,24 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
|
||||
*/
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
|
||||
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||
override inline val zero: MST.Numeric get() = MstRing.zero
|
||||
override inline val one: MST.Numeric get() = MstRing.one
|
||||
|
||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||
override fun number(value: Number): MST.Numeric = MstRing.number(value)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||
override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right)
|
||||
|
||||
override fun scale(a: MST, value: Double): MST.Binary =
|
||||
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST.Binary =
|
||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
||||
override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right)
|
||||
override fun divide(left: MST, right: MST): MST.Binary =
|
||||
binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
|
||||
|
||||
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
||||
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
||||
override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b }
|
||||
override operator fun MST.minus(other: MST): MST.Binary = MstRing { this@minus - other }
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||
MstRing.binaryOperationFunction(operation)
|
||||
@ -134,17 +134,17 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
||||
override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg)
|
||||
override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
|
||||
override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
|
||||
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b)
|
||||
override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right)
|
||||
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
||||
|
||||
override fun scale(a: MST, value: Double): MST =
|
||||
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value))
|
||||
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
|
||||
|
||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||
override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right)
|
||||
override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right)
|
||||
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
|
||||
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
|
||||
override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b }
|
||||
override operator fun MST.minus(other: MST): MST.Binary = MstField { this@minus - other }
|
||||
|
||||
override fun power(arg: MST, pow: Number): MST.Binary =
|
||||
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))
|
||||
|
@ -59,7 +59,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
public val context: F,
|
||||
bindings: Map<Symbol, T>,
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOperations<AutoDiffValue<T>> {
|
||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
|
||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||
|
||||
@ -168,22 +168,22 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||
|
||||
// Basic math (+, -, *, /)
|
||||
|
||||
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value + b.value }) { z ->
|
||||
a.d += z.d
|
||||
b.d += z.d
|
||||
override fun add(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { left.value + right.value }) { z ->
|
||||
left.d += z.d
|
||||
right.d += z.d
|
||||
}
|
||||
|
||||
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value * b.value }) { z ->
|
||||
a.d += z.d * b.value
|
||||
b.d += z.d * a.value
|
||||
override fun multiply(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { left.value * right.value }) { z ->
|
||||
left.d += z.d * right.value
|
||||
right.d += z.d * left.value
|
||||
}
|
||||
|
||||
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { a.value / b.value }) { z ->
|
||||
a.d += z.d / b.value
|
||||
b.d -= z.d * a.value / (b.value * b.value)
|
||||
override fun divide(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
|
||||
derive(const { left.value / right.value }) { z ->
|
||||
left.d += z.d / right.value
|
||||
right.d -= z.d * left.value / (right.value * right.value)
|
||||
}
|
||||
|
||||
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> =
|
||||
|
@ -6,12 +6,10 @@
|
||||
package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.BufferedRingND
|
||||
import space.kscience.kmath.nd.BufferedRingOpsND
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.nd.asND
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.VirtualBuffer
|
||||
@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices
|
||||
|
||||
|
||||
public class BufferedLinearSpace<T, out A : Ring<T>>(
|
||||
override val elementAlgebra: A,
|
||||
private val bufferFactory: BufferFactory<T>,
|
||||
private val bufferAlgebra: BufferAlgebra<T, A>
|
||||
) : LinearSpace<T, A> {
|
||||
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||
|
||||
private fun ndRing(
|
||||
rows: Int,
|
||||
cols: Int,
|
||||
): BufferedRingND<T, A> = elementAlgebra.ndAlgebra(bufferFactory, rows, cols)
|
||||
private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
|
||||
|
||||
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
|
||||
ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
|
||||
ndAlgebra.structureND(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
|
||||
|
||||
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
|
||||
bufferFactory(size) { elementAlgebra.initializer(it) }
|
||||
bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }
|
||||
|
||||
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndAlgebra {
|
||||
asND().map { -it }.as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
@ -88,11 +83,11 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
|
||||
}
|
||||
}
|
||||
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<T>.times(value: T): Matrix<T> = ndAlgebra {
|
||||
asND().map { it * value }.as2D()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
|
||||
BufferedLinearSpace(this, bufferFactory)
|
||||
BufferedLinearSpace(BufferRingOps(this, bufferFactory))
|
||||
|
@ -6,11 +6,12 @@
|
||||
package space.kscience.kmath.linear
|
||||
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.DoubleFieldND
|
||||
import space.kscience.kmath.nd.DoubleFieldOpsND
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.nd.asND
|
||||
import space.kscience.kmath.operations.DoubleBufferOperations
|
||||
import space.kscience.kmath.operations.DoubleBufferOps
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
|
||||
@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
|
||||
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
private fun ndRing(
|
||||
rows: Int,
|
||||
cols: Int,
|
||||
): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols))
|
||||
|
||||
override fun buildMatrix(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
initializer: DoubleField.(i: Int, j: Int) -> Double
|
||||
): Matrix<Double> = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D()
|
||||
): Matrix<Double> = DoubleFieldOpsND.structureND(intArrayOf(rows, columns)) { (i, j) ->
|
||||
DoubleField.initializer(i, j)
|
||||
}.as2D()
|
||||
|
||||
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
|
||||
DoubleBuffer(size) { DoubleField.initializer(it) }
|
||||
|
||||
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = DoubleFieldOpsND {
|
||||
asND().map { -it }.as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||
asND().plus(other.asND()).as2D()
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
|
||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||
asND().minus(other.asND()).as2D()
|
||||
}
|
||||
@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
|
||||
|
||||
}
|
||||
|
||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> = ndRing(rowNum, colNum).run {
|
||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> = DoubleFieldOpsND {
|
||||
asND().map { it * value }.as2D()
|
||||
}
|
||||
|
||||
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
|
||||
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
|
||||
this@plus + other
|
||||
}
|
||||
|
||||
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
|
||||
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
|
||||
this@minus - other
|
||||
}
|
||||
|
||||
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOperations.run {
|
||||
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOps.run {
|
||||
scale(this@times, value)
|
||||
}
|
||||
|
||||
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOperations.run {
|
||||
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOps.run {
|
||||
scale(this@div, 1.0 / value)
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.nd.StructureFeature
|
||||
import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.operations.BufferRingOps
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.operations.invoke
|
||||
@ -188,7 +189,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
|
||||
public fun <T : Any, A : Ring<T>> buffered(
|
||||
algebra: A,
|
||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
): LinearSpace<T, A> = BufferedLinearSpace(algebra, bufferFactory)
|
||||
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
|
||||
|
||||
@Deprecated("use DoubleField.linearSpace")
|
||||
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)
|
||||
|
@ -7,7 +7,6 @@ package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.*
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
@ -19,6 +18,14 @@ import kotlin.reflect.KClass
|
||||
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||
|
||||
public typealias Shape = IntArray
|
||||
|
||||
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
||||
|
||||
public interface WithShape {
|
||||
public val shape: Shape
|
||||
}
|
||||
|
||||
/**
|
||||
* The base interface for all ND-algebra implementations.
|
||||
*
|
||||
@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
|
||||
* @param C the type of the element context.
|
||||
*/
|
||||
public interface AlgebraND<T, out C : Algebra<T>> {
|
||||
/**
|
||||
* The shape of ND-structures this algebra operates on.
|
||||
*/
|
||||
public val shape: IntArray
|
||||
|
||||
/**
|
||||
* The algebra over elements of ND structure.
|
||||
*/
|
||||
public val elementContext: C
|
||||
public val elementAlgebra: C
|
||||
|
||||
/**
|
||||
* Produces a new NDStructure using given initializer function.
|
||||
* Produces a new [StructureND] using given initializer function.
|
||||
*/
|
||||
public fun produce(initializer: C.(IntArray) -> T): StructureND<T>
|
||||
public fun structureND(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
|
||||
|
||||
/**
|
||||
* Maps elements from one structure to another one by applying [transform] to them.
|
||||
@ -54,7 +56,7 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
||||
/**
|
||||
* Combines two structures into one.
|
||||
*/
|
||||
public fun combine(a: StructureND<T>, b: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
|
||||
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
|
||||
|
||||
/**
|
||||
* Element-wise invocation of function working on [T] on a [StructureND].
|
||||
@ -77,7 +79,6 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
||||
public companion object
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get a feature of the structure in this scope. Structure features take precedence other context features.
|
||||
*
|
||||
@ -89,46 +90,22 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
||||
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
|
||||
getFeature(structure, F::class)
|
||||
|
||||
/**
|
||||
* Checks if given elements are consistent with this context.
|
||||
*
|
||||
* @param structures the structures to check.
|
||||
* @return the array of valid structures.
|
||||
*/
|
||||
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(vararg structures: StructureND<T>): Array<out StructureND<T>> =
|
||||
structures
|
||||
.map(StructureND<T>::shape)
|
||||
.singleOrNull { !shape.contentEquals(it) }
|
||||
?.let<IntArray, Array<out StructureND<T>>> { throw ShapeMismatchException(shape, it) }
|
||||
?: structures
|
||||
|
||||
/**
|
||||
* Checks if given element is consistent with this context.
|
||||
*
|
||||
* @param element the structure to check.
|
||||
* @return the valid structure.
|
||||
*/
|
||||
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND<T>): StructureND<T> {
|
||||
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
|
||||
return element
|
||||
}
|
||||
|
||||
/**
|
||||
* Space of [StructureND].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of group over structure elements.
|
||||
* @param A the type of group over structure elements.
|
||||
*/
|
||||
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> {
|
||||
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
|
||||
/**
|
||||
* Element-wise addition.
|
||||
*
|
||||
* @param a the augend.
|
||||
* @param b the addend.
|
||||
* @param left the augend.
|
||||
* @param right the addend.
|
||||
* @return the sum.
|
||||
*/
|
||||
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||
combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
||||
override fun add(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||
zip(left, right) { aValue, bValue -> add(aValue, bValue) }
|
||||
|
||||
// TODO move to extensions after KEEP-176
|
||||
|
||||
@ -157,7 +134,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
|
||||
* @param arg the addend.
|
||||
* @return the sum.
|
||||
*/
|
||||
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(this@plus, value) }
|
||||
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this
|
||||
|
||||
/**
|
||||
* Subtracts an ND structure from an element of it.
|
||||
@ -171,22 +148,26 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
|
||||
public companion object
|
||||
}
|
||||
|
||||
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
|
||||
override val zero: StructureND<T> get() = structureND(shape) { elementAlgebra.zero }
|
||||
}
|
||||
|
||||
/**
|
||||
* Ring of [StructureND].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param R the type of ring over structure elements.
|
||||
* @param A the type of ring over structure elements.
|
||||
*/
|
||||
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> {
|
||||
public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, GroupOpsND<T, A> {
|
||||
/**
|
||||
* Element-wise multiplication.
|
||||
*
|
||||
* @param a the multiplicand.
|
||||
* @param b the multiplier.
|
||||
* @param left the multiplicand.
|
||||
* @param right the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||
zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
|
||||
@ -211,24 +192,32 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
|
||||
public companion object
|
||||
}
|
||||
|
||||
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
|
||||
override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Field of [StructureND].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param F the type field over structure elements.
|
||||
* @param A the type field over structure elements.
|
||||
*/
|
||||
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F> {
|
||||
public interface FieldOpsND<T, out A : Field<T>> :
|
||||
FieldOps<StructureND<T>>,
|
||||
RingOpsND<T, A>,
|
||||
ScaleOperations<StructureND<T>> {
|
||||
/**
|
||||
* Element-wise division.
|
||||
*
|
||||
* @param a the dividend.
|
||||
* @param b the divisor.
|
||||
* @param left the dividend.
|
||||
* @param right the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||
combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||
override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
|
||||
zip(left, right) { aValue, bValue -> divide(aValue, bValue) }
|
||||
|
||||
//TODO move to extensions after KEEP-176
|
||||
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
|
||||
/**
|
||||
* Divides an ND structure by an element of it.
|
||||
*
|
||||
@ -247,42 +236,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
|
||||
*/
|
||||
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
|
||||
|
||||
/**
|
||||
* Element-wise scaling.
|
||||
*
|
||||
* @param a the multiplicand.
|
||||
* @param value the multiplier.
|
||||
* @return the product.
|
||||
*/
|
||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
|
||||
|
||||
// @ThreadLocal
|
||||
// public companion object {
|
||||
// private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
||||
//
|
||||
// /**
|
||||
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
|
||||
// */
|
||||
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
||||
//
|
||||
// /**
|
||||
// * Create an ND field with boxing generic buffer.
|
||||
// */
|
||||
// public fun <T : Any, F : Field<T>> boxing(
|
||||
// field: F,
|
||||
// vararg shape: Int,
|
||||
// bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||
// ): BufferedNDField<T, F> = BufferedNDField(shape, field, bufferFactory)
|
||||
//
|
||||
// /**
|
||||
// * Create a most suitable implementation for nd-field using reified class.
|
||||
// */
|
||||
// @Suppress("UNCHECKED_CAST")
|
||||
// public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): NDField<T, F> =
|
||||
// when {
|
||||
// T::class == Double::class -> real(*shape) as NDField<T, F>
|
||||
// T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
|
||||
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
|
||||
override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
|
||||
}
|
@ -3,145 +3,177 @@
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
@file:OptIn(UnstableKMathAPI::class)
|
||||
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||
public val strides: Strides
|
||||
public val bufferFactory: BufferFactory<T>
|
||||
public val indexerBuilder: (IntArray) -> ShapeIndex
|
||||
public val bufferAlgebra: BufferAlgebra<T, A>
|
||||
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||
|
||||
override fun produce(initializer: A.(IntArray) -> T): BufferND<T> = BufferND(
|
||||
strides,
|
||||
bufferFactory(strides.linearSize) { offset ->
|
||||
elementContext.initializer(strides.index(offset))
|
||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
|
||||
val indexer = indexerBuilder(shape)
|
||||
return BufferND(
|
||||
indexer,
|
||||
bufferAlgebra.buffer(indexer.linearSize) { offset ->
|
||||
elementAlgebra.initializer(indexer.index(offset))
|
||||
}
|
||||
)
|
||||
|
||||
public val StructureND<T>.buffer: Buffer<T>
|
||||
get() = when {
|
||||
!shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException(
|
||||
this@BufferAlgebraND.shape,
|
||||
shape
|
||||
)
|
||||
this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer
|
||||
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||
}
|
||||
|
||||
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> {
|
||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||
elementContext.transform(buffer[offset])
|
||||
public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
|
||||
is BufferND -> this
|
||||
else -> {
|
||||
val indexer = indexerBuilder(shape)
|
||||
BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||
}
|
||||
return BufferND(strides, buffer)
|
||||
}
|
||||
|
||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> {
|
||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||
elementContext.transform(
|
||||
strides.index(offset),
|
||||
buffer[offset]
|
||||
)
|
||||
}
|
||||
return BufferND(strides, buffer)
|
||||
}
|
||||
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> = mapInline(toBufferND(), transform)
|
||||
|
||||
override fun combine(a: StructureND<T>, b: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> {
|
||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||
elementContext.transform(a.buffer[offset], b.buffer[offset])
|
||||
}
|
||||
return BufferND(strides, buffer)
|
||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> =
|
||||
mapIndexedInline(toBufferND(), transform)
|
||||
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> =
|
||||
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||
|
||||
public companion object {
|
||||
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
|
||||
}
|
||||
}
|
||||
|
||||
public open class BufferedGroupND<T, out A : Group<T>>(
|
||||
final override val shape: IntArray,
|
||||
final override val elementContext: A,
|
||||
final override val bufferFactory: BufferFactory<T>,
|
||||
) : GroupND<T, A>, BufferAlgebraND<T, A> {
|
||||
override val strides: Strides = DefaultStrides(shape)
|
||||
override val zero: BufferND<T> by lazy { produce { zero } }
|
||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
|
||||
public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
|
||||
arg: BufferND<T>,
|
||||
crossinline transform: A.(T) -> T
|
||||
): BufferND<T> {
|
||||
val indexes = arg.indexes
|
||||
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
|
||||
}
|
||||
|
||||
public open class BufferedRingND<T, out R : Ring<T>>(
|
||||
shape: IntArray,
|
||||
elementContext: R,
|
||||
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
|
||||
arg: BufferND<T>,
|
||||
crossinline transform: A.(index: IntArray, arg: T) -> T
|
||||
): BufferND<T> {
|
||||
val indexes = arg.indexes
|
||||
return BufferND(
|
||||
indexes,
|
||||
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
|
||||
transform(indexes.index(offset), value)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
|
||||
l: BufferND<T>,
|
||||
r: BufferND<T>,
|
||||
crossinline block: A.(l: T, r: T) -> T
|
||||
): BufferND<T> {
|
||||
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
|
||||
val indexes = l.indexes
|
||||
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
|
||||
}
|
||||
|
||||
public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
||||
override val bufferAlgebra: BufferAlgebra<T, A>,
|
||||
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
||||
}
|
||||
|
||||
public open class BufferedRingOpsND<T, out A : Ring<T>>(
|
||||
bufferAlgebra: BufferAlgebra<T, A>,
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
|
||||
|
||||
public open class BufferedFieldOpsND<T, out A : Field<T>>(
|
||||
bufferAlgebra: BufferAlgebra<T, A>,
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
|
||||
|
||||
public constructor(
|
||||
elementAlgebra: A,
|
||||
bufferFactory: BufferFactory<T>,
|
||||
) : BufferedGroupND<T, R>(shape, elementContext, bufferFactory), RingND<T, R> {
|
||||
override val one: BufferND<T> by lazy { produce { one } }
|
||||
}
|
||||
|
||||
public open class BufferedFieldND<T, out R : Field<T>>(
|
||||
shape: IntArray,
|
||||
elementContext: R,
|
||||
bufferFactory: BufferFactory<T>,
|
||||
) : BufferedRingND<T, R>(shape, elementContext, bufferFactory), FieldND<T, R> {
|
||||
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
|
||||
|
||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
||||
}
|
||||
|
||||
// group factories
|
||||
public fun <T, A : Group<T>> A.ndAlgebra(
|
||||
bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
): BufferedGroupND<T, A> = BufferedGroupND(shape, this, bufferFactory)
|
||||
public val <T, A : Group<T>> BufferAlgebra<T, A>.nd: BufferedGroupNDOps<T, A> get() = BufferedGroupNDOps(this)
|
||||
public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get() = BufferedRingOpsND(this)
|
||||
public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
|
||||
|
||||
@JvmName("withNdGroup")
|
||||
public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedGroupND<T, A>.() -> R,
|
||||
): R {
|
||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
}
|
||||
|
||||
//ring factories
|
||||
public fun <T, A : Ring<T>> A.ndAlgebra(
|
||||
bufferFactory: BufferFactory<T>,
|
||||
public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.structureND(
|
||||
vararg shape: Int,
|
||||
): BufferedRingND<T, A> = BufferedRingND(shape, this, bufferFactory)
|
||||
initializer: A.(IntArray) -> T
|
||||
): BufferND<T> = structureND(shape, initializer)
|
||||
|
||||
@JvmName("withNdRing")
|
||||
public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedRingND<T, A>.() -> R,
|
||||
): R {
|
||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
}
|
||||
public fun <T, EA : Algebra<T>, A> A.structureND(
|
||||
initializer: EA.(IntArray) -> T
|
||||
): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = structureND(shape, initializer)
|
||||
|
||||
//field factories
|
||||
public fun <T, A : Field<T>> A.ndAlgebra(
|
||||
bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
): BufferedFieldND<T, A> = BufferedFieldND(shape, this, bufferFactory)
|
||||
//// group factories
|
||||
//public fun <T, A : Group<T>> A.ndAlgebra(
|
||||
// bufferAlgebra: BufferAlgebra<T, A>,
|
||||
// vararg shape: Int,
|
||||
//): BufferedGroupNDOps<T, A> = BufferedGroupNDOps(bufferAlgebra)
|
||||
//
|
||||
//@JvmName("withNdGroup")
|
||||
//public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
|
||||
// noinline bufferFactory: BufferFactory<T>,
|
||||
// vararg shape: Int,
|
||||
// action: BufferedGroupNDOps<T, A>.() -> R,
|
||||
//): R {
|
||||
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
//}
|
||||
|
||||
/**
|
||||
* Create a [FieldND] for this [Field] inferring proper buffer factory from the type
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
|
||||
vararg shape: Int,
|
||||
): FieldND<T, A> = when (this) {
|
||||
DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
|
||||
else -> BufferedFieldND(shape, this, Buffer.Companion::auto)
|
||||
}
|
||||
|
||||
@JvmName("withNdField")
|
||||
public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedFieldND<T, A>.() -> R,
|
||||
): R {
|
||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
}
|
||||
////ring factories
|
||||
//public fun <T, A : Ring<T>> A.ndAlgebra(
|
||||
// bufferFactory: BufferFactory<T>,
|
||||
// vararg shape: Int,
|
||||
//): BufferedRingNDOps<T, A> = BufferedRingNDOps(shape, this, bufferFactory)
|
||||
//
|
||||
//@JvmName("withNdRing")
|
||||
//public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
|
||||
// noinline bufferFactory: BufferFactory<T>,
|
||||
// vararg shape: Int,
|
||||
// action: BufferedRingNDOps<T, A>.() -> R,
|
||||
//): R {
|
||||
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
//}
|
||||
//
|
||||
////field factories
|
||||
//public fun <T, A : Field<T>> A.ndAlgebra(
|
||||
// bufferFactory: BufferFactory<T>,
|
||||
// vararg shape: Int,
|
||||
//): BufferedFieldNDOps<T, A> = BufferedFieldNDOps(shape, this, bufferFactory)
|
||||
//
|
||||
///**
|
||||
// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type
|
||||
// */
|
||||
//@UnstableKMathAPI
|
||||
//@Suppress("UNCHECKED_CAST")
|
||||
//public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
|
||||
// vararg shape: Int,
|
||||
//): FieldND<T, A> = when (this) {
|
||||
// DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
|
||||
// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto)
|
||||
//}
|
||||
//
|
||||
//@JvmName("withNdField")
|
||||
//public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
|
||||
// noinline bufferFactory: BufferFactory<T>,
|
||||
// vararg shape: Int,
|
||||
// action: BufferedFieldNDOps<T, A>.() -> R,
|
||||
//): R {
|
||||
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||
//}
|
@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
||||
* Represents [StructureND] over [Buffer].
|
||||
*
|
||||
* @param T the type of items.
|
||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
||||
* @param indexes The strides to access elements of [Buffer] by linear indices.
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public open class BufferND<out T>(
|
||||
public val strides: Strides,
|
||||
public val buffer: Buffer<T>,
|
||||
public val indexes: ShapeIndex,
|
||||
public open val buffer: Buffer<T>,
|
||||
) : StructureND<T> {
|
||||
|
||||
init {
|
||||
if (strides.linearSize != buffer.size) {
|
||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
||||
}
|
||||
}
|
||||
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
|
||||
|
||||
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
||||
|
||||
override val shape: IntArray get() = strides.shape
|
||||
override val shape: IntArray get() = indexes.shape
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map {
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map {
|
||||
it to this[it]
|
||||
}
|
||||
|
||||
@ -49,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||
crossinline transform: (T) -> R,
|
||||
): BufferND<R> {
|
||||
return if (this is BufferND<T>)
|
||||
BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
||||
BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
|
||||
else {
|
||||
val strides = DefaultStrides(shape)
|
||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
@ -61,14 +55,14 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
||||
*
|
||||
* @param T the type of items.
|
||||
* @param strides The strides to access elements of [MutableBuffer] by linear indices.
|
||||
* @param mutableBuffer The underlying buffer.
|
||||
* @param buffer The underlying buffer.
|
||||
*/
|
||||
public class MutableBufferND<T>(
|
||||
strides: Strides,
|
||||
public val mutableBuffer: MutableBuffer<T>,
|
||||
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
||||
strides: ShapeIndex,
|
||||
override val buffer: MutableBuffer<T>,
|
||||
) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
|
||||
override fun set(index: IntArray, value: T) {
|
||||
mutableBuffer[strides.offset(index)] = value
|
||||
buffer[indexes.offset(index)] = value
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
||||
crossinline transform: (T) -> R,
|
||||
): MutableBufferND<R> {
|
||||
return if (this is MutableBufferND<T>)
|
||||
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
|
||||
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
|
||||
else {
|
||||
val strides = DefaultStrides(shape)
|
||||
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||
|
@ -6,108 +6,186 @@
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.ExtendedField
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
import kotlin.math.pow
|
||||
|
||||
public class DoubleBufferND(
|
||||
indexes: ShapeIndex,
|
||||
override val buffer: DoubleBuffer,
|
||||
) : BufferND<Double>(indexes, buffer)
|
||||
|
||||
|
||||
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
|
||||
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
|
||||
|
||||
override fun StructureND<Double>.toBufferND(): DoubleBufferND = when (this) {
|
||||
is DoubleBufferND -> this
|
||||
else -> {
|
||||
val indexer = indexerBuilder(shape)
|
||||
DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||
}
|
||||
}
|
||||
|
||||
private inline fun mapInline(
|
||||
arg: DoubleBufferND,
|
||||
transform: (Double) -> Double
|
||||
): DoubleBufferND {
|
||||
val indexes = arg.indexes
|
||||
val array = arg.buffer.array
|
||||
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
|
||||
}
|
||||
|
||||
private inline fun zipInline(
|
||||
l: DoubleBufferND,
|
||||
r: DoubleBufferND,
|
||||
block: (l: Double, r: Double) -> Double
|
||||
): DoubleBufferND {
|
||||
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
|
||||
val indexes = l.indexes
|
||||
val lArray = l.buffer.array
|
||||
val rArray = r.buffer.array
|
||||
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): BufferND<Double> =
|
||||
mapInline(toBufferND()) { DoubleField.transform(it) }
|
||||
|
||||
|
||||
override fun zip(
|
||||
left: StructureND<Double>,
|
||||
right: StructureND<Double>,
|
||||
transform: DoubleField.(Double, Double) -> Double
|
||||
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
|
||||
|
||||
override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
|
||||
val indexer = indexerBuilder(shape)
|
||||
return DoubleBufferND(
|
||||
indexer,
|
||||
DoubleBuffer(indexer.linearSize) { offset ->
|
||||
elementAlgebra.initializer(indexer.index(offset))
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
override fun add(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r }
|
||||
|
||||
override fun multiply(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r }
|
||||
|
||||
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
|
||||
|
||||
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r }
|
||||
|
||||
override fun divide(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r }
|
||||
|
||||
override fun StructureND<Double>.div(arg: Double): DoubleBufferND =
|
||||
mapInline(toBufferND()) { it / arg }
|
||||
|
||||
override fun Double.div(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { this / it }
|
||||
|
||||
override fun StructureND<Double>.unaryPlus(): DoubleBufferND = toBufferND()
|
||||
|
||||
override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l + r }
|
||||
|
||||
override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l - r }
|
||||
|
||||
override fun StructureND<Double>.times(other: StructureND<Double>): DoubleBufferND =
|
||||
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l * r }
|
||||
|
||||
override fun StructureND<Double>.times(k: Number): DoubleBufferND =
|
||||
mapInline(toBufferND()) { it * k.toDouble() }
|
||||
|
||||
override fun StructureND<Double>.div(k: Number): DoubleBufferND =
|
||||
mapInline(toBufferND()) { it / k.toDouble() }
|
||||
|
||||
override fun Number.times(other: StructureND<Double>): DoubleBufferND = other * this
|
||||
|
||||
override fun StructureND<Double>.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg }
|
||||
|
||||
override fun StructureND<Double>.minus(arg: Double): StructureND<Double> = mapInline(toBufferND()) { it - arg }
|
||||
|
||||
override fun Double.plus(arg: StructureND<Double>): StructureND<Double> = arg + this
|
||||
|
||||
override fun Double.minus(arg: StructureND<Double>): StructureND<Double> = mapInline(arg.toBufferND()) { this - it }
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): DoubleBufferND =
|
||||
mapInline(a.toBufferND()) { it * value }
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) }
|
||||
|
||||
override fun exp(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.exp(it) }
|
||||
|
||||
override fun ln(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.ln(it) }
|
||||
|
||||
override fun sin(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.sin(it) }
|
||||
|
||||
override fun cos(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.cos(it) }
|
||||
|
||||
override fun tan(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.tan(it) }
|
||||
|
||||
override fun asin(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.asin(it) }
|
||||
|
||||
override fun acos(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.acos(it) }
|
||||
|
||||
override fun atan(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.atan(it) }
|
||||
|
||||
override fun sinh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.sinh(it) }
|
||||
|
||||
override fun cosh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.cosh(it) }
|
||||
|
||||
override fun tanh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.tanh(it) }
|
||||
|
||||
override fun asinh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.asinh(it) }
|
||||
|
||||
override fun acosh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.acosh(it) }
|
||||
|
||||
override fun atanh(arg: StructureND<Double>): DoubleBufferND =
|
||||
mapInline(arg.toBufferND()) { kotlin.math.atanh(it) }
|
||||
|
||||
public companion object : DoubleFieldOpsND()
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class DoubleFieldND(
|
||||
shape: IntArray,
|
||||
) : BufferedFieldND<Double, DoubleField>(shape, DoubleField, ::DoubleBuffer),
|
||||
NumbersAddOperations<StructureND<Double>>,
|
||||
ScaleOperations<StructureND<Double>>,
|
||||
ExtendedField<StructureND<Double>> {
|
||||
public class DoubleFieldND(override val shape: Shape) :
|
||||
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
||||
|
||||
override val zero: BufferND<Double> by lazy { produce { zero } }
|
||||
override val one: BufferND<Double> by lazy { produce { one } }
|
||||
|
||||
override fun number(value: Number): BufferND<Double> {
|
||||
override fun number(value: Number): DoubleBufferND {
|
||||
val d = value.toDouble() // minimize conversions
|
||||
return produce { d }
|
||||
return structureND(shape) { d }
|
||||
}
|
||||
|
||||
override val StructureND<Double>.buffer: DoubleBuffer
|
||||
get() = when {
|
||||
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
|
||||
this@DoubleFieldND.shape,
|
||||
shape
|
||||
)
|
||||
this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer
|
||||
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun StructureND<Double>.map(
|
||||
transform: DoubleField.(Double) -> Double,
|
||||
): BufferND<Double> {
|
||||
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
|
||||
return BufferND(strides, buffer)
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
||||
val array = DoubleArray(strides.linearSize) { offset ->
|
||||
val index = strides.index(offset)
|
||||
DoubleField.initializer(index)
|
||||
}
|
||||
return BufferND(strides, DoubleBuffer(array))
|
||||
}
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun StructureND<Double>.mapIndexed(
|
||||
transform: DoubleField.(index: IntArray, Double) -> Double,
|
||||
): BufferND<Double> = BufferND(
|
||||
strides,
|
||||
buffer = DoubleBuffer(strides.linearSize) { offset ->
|
||||
DoubleField.transform(
|
||||
strides.index(offset),
|
||||
buffer.array[offset]
|
||||
)
|
||||
})
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE")
|
||||
override inline fun combine(
|
||||
a: StructureND<Double>,
|
||||
b: StructureND<Double>,
|
||||
transform: DoubleField.(Double, Double) -> Double,
|
||||
): BufferND<Double> {
|
||||
val buffer = DoubleBuffer(strides.linearSize) { offset ->
|
||||
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
||||
}
|
||||
return BufferND(strides, buffer)
|
||||
}
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
|
||||
|
||||
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
|
||||
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
|
||||
|
||||
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
|
||||
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
|
||||
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
|
||||
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
|
||||
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
|
||||
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
|
||||
|
||||
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
|
||||
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
|
||||
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
|
||||
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
|
||||
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
|
||||
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
|
||||
}
|
||||
|
||||
public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND
|
||||
|
||||
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
|
||||
|
||||
/**
|
||||
* Produce a context for n-dimensional operations inside this real field
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
|
||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
return DoubleFieldND(shape).run(action)
|
||||
|
@ -0,0 +1,119 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import kotlin.native.concurrent.ThreadLocal
|
||||
|
||||
/**
|
||||
* A converter from linear index to multivariate index
|
||||
*/
|
||||
public interface ShapeIndex{
|
||||
public val shape: Shape
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
*/
|
||||
public fun offset(index: IntArray): Int
|
||||
|
||||
/**
|
||||
* Get multidimensional from linear
|
||||
*/
|
||||
public fun index(offset: Int): IntArray
|
||||
|
||||
/**
|
||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||
*/
|
||||
public val linearSize: Int
|
||||
|
||||
// TODO introduce a fast way to calculate index of the next element?
|
||||
|
||||
/**
|
||||
* Iterate over ND indices in a natural order
|
||||
*/
|
||||
public fun indices(): Sequence<IntArray>
|
||||
|
||||
override fun equals(other: Any?): Boolean
|
||||
override fun hashCode(): Int
|
||||
}
|
||||
|
||||
/**
|
||||
* Linear transformation of indexes
|
||||
*/
|
||||
public abstract class Strides: ShapeIndex {
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
||||
public abstract val strides: IntArray
|
||||
|
||||
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
// TODO introduce a fast way to calculate index of the next element?
|
||||
|
||||
/**
|
||||
* Iterate over ND indices in a natural order
|
||||
*/
|
||||
public override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple implementation of [Strides].
|
||||
*/
|
||||
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
|
||||
override val linearSize: Int get() = strides[shape.size]
|
||||
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
|
||||
shape.forEach {
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
|
||||
override fun index(offset: Int): IntArray {
|
||||
val res = IntArray(shape.size)
|
||||
var current = offset
|
||||
var strideIndex = strides.size - 2
|
||||
|
||||
while (strideIndex >= 0) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex--
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is DefaultStrides) return false
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
* Cached builder for default strides
|
||||
*/
|
||||
public operator fun invoke(shape: IntArray): Strides =
|
||||
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||
}
|
||||
}
|
||||
|
||||
@ThreadLocal
|
||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
@ -6,34 +6,27 @@
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.NumbersAddOps
|
||||
import space.kscience.kmath.operations.ShortRing
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.ShortBuffer
|
||||
import space.kscience.kmath.operations.bufferAlgebra
|
||||
import kotlin.contracts.InvocationKind
|
||||
import kotlin.contracts.contract
|
||||
|
||||
public sealed class ShortRingOpsND : BufferedRingOpsND<Short, ShortRing>(ShortRing.bufferAlgebra) {
|
||||
public companion object : ShortRingOpsND()
|
||||
}
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public class ShortRingND(
|
||||
shape: IntArray,
|
||||
) : BufferedRingND<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto),
|
||||
NumbersAddOperations<StructureND<Short>> {
|
||||
|
||||
override val zero: BufferND<Short> by lazy { produce { zero } }
|
||||
override val one: BufferND<Short> by lazy { produce { one } }
|
||||
override val shape: Shape
|
||||
) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
|
||||
|
||||
override fun number(value: Number): BufferND<Short> {
|
||||
val d = value.toShort() // minimize conversions
|
||||
return produce { d }
|
||||
return structureND(shape) { d }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fast element production using function inlining.
|
||||
*/
|
||||
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
|
||||
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
|
||||
|
||||
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
|
||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||
return ShortRingND(shape).run(action)
|
||||
|
@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import kotlin.jvm.JvmName
|
||||
import kotlin.math.abs
|
||||
import kotlin.native.concurrent.ThreadLocal
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
public interface StructureFeature : Feature<StructureFeature>
|
||||
@ -72,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
||||
if (st1 === st2) return true
|
||||
|
||||
// fast comparison of buffers if possible
|
||||
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
|
||||
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
|
||||
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
||||
|
||||
//element by element comparison if it could not be avoided
|
||||
@ -88,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
||||
if (st1 === st2) return true
|
||||
|
||||
// fast comparison of buffers if possible
|
||||
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
|
||||
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
|
||||
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
||||
|
||||
//element by element comparison if it could not be avoided
|
||||
@ -187,11 +186,11 @@ public fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.contentEquals(
|
||||
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
||||
*/
|
||||
@PerformancePitfall
|
||||
public fun <T : Comparable<T>> GroupND<T, Ring<T>>.contentEquals(
|
||||
public fun <T : Comparable<T>> GroupOpsND<T, Ring<T>>.contentEquals(
|
||||
st1: StructureND<T>,
|
||||
st2: StructureND<T>,
|
||||
absoluteTolerance: T,
|
||||
): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance }
|
||||
): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance }
|
||||
|
||||
/**
|
||||
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
||||
@ -231,107 +230,10 @@ public interface MutableStructureND<T> : StructureND<T> {
|
||||
* Transform a structure element-by element in place.
|
||||
*/
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
|
||||
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (index: IntArray, t: T) -> T): Unit =
|
||||
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||
|
||||
/**
|
||||
* A way to convert ND indices to linear one and back.
|
||||
*/
|
||||
public interface Strides {
|
||||
/**
|
||||
* Shape of NDStructure
|
||||
*/
|
||||
public val shape: IntArray
|
||||
|
||||
/**
|
||||
* Array strides
|
||||
*/
|
||||
public val strides: IntArray
|
||||
|
||||
/**
|
||||
* Get linear index from multidimensional index
|
||||
*/
|
||||
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||
value * strides[i]
|
||||
}.sum()
|
||||
|
||||
/**
|
||||
* Get multidimensional from linear
|
||||
*/
|
||||
public fun index(offset: Int): IntArray
|
||||
|
||||
/**
|
||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||
*/
|
||||
public val linearSize: Int
|
||||
|
||||
// TODO introduce a fast way to calculate index of the next element?
|
||||
|
||||
/**
|
||||
* Iterate over ND indices in a natural order
|
||||
*/
|
||||
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple implementation of [Strides].
|
||||
*/
|
||||
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
||||
override val linearSize: Int
|
||||
get() = strides[shape.size]
|
||||
|
||||
/**
|
||||
* Strides for memory access
|
||||
*/
|
||||
override val strides: IntArray by lazy {
|
||||
sequence {
|
||||
var current = 1
|
||||
yield(1)
|
||||
|
||||
shape.forEach {
|
||||
current *= it
|
||||
yield(current)
|
||||
}
|
||||
}.toList().toIntArray()
|
||||
}
|
||||
|
||||
override fun index(offset: Int): IntArray {
|
||||
val res = IntArray(shape.size)
|
||||
var current = offset
|
||||
var strideIndex = strides.size - 2
|
||||
|
||||
while (strideIndex >= 0) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex--
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other !is DefaultStrides) return false
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int = shape.contentHashCode()
|
||||
|
||||
@ThreadLocal
|
||||
public companion object {
|
||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||
|
||||
/**
|
||||
* Cached builder for default strides
|
||||
*/
|
||||
public operator fun invoke(shape: IntArray): Strides =
|
||||
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||
}
|
||||
}
|
||||
|
||||
public inline fun <reified T : Any> StructureND<T>.combine(
|
||||
public inline fun <reified T : Any> StructureND<T>.zip(
|
||||
struct: StructureND<T>,
|
||||
crossinline block: (T, T) -> T,
|
||||
): StructureND<T> {
|
||||
|
@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
|
||||
public fun <T, A : Algebra<T>> AlgebraND<T, A>.structureND(
|
||||
shapeFirst: Int,
|
||||
vararg shapeRest: Int,
|
||||
initializer: A.(IntArray) -> T
|
||||
): StructureND<T> = structureND(Shape(shapeFirst, *shapeRest), initializer)
|
||||
|
||||
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = structureND(shape) { zero }
|
||||
|
||||
@JvmName("zeroVarArg")
|
||||
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
|
||||
shapeFirst: Int,
|
||||
vararg shapeRest: Int,
|
||||
): StructureND<T> = structureND(shapeFirst, *shapeRest) { zero }
|
||||
|
||||
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = structureND(shape) { one }
|
||||
|
||||
@JvmName("oneVarArg")
|
||||
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
|
||||
shapeFirst: Int,
|
||||
vararg shapeRest: Int,
|
||||
): StructureND<T> = structureND(shapeFirst, *shapeRest) { one }
|
@ -117,15 +117,15 @@ public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = r
|
||||
*
|
||||
* @param T the type of element of this semispace.
|
||||
*/
|
||||
public interface GroupOperations<T> : Algebra<T> {
|
||||
public interface GroupOps<T> : Algebra<T> {
|
||||
/**
|
||||
* Addition of two elements.
|
||||
*
|
||||
* @param a the augend.
|
||||
* @param b the addend.
|
||||
* @param left the augend.
|
||||
* @param right the addend.
|
||||
* @return the sum.
|
||||
*/
|
||||
public fun add(a: T, b: T): T
|
||||
public fun add(left: T, right: T): T
|
||||
|
||||
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
|
||||
|
||||
@ -149,20 +149,20 @@ public interface GroupOperations<T> : Algebra<T> {
|
||||
* Addition of two elements.
|
||||
*
|
||||
* @receiver the augend.
|
||||
* @param b the addend.
|
||||
* @param other the addend.
|
||||
* @return the sum.
|
||||
*/
|
||||
public operator fun T.plus(b: T): T = add(this, b)
|
||||
public operator fun T.plus(other: T): T = add(this, other)
|
||||
|
||||
/**
|
||||
* Subtraction of two elements.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @param other the subtrahend.
|
||||
* @return the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: T): T = add(this, -b)
|
||||
|
||||
public operator fun T.minus(other: T): T = add(this, -other)
|
||||
// Dynamic dispatch of operations
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> { arg -> +arg }
|
||||
MINUS_OPERATION -> { arg -> -arg }
|
||||
@ -193,7 +193,7 @@ public interface GroupOperations<T> : Algebra<T> {
|
||||
*
|
||||
* @param T the type of element of this semispace.
|
||||
*/
|
||||
public interface Group<T> : GroupOperations<T> {
|
||||
public interface Group<T> : GroupOps<T> {
|
||||
/**
|
||||
* The neutral element of addition.
|
||||
*/
|
||||
@ -206,22 +206,22 @@ public interface Group<T> : GroupOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this semiring.
|
||||
*/
|
||||
public interface RingOperations<T> : GroupOperations<T> {
|
||||
public interface RingOps<T> : GroupOps<T> {
|
||||
/**
|
||||
* Multiplies two elements.
|
||||
*
|
||||
* @param a the multiplier.
|
||||
* @param b the multiplicand.
|
||||
* @param left the multiplier.
|
||||
* @param right the multiplicand.
|
||||
*/
|
||||
public fun multiply(a: T, b: T): T
|
||||
public fun multiply(left: T, right: T): T
|
||||
|
||||
/**
|
||||
* Multiplies this element by scalar.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param b the multiplicand.
|
||||
* @param other the multiplicand.
|
||||
*/
|
||||
public operator fun T.times(b: T): T = multiply(this, b)
|
||||
public operator fun T.times(other: T): T = multiply(this, other)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
TIMES_OPERATION -> ::multiply
|
||||
@ -242,7 +242,7 @@ public interface RingOperations<T> : GroupOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this ring.
|
||||
*/
|
||||
public interface Ring<T> : Group<T>, RingOperations<T> {
|
||||
public interface Ring<T> : Group<T>, RingOps<T> {
|
||||
/**
|
||||
* The neutral element of multiplication
|
||||
*/
|
||||
@ -256,24 +256,24 @@ public interface Ring<T> : Group<T>, RingOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this semifield.
|
||||
*/
|
||||
public interface FieldOperations<T> : RingOperations<T> {
|
||||
public interface FieldOps<T> : RingOps<T> {
|
||||
/**
|
||||
* Division of two elements.
|
||||
*
|
||||
* @param a the dividend.
|
||||
* @param b the divisor.
|
||||
* @param left the dividend.
|
||||
* @param right the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
public fun divide(a: T, b: T): T
|
||||
public fun divide(left: T, right: T): T
|
||||
|
||||
/**
|
||||
* Division of two elements.
|
||||
*
|
||||
* @receiver the dividend.
|
||||
* @param b the divisor.
|
||||
* @param other the divisor.
|
||||
* @return the quotient.
|
||||
*/
|
||||
public operator fun T.div(b: T): T = divide(this, b)
|
||||
public operator fun T.div(other: T): T = divide(this, other)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
|
||||
DIV_OPERATION -> ::divide
|
||||
@ -295,6 +295,6 @@ public interface FieldOperations<T> : RingOperations<T> {
|
||||
*
|
||||
* @param T the type of element of this field.
|
||||
*/
|
||||
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> {
|
||||
public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlgebra<T> {
|
||||
override fun number(value: Number): T = scale(one, value.toDouble())
|
||||
}
|
||||
|
@ -6,7 +6,7 @@
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.BufferedRingND
|
||||
import space.kscience.kmath.nd.BufferedRingOpsND
|
||||
import space.kscience.kmath.operations.BigInt.Companion.BASE
|
||||
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
@ -26,7 +26,7 @@ private typealias TBase = ULong
|
||||
* @author Peter Klimai
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> {
|
||||
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
|
||||
override val zero: BigInt = BigInt.ZERO
|
||||
override val one: BigInt = BigInt.ONE
|
||||
|
||||
@ -34,10 +34,10 @@ public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOp
|
||||
|
||||
@Suppress("EXTENSION_SHADOWED_BY_MEMBER")
|
||||
override fun BigInt.unaryMinus(): BigInt = -this
|
||||
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b)
|
||||
override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right)
|
||||
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
|
||||
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b)
|
||||
override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b)
|
||||
override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right)
|
||||
override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right)
|
||||
|
||||
public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
|
||||
public operator fun String.unaryMinus(): BigInt =
|
||||
@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -
|
||||
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||
Buffer.boxing(size, initializer)
|
||||
|
||||
public fun BigIntField.nd(vararg shape: Int): BufferedRingND<BigInt, BigIntField> =
|
||||
BufferedRingND(shape, BigIntField, BigInt::buffer)
|
||||
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
|
||||
get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer))
|
||||
|
@ -5,32 +5,34 @@
|
||||
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.ShortBuffer
|
||||
|
||||
public interface WithSize {
|
||||
public val size: Int
|
||||
}
|
||||
|
||||
/**
|
||||
* An algebra over [Buffer]
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||
public val bufferFactory: BufferFactory<T>
|
||||
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||
public val elementAlgebra: A
|
||||
public val size: Int
|
||||
public val bufferFactory: BufferFactory<T>
|
||||
|
||||
public fun buffer(vararg elements: T): Buffer<T> {
|
||||
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
|
||||
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
||||
return bufferFactory(size) { elements[it] }
|
||||
}
|
||||
|
||||
//TODO move to multi-receiver inline extension
|
||||
public fun Buffer<T>.map(block: (T) -> T): Buffer<T> = bufferFactory(size) { block(get(it)) }
|
||||
public fun Buffer<T>.map(block: A.(T) -> T): Buffer<T> = mapInline(this, block)
|
||||
|
||||
public fun Buffer<T>.zip(other: Buffer<T>, block: (left: T, right: T) -> T): Buffer<T> {
|
||||
require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" }
|
||||
return bufferFactory(size) { block(this[it], other[it]) }
|
||||
}
|
||||
public fun Buffer<T>.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer<T> = mapIndexedInline(this, block)
|
||||
|
||||
public fun Buffer<T>.zip(other: Buffer<T>, block: A.(left: T, right: T) -> T): Buffer<T> =
|
||||
zipInline(this, other, block)
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
|
||||
val operationFunction = elementAlgebra.unaryOperationFunction(operation)
|
||||
@ -45,112 +47,149 @@ public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T> BufferField<T, *>.buffer(initializer: (Int) -> T): Buffer<T> {
|
||||
/**
|
||||
* Inline map
|
||||
*/
|
||||
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
|
||||
buffer: Buffer<T>,
|
||||
crossinline block: A.(T) -> T
|
||||
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
|
||||
|
||||
/**
|
||||
* Inline map
|
||||
*/
|
||||
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
|
||||
buffer: Buffer<T>,
|
||||
crossinline block: A.(index: Int, arg: T) -> T
|
||||
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
|
||||
|
||||
/**
|
||||
* Inline zip
|
||||
*/
|
||||
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
|
||||
l: Buffer<T>,
|
||||
r: Buffer<T>,
|
||||
crossinline block: A.(l: T, r: T) -> T
|
||||
): Buffer<T> {
|
||||
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
|
||||
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
|
||||
}
|
||||
|
||||
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
|
||||
return bufferFactory(size, initializer)
|
||||
}
|
||||
|
||||
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
|
||||
return bufferFactory(size, initializer)
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::sin)
|
||||
mapInline(arg) { sin(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::cos)
|
||||
mapInline(arg) { cos(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::tan)
|
||||
mapInline(arg) { tan(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::asin)
|
||||
mapInline(arg) { asin(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::acos)
|
||||
mapInline(arg) { acos(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::atan)
|
||||
mapInline(arg) { atan(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::exp)
|
||||
mapInline(arg) { exp(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::ln)
|
||||
mapInline(arg) { ln(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::sinh)
|
||||
mapInline(arg) { sinh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::cosh)
|
||||
mapInline(arg) { cosh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::tanh)
|
||||
mapInline(arg) { tanh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::asinh)
|
||||
mapInline(arg) { asinh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::acosh)
|
||||
mapInline(arg) { acosh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
|
||||
arg.map(elementAlgebra::atanh)
|
||||
mapInline(arg) { atanh(it) }
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
|
||||
with(elementAlgebra) { arg.map { power(it, pow) } }
|
||||
mapInline(arg) { power(it, pow) }
|
||||
|
||||
|
||||
@UnstableKMathAPI
|
||||
public class BufferField<T, A : Field<T>>(
|
||||
override val bufferFactory: BufferFactory<T>,
|
||||
public open class BufferRingOps<T, A: Ring<T>>(
|
||||
override val elementAlgebra: A,
|
||||
override val bufferFactory: BufferFactory<T>,
|
||||
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
|
||||
|
||||
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
||||
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
|
||||
super<BufferAlgebra>.unaryOperationFunction(operation)
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
|
||||
super<BufferAlgebra>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
|
||||
get() = BufferRingOps(ShortRing, ::ShortBuffer)
|
||||
|
||||
public open class BufferFieldOps<T, A : Field<T>>(
|
||||
elementAlgebra: A,
|
||||
bufferFactory: BufferFactory<T>,
|
||||
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
|
||||
|
||||
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
|
||||
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
|
||||
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
|
||||
|
||||
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
|
||||
super<BufferRingOps>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
public class BufferField<T, A : Field<T>>(
|
||||
elementAlgebra: A,
|
||||
bufferFactory: BufferFactory<T>,
|
||||
override val size: Int
|
||||
) : BufferAlgebra<T, A>, Field<Buffer<T>> {
|
||||
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
|
||||
|
||||
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
|
||||
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
|
||||
|
||||
|
||||
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::add)
|
||||
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::multiply)
|
||||
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::divide)
|
||||
|
||||
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = with(elementAlgebra) { a.map { scale(it, value) } }
|
||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = with(elementAlgebra) { map { -it } }
|
||||
|
||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
|
||||
return super<BufferAlgebra>.unaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> {
|
||||
return super<BufferAlgebra>.binaryOperationFunction(operation)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate full buffer field from given buffer operations
|
||||
*/
|
||||
public fun <T, A : Field<T>> BufferFieldOps<T, A>.withSize(size: Int): BufferField<T, A> =
|
||||
BufferField(elementAlgebra, bufferFactory, size)
|
||||
|
||||
//Double buffer specialization
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
|
||||
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
||||
return bufferFactory(size) { elements[it].toDouble() }
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>, size: Int): BufferField<T, A> =
|
||||
BufferField(bufferFactory, this, size)
|
||||
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>): BufferFieldOps<T, A> =
|
||||
BufferFieldOps(this, bufferFactory)
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun DoubleField.bufferAlgebra(size: Int): BufferField<Double, DoubleField> =
|
||||
BufferField(::DoubleBuffer, DoubleField, size)
|
||||
public val DoubleField.bufferAlgebra: BufferFieldOps<Double, DoubleField>
|
||||
get() = BufferFieldOps(DoubleField, ::DoubleBuffer)
|
||||
|
||||
|
@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer
|
||||
*
|
||||
* @property size the size of buffers to operate on.
|
||||
*/
|
||||
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOperations() {
|
||||
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOps() {
|
||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
||||
|
||||
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.sinh(arg)
|
||||
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.sinh(arg)
|
||||
|
||||
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.cosh(arg)
|
||||
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.cosh(arg)
|
||||
|
||||
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.tanh(arg)
|
||||
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.tanh(arg)
|
||||
|
||||
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.asinh(arg)
|
||||
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.asinh(arg)
|
||||
|
||||
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.acosh(arg)
|
||||
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.acosh(arg)
|
||||
|
||||
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOperations>.atanh(arg)
|
||||
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOps>.atanh(arg)
|
||||
|
||||
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
|
||||
//
|
||||
|
@ -12,39 +12,40 @@ import space.kscience.kmath.structures.DoubleBuffer
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* [ExtendedFieldOperations] over [DoubleBuffer].
|
||||
* [ExtendedFieldOps] over [DoubleBuffer].
|
||||
*/
|
||||
public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
||||
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
||||
|
||||
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
|
||||
DoubleBuffer(size) { -array[it] }
|
||||
} else {
|
||||
DoubleBuffer(size) { -get(it) }
|
||||
}
|
||||
|
||||
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
override fun add(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||
require(right.size == left.size) {
|
||||
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||
}
|
||||
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] })
|
||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] })
|
||||
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||
val aArray = left.array
|
||||
val bArray = right.array
|
||||
DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] })
|
||||
} else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
|
||||
}
|
||||
|
||||
override fun Buffer<Double>.plus(b: Buffer<Double>): DoubleBuffer = add(this, b)
|
||||
override fun Buffer<Double>.plus(other: Buffer<Double>): DoubleBuffer = add(this, other)
|
||||
|
||||
override fun Buffer<Double>.minus(b: Buffer<Double>): DoubleBuffer {
|
||||
require(b.size == this.size) {
|
||||
"The size of the first buffer ${this.size} should be the same as for second one: ${b.size} "
|
||||
override fun Buffer<Double>.minus(other: Buffer<Double>): DoubleBuffer {
|
||||
require(other.size == this.size) {
|
||||
"The size of the first buffer ${this.size} should be the same as for second one: ${other.size} "
|
||||
}
|
||||
|
||||
return if (this is DoubleBuffer && b is DoubleBuffer) {
|
||||
return if (this is DoubleBuffer && other is DoubleBuffer) {
|
||||
val aArray = this.array
|
||||
val bArray = b.array
|
||||
val bArray = other.array
|
||||
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
|
||||
} else DoubleBuffer(DoubleArray(this.size) { this[it] - b[it] })
|
||||
} else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] })
|
||||
}
|
||||
|
||||
//
|
||||
@ -66,29 +67,29 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
|
||||
// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue })
|
||||
// }
|
||||
|
||||
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
override fun multiply(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||
require(right.size == left.size) {
|
||||
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||
}
|
||||
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] })
|
||||
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||
val aArray = left.array
|
||||
val bArray = right.array
|
||||
DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] })
|
||||
} else
|
||||
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] })
|
||||
DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] })
|
||||
}
|
||||
|
||||
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer {
|
||||
require(b.size == a.size) {
|
||||
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} "
|
||||
override fun divide(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
|
||||
require(right.size == left.size) {
|
||||
"The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
|
||||
}
|
||||
|
||||
return if (a is DoubleBuffer && b is DoubleBuffer) {
|
||||
val aArray = a.array
|
||||
val bArray = b.array
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] })
|
||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] })
|
||||
return if (left is DoubleBuffer && right is DoubleBuffer) {
|
||||
val aArray = left.array
|
||||
val bArray = right.array
|
||||
DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] })
|
||||
} else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] })
|
||||
}
|
||||
|
||||
override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) {
|
||||
@ -185,7 +186,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
|
||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
|
||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
|
||||
|
||||
public companion object : DoubleBufferOperations()
|
||||
public companion object : DoubleBufferOps()
|
||||
}
|
||||
|
||||
public object DoubleL2Norm : Norm<Point<Double>, Double> {
|
@ -139,10 +139,10 @@ public interface ScaleOperations<T> : Algebra<T> {
|
||||
* Multiplication of this number by element.
|
||||
*
|
||||
* @receiver the multiplier.
|
||||
* @param b the multiplicand.
|
||||
* @param other the multiplicand.
|
||||
* @return the product.
|
||||
*/
|
||||
public operator fun Number.times(b: T): T = b * this
|
||||
public operator fun Number.times(other: T): T = other * this
|
||||
}
|
||||
|
||||
/**
|
||||
@ -150,38 +150,38 @@ public interface ScaleOperations<T> : Algebra<T> {
|
||||
* TODO to be removed and replaced by extensions after multiple receivers are there
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public interface NumbersAddOperations<T> : Ring<T>, NumericAlgebra<T> {
|
||||
public interface NumbersAddOps<T> : RingOps<T>, NumericAlgebra<T> {
|
||||
/**
|
||||
* Addition of element and scalar.
|
||||
*
|
||||
* @receiver the augend.
|
||||
* @param b the addend.
|
||||
* @param other the addend.
|
||||
*/
|
||||
public operator fun T.plus(b: Number): T = this + number(b)
|
||||
public operator fun T.plus(other: Number): T = this + number(other)
|
||||
|
||||
/**
|
||||
* Addition of scalar and element.
|
||||
*
|
||||
* @receiver the augend.
|
||||
* @param b the addend.
|
||||
* @param other the addend.
|
||||
*/
|
||||
public operator fun Number.plus(b: T): T = b + this
|
||||
public operator fun Number.plus(other: T): T = other + this
|
||||
|
||||
/**
|
||||
* Subtraction of element from number.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @param other the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun T.minus(b: Number): T = this - number(b)
|
||||
public operator fun T.minus(other: Number): T = this - number(other)
|
||||
|
||||
/**
|
||||
* Subtraction of number from element.
|
||||
*
|
||||
* @receiver the minuend.
|
||||
* @param b the subtrahend.
|
||||
* @param other the subtrahend.
|
||||
* @receiver the difference.
|
||||
*/
|
||||
public operator fun Number.minus(b: T): T = -b + this
|
||||
public operator fun Number.minus(other: T): T = -other + this
|
||||
}
|
@ -10,8 +10,8 @@ import kotlin.math.pow as kpow
|
||||
/**
|
||||
* Advanced Number-like semifield that implements basic operations.
|
||||
*/
|
||||
public interface ExtendedFieldOperations<T> :
|
||||
FieldOperations<T>,
|
||||
public interface ExtendedFieldOps<T> :
|
||||
FieldOps<T>,
|
||||
TrigonometricOperations<T>,
|
||||
PowerOperations<T>,
|
||||
ExponentialOperations<T>,
|
||||
@ -35,14 +35,14 @@ public interface ExtendedFieldOperations<T> :
|
||||
ExponentialOperations.ACOSH_OPERATION -> ::acosh
|
||||
ExponentialOperations.ASINH_OPERATION -> ::asinh
|
||||
ExponentialOperations.ATANH_OPERATION -> ::atanh
|
||||
else -> super<FieldOperations>.unaryOperationFunction(operation)
|
||||
else -> super<FieldOps>.unaryOperationFunction(operation)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Advanced Number-like field that implements basic operations.
|
||||
*/
|
||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T>{
|
||||
public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T>{
|
||||
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
|
||||
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
|
||||
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||
@ -73,10 +73,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
|
||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override inline fun add(a: Double, b: Double): Double = a + b
|
||||
override inline fun add(left: Double, right: Double): Double = left + right
|
||||
|
||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
||||
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||
|
||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||
|
||||
@ -102,10 +102,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
|
||||
override inline fun norm(arg: Double): Double = abs(arg)
|
||||
|
||||
override inline fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.plus(b: Double): Double = this + b
|
||||
override inline fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.times(b: Double): Double = this * b
|
||||
override inline fun Double.div(b: Double): Double = this / b
|
||||
override inline fun Double.plus(other: Double): Double = this + other
|
||||
override inline fun Double.minus(other: Double): Double = this - other
|
||||
override inline fun Double.times(other: Double): Double = this * other
|
||||
override inline fun Double.div(other: Double): Double = this / other
|
||||
}
|
||||
|
||||
public val Double.Companion.algebra: DoubleField get() = DoubleField
|
||||
@ -126,12 +126,12 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
else -> super.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override inline fun add(a: Float, b: Float): Float = a + b
|
||||
override inline fun add(left: Float, right: Float): Float = left + right
|
||||
override fun scale(a: Float, value: Double): Float = a * value.toFloat()
|
||||
|
||||
override inline fun multiply(a: Float, b: Float): Float = a * b
|
||||
override inline fun multiply(left: Float, right: Float): Float = left * right
|
||||
|
||||
override inline fun divide(a: Float, b: Float): Float = a / b
|
||||
override inline fun divide(left: Float, right: Float): Float = left / right
|
||||
|
||||
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
|
||||
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
|
||||
@ -155,10 +155,10 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
|
||||
override inline fun norm(arg: Float): Float = abs(arg)
|
||||
|
||||
override inline fun Float.unaryMinus(): Float = -this
|
||||
override inline fun Float.plus(b: Float): Float = this + b
|
||||
override inline fun Float.minus(b: Float): Float = this - b
|
||||
override inline fun Float.times(b: Float): Float = this * b
|
||||
override inline fun Float.div(b: Float): Float = this / b
|
||||
override inline fun Float.plus(other: Float): Float = this + other
|
||||
override inline fun Float.minus(other: Float): Float = this - other
|
||||
override inline fun Float.times(other: Float): Float = this * other
|
||||
override inline fun Float.div(other: Float): Float = this / other
|
||||
}
|
||||
|
||||
public val Float.Companion.algebra: FloatField get() = FloatField
|
||||
@ -175,14 +175,14 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Int = value.toInt()
|
||||
override inline fun add(a: Int, b: Int): Int = a + b
|
||||
override inline fun multiply(a: Int, b: Int): Int = a * b
|
||||
override inline fun add(left: Int, right: Int): Int = left + right
|
||||
override inline fun multiply(left: Int, right: Int): Int = left * right
|
||||
override inline fun norm(arg: Int): Int = abs(arg)
|
||||
|
||||
override inline fun Int.unaryMinus(): Int = -this
|
||||
override inline fun Int.plus(b: Int): Int = this + b
|
||||
override inline fun Int.minus(b: Int): Int = this - b
|
||||
override inline fun Int.times(b: Int): Int = this * b
|
||||
override inline fun Int.plus(other: Int): Int = this + other
|
||||
override inline fun Int.minus(other: Int): Int = this - other
|
||||
override inline fun Int.times(other: Int): Int = this * other
|
||||
}
|
||||
|
||||
public val Int.Companion.algebra: IntRing get() = IntRing
|
||||
@ -199,14 +199,14 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Short = value.toShort()
|
||||
override inline fun add(a: Short, b: Short): Short = (a + b).toShort()
|
||||
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort()
|
||||
override inline fun add(left: Short, right: Short): Short = (left + right).toShort()
|
||||
override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
|
||||
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
|
||||
|
||||
override inline fun Short.unaryMinus(): Short = (-this).toShort()
|
||||
override inline fun Short.plus(b: Short): Short = (this + b).toShort()
|
||||
override inline fun Short.minus(b: Short): Short = (this - b).toShort()
|
||||
override inline fun Short.times(b: Short): Short = (this * b).toShort()
|
||||
override inline fun Short.plus(other: Short): Short = (this + other).toShort()
|
||||
override inline fun Short.minus(other: Short): Short = (this - other).toShort()
|
||||
override inline fun Short.times(other: Short): Short = (this * other).toShort()
|
||||
}
|
||||
|
||||
public val Short.Companion.algebra: ShortRing get() = ShortRing
|
||||
@ -223,14 +223,14 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
|
||||
get() = 1
|
||||
|
||||
override fun number(value: Number): Byte = value.toByte()
|
||||
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte()
|
||||
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte()
|
||||
override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte()
|
||||
override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte()
|
||||
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
|
||||
|
||||
override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
|
||||
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte()
|
||||
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte()
|
||||
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte()
|
||||
override inline fun Byte.plus(other: Byte): Byte = (this + other).toByte()
|
||||
override inline fun Byte.minus(other: Byte): Byte = (this - other).toByte()
|
||||
override inline fun Byte.times(other: Byte): Byte = (this * other).toByte()
|
||||
}
|
||||
|
||||
public val Byte.Companion.algebra: ByteRing get() = ByteRing
|
||||
@ -247,14 +247,14 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
|
||||
get() = 1L
|
||||
|
||||
override fun number(value: Number): Long = value.toLong()
|
||||
override inline fun add(a: Long, b: Long): Long = a + b
|
||||
override inline fun multiply(a: Long, b: Long): Long = a * b
|
||||
override inline fun add(left: Long, right: Long): Long = left + right
|
||||
override inline fun multiply(left: Long, right: Long): Long = left * right
|
||||
override fun norm(arg: Long): Long = abs(arg)
|
||||
|
||||
override inline fun Long.unaryMinus(): Long = (-this)
|
||||
override inline fun Long.plus(b: Long): Long = (this + b)
|
||||
override inline fun Long.minus(b: Long): Long = (this - b)
|
||||
override inline fun Long.times(b: Long): Long = (this * b)
|
||||
override inline fun Long.plus(other: Long): Long = (this + other)
|
||||
override inline fun Long.minus(other: Long): Long = (this - other)
|
||||
override inline fun Long.times(other: Long): Long = (this * other)
|
||||
}
|
||||
|
||||
public val Long.Companion.algebra: LongRing get() = LongRing
|
||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.structures
|
||||
|
||||
import space.kscience.kmath.nd.get
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.structureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.testutils.FieldVerifier
|
||||
@ -21,7 +22,7 @@ internal class NDFieldTest {
|
||||
|
||||
@Test
|
||||
fun testStrides() {
|
||||
val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() }
|
||||
val ndArray = DoubleField.ndAlgebra.structureND(10, 10) { (it[0] + it[1]).toDouble() }
|
||||
assertEquals(ndArray[5, 5], 10.0)
|
||||
}
|
||||
}
|
||||
|
@ -7,10 +7,7 @@ package space.kscience.kmath.structures
|
||||
|
||||
import space.kscience.kmath.linear.linearSpace
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.combine
|
||||
import space.kscience.kmath.nd.get
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Norm
|
||||
import space.kscience.kmath.operations.algebra
|
||||
@ -22,9 +19,9 @@ import kotlin.test.assertEquals
|
||||
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
class NumberNDFieldTest {
|
||||
val algebra = DoubleField.ndAlgebra(3, 3)
|
||||
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() }
|
||||
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() }
|
||||
val algebra = DoubleField.ndAlgebra
|
||||
val array1 = algebra.structureND(3, 3) { (i, j) -> (i + j).toDouble() }
|
||||
val array2 = algebra.structureND(3, 3) { (i, j) -> (i - j).toDouble() }
|
||||
|
||||
@Test
|
||||
fun testSum() {
|
||||
@ -77,7 +74,7 @@ class NumberNDFieldTest {
|
||||
|
||||
@Test
|
||||
fun combineTest() {
|
||||
val division = array1.combine(array2, Double::div)
|
||||
val division = array1.zip(array2, Double::div)
|
||||
}
|
||||
|
||||
object L2Norm : Norm<StructureND<Number>, Double> {
|
||||
|
@ -18,9 +18,9 @@ public object JBigIntegerField : Ring<BigInteger>, NumericAlgebra<BigInteger> {
|
||||
override val one: BigInteger get() = BigInteger.ONE
|
||||
|
||||
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
|
||||
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b)
|
||||
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b)
|
||||
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b)
|
||||
override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right)
|
||||
override operator fun BigInteger.minus(other: BigInteger): BigInteger = subtract(other)
|
||||
override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right)
|
||||
|
||||
override operator fun BigInteger.unaryMinus(): BigInteger = negate()
|
||||
}
|
||||
@ -39,15 +39,15 @@ public abstract class JBigDecimalFieldBase internal constructor(
|
||||
override val one: BigDecimal
|
||||
get() = BigDecimal.ONE
|
||||
|
||||
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b)
|
||||
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b)
|
||||
override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right)
|
||||
override operator fun BigDecimal.minus(other: BigDecimal): BigDecimal = subtract(other)
|
||||
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
|
||||
|
||||
override fun scale(a: BigDecimal, value: Double): BigDecimal =
|
||||
a.multiply(value.toBigDecimal(mathContext), mathContext)
|
||||
|
||||
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext)
|
||||
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext)
|
||||
override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext)
|
||||
override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext)
|
||||
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
|
||||
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
|
||||
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)
|
||||
|
@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.map
|
||||
import kotlinx.coroutines.flow.runningReduce
|
||||
import kotlinx.coroutines.flow.scan
|
||||
import space.kscience.kmath.operations.GroupOperations
|
||||
import space.kscience.kmath.operations.GroupOps
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.operations.invoke
|
||||
|
||||
public fun <T> Flow<T>.cumulativeSum(group: GroupOperations<T>): Flow<T> =
|
||||
public fun <T> Flow<T>.cumulativeSum(group: GroupOps<T>): Flow<T> =
|
||||
group { runningReduce { sum, element -> sum + element } }
|
||||
|
||||
@ExperimentalCoroutinesApi
|
||||
|
@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
|
||||
* Map one [BufferND] using function without indices.
|
||||
*/
|
||||
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
|
||||
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
|
||||
return BufferND(strides, DoubleBuffer(array))
|
||||
val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
|
||||
return BufferND(indexes, DoubleBuffer(array))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -104,12 +104,12 @@ public class PolynomialSpace<T, C>(
|
||||
Polynomial(coefficients.map { -it })
|
||||
}
|
||||
|
||||
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> {
|
||||
val dim = max(a.coefficients.size, b.coefficients.size)
|
||||
override fun add(left: Polynomial<T>, right: Polynomial<T>): Polynomial<T> {
|
||||
val dim = max(left.coefficients.size, right.coefficients.size)
|
||||
|
||||
return ring {
|
||||
Polynomial(List(dim) { index ->
|
||||
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero }
|
||||
left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero }
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ public object Euclidean2DSpace : GeometrySpace<Vector2D>, ScaleOperations<Vector
|
||||
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
|
||||
|
||||
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
|
||||
override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y)
|
||||
override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y)
|
||||
override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value)
|
||||
override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y
|
||||
}
|
||||
|
@ -47,8 +47,8 @@ public object Euclidean3DSpace : GeometrySpace<Vector3D>, ScaleOperations<Vector
|
||||
|
||||
override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm()
|
||||
|
||||
override fun add(a: Vector3D, b: Vector3D): Vector3D =
|
||||
Vector3D(a.x + b.x, a.y + b.y, a.z + b.z)
|
||||
override fun add(left: Vector3D, right: Vector3D): Vector3D =
|
||||
Vector3D(left.x + right.x, left.y + right.y, left.z + right.z)
|
||||
|
||||
override fun scale(a: Vector3D, value: Double): Vector3D =
|
||||
Vector3D(a.x * value, a.y * value, a.z * value)
|
||||
|
@ -28,10 +28,9 @@ public class DoubleHistogramSpace(
|
||||
|
||||
public val dimension: Int get() = lower.size
|
||||
|
||||
private val shape = IntArray(binNums.size) { binNums[it] + 2 }
|
||||
override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
|
||||
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
|
||||
|
||||
override val strides: Strides get() = histogramValueSpace.strides
|
||||
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||
|
||||
/**
|
||||
@ -52,7 +51,7 @@ public class DoubleHistogramSpace(
|
||||
val lowerBoundary = index.mapIndexed { axis, i ->
|
||||
when (i) {
|
||||
0 -> Double.NEGATIVE_INFINITY
|
||||
strides.shape[axis] - 1 -> upper[axis]
|
||||
shape[axis] - 1 -> upper[axis]
|
||||
else -> lower[axis] + (i.toDouble()) * binSize[axis]
|
||||
}
|
||||
}.asBuffer()
|
||||
@ -60,7 +59,7 @@ public class DoubleHistogramSpace(
|
||||
val upperBoundary = index.mapIndexed { axis, i ->
|
||||
when (i) {
|
||||
0 -> lower[axis]
|
||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||
shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
|
||||
}
|
||||
}.asBuffer()
|
||||
@ -75,7 +74,7 @@ public class DoubleHistogramSpace(
|
||||
}
|
||||
|
||||
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
||||
val ndCounter = StructureND.auto(strides) { Counter.real() }
|
||||
val ndCounter = StructureND.auto(shape) { Counter.real() }
|
||||
val hBuilder = HistogramBuilder<Double> { point, value ->
|
||||
val index = getIndex(point)
|
||||
ndCounter[index].add(value.toDouble())
|
||||
|
@ -8,8 +8,9 @@ package space.kscience.kmath.histogram
|
||||
import space.kscience.kmath.domains.Domain
|
||||
import space.kscience.kmath.linear.Point
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.DefaultStrides
|
||||
import space.kscience.kmath.nd.FieldND
|
||||
import space.kscience.kmath.nd.Strides
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.Group
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
@ -34,10 +35,10 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
||||
return context.produceBin(index, values[index])
|
||||
}
|
||||
|
||||
override val dimension: Int get() = context.strides.shape.size
|
||||
override val dimension: Int get() = context.shape.size
|
||||
|
||||
override val bins: Iterable<Bin<T>>
|
||||
get() = context.strides.indices().map {
|
||||
get() = DefaultStrides(context.shape).indices().map {
|
||||
context.produceBin(it, values[it])
|
||||
}.asIterable()
|
||||
|
||||
@ -49,7 +50,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
||||
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
||||
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
|
||||
//public val valueSpace: Space<V>
|
||||
public val strides: Strides
|
||||
public val shape: Shape
|
||||
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
|
||||
|
||||
/**
|
||||
@ -66,10 +67,10 @@ public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
||||
|
||||
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
|
||||
|
||||
override fun add(a: IndexedHistogram<T, V>, b: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
|
||||
require(a.context == this) { "Can't operate on a histogram produced by external space" }
|
||||
require(b.context == this) { "Can't operate on a histogram produced by external space" }
|
||||
return IndexedHistogram(this, histogramValueSpace { a.values + b.values })
|
||||
override fun add(left: IndexedHistogram<T, V>, right: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
|
||||
require(left.context == this) { "Can't operate on a histogram produced by external space" }
|
||||
require(right.context == this) { "Can't operate on a histogram produced by external space" }
|
||||
return IndexedHistogram(this, histogramValueSpace { left.values + right.values })
|
||||
}
|
||||
|
||||
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.histogram
|
||||
|
||||
import space.kscience.kmath.nd.DefaultStrides
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.real.DoubleVector
|
||||
import kotlin.random.Random
|
||||
@ -69,7 +70,7 @@ internal class MultivariateHistogramTest {
|
||||
}
|
||||
val res = histogram1 - histogram2
|
||||
assertTrue {
|
||||
strides.indices().all { index ->
|
||||
DefaultStrides(shape).indices().all { index ->
|
||||
res.values[index] <= histogram1.values[index]
|
||||
}
|
||||
}
|
||||
|
@ -88,20 +88,20 @@ public class TreeHistogramSpace(
|
||||
TreeHistogramBuilder(binFactory).apply(block).build()
|
||||
|
||||
override fun add(
|
||||
a: UnivariateHistogram,
|
||||
b: UnivariateHistogram,
|
||||
left: UnivariateHistogram,
|
||||
right: UnivariateHistogram,
|
||||
): UnivariateHistogram {
|
||||
// require(a.context == this) { "Histogram $a does not belong to this context" }
|
||||
// require(b.context == this) { "Histogram $b does not belong to this context" }
|
||||
val bins = TreeMap<Double, UnivariateBin>().apply {
|
||||
(a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def ->
|
||||
(left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def ->
|
||||
put(
|
||||
def.center,
|
||||
UnivariateBin(
|
||||
def,
|
||||
value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0),
|
||||
standardDeviation = (a[def.center]?.standardDeviation
|
||||
?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0)
|
||||
value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0),
|
||||
standardDeviation = (left[def.center]?.standardDeviation
|
||||
?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
@ -28,10 +28,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override inline fun add(a: Double, b: Double): Double = a + b
|
||||
override inline fun add(left: Double, right: Double): Double = left + right
|
||||
|
||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
||||
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||
|
||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||
|
||||
@ -57,10 +57,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
|
||||
override inline fun norm(arg: Double): Double = FastMath.abs(arg)
|
||||
|
||||
override inline fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.plus(b: Double): Double = this + b
|
||||
override inline fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.times(b: Double): Double = this * b
|
||||
override inline fun Double.div(b: Double): Double = this / b
|
||||
override inline fun Double.plus(other: Double): Double = this + other
|
||||
override inline fun Double.minus(other: Double): Double = this - other
|
||||
override inline fun Double.times(other: Double): Double = this * other
|
||||
override inline fun Double.div(other: Double): Double = this / other
|
||||
}
|
||||
|
||||
/**
|
||||
@ -79,10 +79,10 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
|
||||
else -> super<ExtendedField>.binaryOperationFunction(operation)
|
||||
}
|
||||
|
||||
override inline fun add(a: Double, b: Double): Double = a + b
|
||||
override inline fun add(left: Double, right: Double): Double = left + right
|
||||
|
||||
override inline fun multiply(a: Double, b: Double): Double = a * b
|
||||
override inline fun divide(a: Double, b: Double): Double = a / b
|
||||
override inline fun multiply(left: Double, right: Double): Double = left * right
|
||||
override inline fun divide(left: Double, right: Double): Double = left / right
|
||||
|
||||
override inline fun scale(a: Double, value: Double): Double = a * value
|
||||
|
||||
@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
|
||||
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
|
||||
|
||||
override inline fun Double.unaryMinus(): Double = -this
|
||||
override inline fun Double.plus(b: Double): Double = this + b
|
||||
override inline fun Double.minus(b: Double): Double = this - b
|
||||
override inline fun Double.times(b: Double): Double = this * b
|
||||
override inline fun Double.div(b: Double): Double = this / b
|
||||
override inline fun Double.plus(other: Double): Double = this + other
|
||||
override inline fun Double.minus(other: Double): Double = this - other
|
||||
override inline fun Double.times(other: Double): Double = this * other
|
||||
override inline fun Double.div(other: Double): Double = this / other
|
||||
}
|
||||
|
@ -106,8 +106,8 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
is Symbol -> toSVar()
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
||||
GroupOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
||||
GroupOps.PLUS_OPERATION -> +value.toSFun<X>()
|
||||
GroupOps.MINUS_OPERATION -> -value.toSFun<X>()
|
||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||
@ -124,10 +124,10 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||
GroupOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||
GroupOps.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||
GroupOps.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||
RingOps.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||
FieldOps.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
|
14
kmath-multik/build.gradle.kts
Normal file
14
kmath-multik/build.gradle.kts
Normal file
@ -0,0 +1,14 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.jvm")
|
||||
}
|
||||
|
||||
description = "JetBrains Multik connector"
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-tensors"))
|
||||
api("org.jetbrains.kotlinx:multik-default:0.1.0")
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
}
|
@ -0,0 +1,137 @@
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.math.cos
|
||||
import org.jetbrains.kotlinx.multik.api.math.sin
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.zeros
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.nd.FieldOpsND
|
||||
import space.kscience.kmath.nd.RingOpsND
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.*
|
||||
|
||||
/**
|
||||
* A ring algebra for Multik operations
|
||||
*/
|
||||
public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
|
||||
public val type: DataType,
|
||||
override val elementAlgebra: A
|
||||
) : RingOpsND<T, A> {
|
||||
|
||||
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
|
||||
|
||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||
val res = mk.zeros<T, DN>(shape, type).asDNArray()
|
||||
for (index in res.multiIndices) {
|
||||
res[index] = elementAlgebra.initializer(index)
|
||||
}
|
||||
return res.wrap()
|
||||
}
|
||||
|
||||
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
|
||||
this
|
||||
} else {
|
||||
structureND(shape) { get(it) }
|
||||
}
|
||||
|
||||
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {
|
||||
//taken directly from Multik sources
|
||||
val array = asMultik().array
|
||||
val data = initMemoryView<T>(array.size, type)
|
||||
var count = 0
|
||||
for (el in array) data[count++] = elementAlgebra.transform(el)
|
||||
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
|
||||
}
|
||||
|
||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> {
|
||||
//taken directly from Multik sources
|
||||
val array = asMultik().array
|
||||
val data = initMemoryView<T>(array.size, type)
|
||||
val indexIter = array.multiIndices.iterator()
|
||||
var index = 0
|
||||
for (item in array) {
|
||||
if (indexIter.hasNext()) {
|
||||
data[index++] = elementAlgebra.transform(indexIter.next(), item)
|
||||
} else {
|
||||
throw ArithmeticException("Index overflow has happened.")
|
||||
}
|
||||
}
|
||||
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
|
||||
}
|
||||
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
val leftArray = left.asMultik().array
|
||||
val rightArray = right.asMultik().array
|
||||
val data = initMemoryView<T>(leftArray.size, type)
|
||||
var counter = 0
|
||||
val leftIterator = leftArray.iterator()
|
||||
val rightIterator = rightArray.iterator()
|
||||
//iterating them together
|
||||
while (leftIterator.hasNext()) {
|
||||
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
|
||||
}
|
||||
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
|
||||
}
|
||||
|
||||
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = asMultik().array.unaryMinus().wrap()
|
||||
|
||||
override fun add(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
|
||||
(left.asMultik().array + right.asMultik().array).wrap()
|
||||
|
||||
override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
|
||||
asMultik().array.plus(arg).wrap()
|
||||
|
||||
override fun StructureND<T>.minus(arg: T): MultikTensor<T> = asMultik().array.minus(arg).wrap()
|
||||
|
||||
override fun T.plus(arg: StructureND<T>): MultikTensor<T> = arg + this
|
||||
|
||||
override fun T.minus(arg: StructureND<T>): MultikTensor<T> = arg.map { this@minus - it }
|
||||
|
||||
override fun multiply(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
|
||||
left.asMultik().array.times(right.asMultik().array).wrap()
|
||||
|
||||
override fun StructureND<T>.times(arg: T): MultikTensor<T> =
|
||||
asMultik().array.times(arg).wrap()
|
||||
|
||||
override fun T.times(arg: StructureND<T>): MultikTensor<T> = arg * this
|
||||
|
||||
override fun StructureND<T>.unaryPlus(): MultikTensor<T> = asMultik()
|
||||
|
||||
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
|
||||
asMultik().array.plus(other.asMultik().array).wrap()
|
||||
|
||||
override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
|
||||
asMultik().array.minus(other.asMultik().array).wrap()
|
||||
|
||||
override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
|
||||
asMultik().array.times(other.asMultik().array).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* A field algebra for multik operations
|
||||
*/
|
||||
public class MultikFieldOpsND<T, A : Field<T>> internal constructor(
|
||||
type: DataType,
|
||||
elementAlgebra: A
|
||||
) : MultikRingOpsND<T, A>(type, elementAlgebra), FieldOpsND<T, A> {
|
||||
override fun StructureND<T>.div(other: StructureND<T>): StructureND<T> =
|
||||
asMultik().array.div(other.asMultik().array).wrap()
|
||||
}
|
||||
|
||||
public val DoubleField.multikND: MultikFieldOpsND<Double, DoubleField>
|
||||
get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField)
|
||||
|
||||
public val FloatField.multikND: MultikFieldOpsND<Float, FloatField>
|
||||
get() = MultikFieldOpsND(DataType.FloatDataType, FloatField)
|
||||
|
||||
public val ShortRing.multikND: MultikRingOpsND<Short, ShortRing>
|
||||
get() = MultikRingOpsND(DataType.ShortDataType, ShortRing)
|
||||
|
||||
public val IntRing.multikND: MultikRingOpsND<Int, IntRing>
|
||||
get() = MultikRingOpsND(DataType.IntDataType, IntRing)
|
||||
|
||||
public val LongRing.multikND: MultikRingOpsND<Long, LongRing>
|
||||
get() = MultikRingOpsND(DataType.LongDataType, LongRing)
|
@ -0,0 +1,214 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.zeros
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.mapInPlace
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
|
||||
@JvmInline
|
||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||
override val shape: IntArray get() = array.shape
|
||||
|
||||
override fun get(index: IntArray): T = array[index]
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
array[index] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public class MultikTensorAlgebra<T> internal constructor(
|
||||
public val type: DataType,
|
||||
public val elementAlgebra: Ring<T>,
|
||||
public val comparator: Comparator<T>
|
||||
) : TensorAlgebra<T> {
|
||||
|
||||
/**
|
||||
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
|
||||
* are not reflected back onto the source
|
||||
*/
|
||||
public fun Tensor<T>.asMultik(): MultikTensor<T> {
|
||||
return if (this is MultikTensor) {
|
||||
this
|
||||
} else {
|
||||
val res = mk.zeros<T, DN>(shape, type).asDNArray()
|
||||
for (index in res.multiIndices) {
|
||||
res[index] = this[index]
|
||||
}
|
||||
res.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
||||
|
||||
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
|
||||
get(intArrayOf(0))
|
||||
} else null
|
||||
|
||||
override fun T.plus(other: Tensor<T>): MultikTensor<T> =
|
||||
other.plus(this)
|
||||
|
||||
override fun Tensor<T>.plus(value: T): MultikTensor<T> =
|
||||
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.plus(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.plusAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.plusAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.add(t, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.plusAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.add(t, other[index]) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun T.minus(other: Tensor<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap()
|
||||
|
||||
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
|
||||
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.minus(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.minusAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.minusAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.run { t - value } }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.minusAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.run { t - other[index] } }
|
||||
}
|
||||
}
|
||||
|
||||
override fun T.times(other: Tensor<T>): MultikTensor<T> =
|
||||
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
|
||||
|
||||
override fun Tensor<T>.times(value: T): Tensor<T> =
|
||||
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.times(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.timesAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.timesAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.multiply(t, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.timesAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.multiply(t, other[index]) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
|
||||
asMultik().array.unaryMinus().wrap()
|
||||
|
||||
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
|
||||
|
||||
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
||||
require(shape.all { it > 0 })
|
||||
require(shape.fold(1, Int::times) == this.shape.size) {
|
||||
"Cannot reshape array of size ${this.shape.size} into a new shape ${
|
||||
shape.joinToString(
|
||||
prefix = "(",
|
||||
postfix = ")"
|
||||
)
|
||||
}"
|
||||
}
|
||||
|
||||
val mt = asMultik().array
|
||||
return if (mt.shape.contentEquals(shape)) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
this as NDArray<T, DN>
|
||||
} else {
|
||||
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
|
||||
}.wrap()
|
||||
}
|
||||
|
||||
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
|
||||
|
||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||
TODO("Diagonal embedding not implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
||||
elementAlgebra.add(acc, t)
|
||||
}
|
||||
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.min(): T =
|
||||
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
|
||||
|
||||
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.max(): T =
|
||||
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
|
||||
|
||||
|
||||
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double>
|
||||
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) }
|
||||
|
||||
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float>
|
||||
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
|
||||
|
||||
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short>
|
||||
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) }
|
||||
|
||||
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int>
|
||||
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) }
|
||||
|
||||
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long>
|
||||
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) }
|
@ -0,0 +1,13 @@
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.junit.jupiter.api.Test
|
||||
import space.kscience.kmath.nd.one
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
|
||||
internal class MultikNDTest {
|
||||
@Test
|
||||
fun basicAlgebra(): Unit = DoubleField.multikND{
|
||||
one(2,2) + 1.0
|
||||
}
|
||||
}
|
@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
|
||||
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
|
||||
val arrayShape = array.shape().toIntArray()
|
||||
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
|
||||
return array
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents [AlgebraND] over [Nd4jArrayAlgebra].
|
||||
*
|
||||
@ -39,33 +32,35 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
||||
*/
|
||||
public val StructureND<T>.ndArray: INDArray
|
||||
|
||||
override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
override fun structureND(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
val struct = Nd4j.create(*shape)!!.wrap()
|
||||
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
||||
struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
|
||||
return struct
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
||||
val newStruct = ndArray.dup().wrap()
|
||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
|
||||
return newStruct
|
||||
}
|
||||
|
||||
override fun StructureND<T>.mapIndexed(
|
||||
transform: C.(index: IntArray, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) }
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) }
|
||||
return new
|
||||
}
|
||||
|
||||
override fun combine(
|
||||
a: StructureND<T>,
|
||||
b: StructureND<T>,
|
||||
override fun zip(
|
||||
left: StructureND<T>,
|
||||
right: StructureND<T>,
|
||||
transform: C.(T, T) -> T,
|
||||
): Nd4jArrayStructure<T> {
|
||||
val new = Nd4j.create(*shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
||||
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||
val new = Nd4j.create(*left.shape).wrap()
|
||||
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
||||
return new
|
||||
}
|
||||
}
|
||||
@ -76,16 +71,13 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||
|
||||
override val zero: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
override fun add(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
left.ndArray.add(right.ndArray).wrap()
|
||||
|
||||
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
a.ndArray.add(b.ndArray).wrap()
|
||||
|
||||
override operator fun StructureND<T>.minus(b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
ndArray.sub(b.ndArray).wrap()
|
||||
override operator fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
ndArray.sub(other.ndArray).wrap()
|
||||
|
||||
override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> =
|
||||
ndArray.neg().wrap()
|
||||
@ -101,13 +93,10 @@ public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4j
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> {
|
||||
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
|
||||
|
||||
override val one: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
|
||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
a.ndArray.mul(b.ndArray).wrap()
|
||||
override fun multiply(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
left.ndArray.mul(right.ndArray).wrap()
|
||||
//
|
||||
// override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
|
||||
// check(this)
|
||||
@ -125,21 +114,12 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
|
||||
// }
|
||||
|
||||
public companion object {
|
||||
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
/**
|
||||
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
|
||||
*/
|
||||
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
|
||||
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
|
||||
|
||||
/**
|
||||
* Creates a most suitable implementation of [RingND] using reified class.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when {
|
||||
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>>
|
||||
public inline fun <reified T : Number> auto(): Nd4jArrayRingOps<T, Ring<T>> = when {
|
||||
T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
|
||||
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
||||
}
|
||||
}
|
||||
@ -151,38 +131,21 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
|
||||
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
a.ndArray.div(b.ndArray).wrap()
|
||||
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
|
||||
|
||||
override fun divide(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
left.ndArray.div(right.ndArray).wrap()
|
||||
|
||||
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
|
||||
|
||||
public companion object {
|
||||
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
/**
|
||||
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
|
||||
*/
|
||||
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
|
||||
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
|
||||
|
||||
/**
|
||||
* Creates an [FieldND] for [Double] values or pull it from cache if it was created previously.
|
||||
*/
|
||||
public fun real(vararg shape: Int): Nd4jArrayRing<Double, DoubleField> =
|
||||
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
|
||||
|
||||
/**
|
||||
* Creates a most suitable implementation of [FieldND] using reified class.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
|
||||
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>>
|
||||
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>>
|
||||
public inline fun <reified T : Any> auto(): Nd4jArrayField<T, Field<T>> = when {
|
||||
T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||
T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||
}
|
||||
}
|
||||
@ -191,8 +154,9 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4
|
||||
/**
|
||||
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
|
||||
*/
|
||||
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>,
|
||||
Nd4jArrayField<T, F> {
|
||||
public sealed interface Nd4jArrayExtendedFieldOps<T, out F : ExtendedField<T>> :
|
||||
ExtendedFieldOps<StructureND<T>>, Nd4jArrayField<T, F> {
|
||||
|
||||
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
|
||||
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
|
||||
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
|
||||
@ -221,63 +185,59 @@ public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : Ex
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
||||
*/
|
||||
public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> {
|
||||
override val elementContext: DoubleField get() = DoubleField
|
||||
public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Double, DoubleField> {
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Double>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Double> -> checkShape(ndArray)
|
||||
is Nd4jArrayStructure<Double> -> ndArray
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> {
|
||||
return a.ndArray.mul(value).wrap()
|
||||
}
|
||||
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> = a.ndArray.mul(value).wrap()
|
||||
|
||||
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
||||
return ndArray.div(arg).wrap()
|
||||
}
|
||||
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> = ndArray.div(arg).wrap()
|
||||
|
||||
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> {
|
||||
return ndArray.add(arg).wrap()
|
||||
}
|
||||
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> = ndArray.add(arg).wrap()
|
||||
|
||||
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> {
|
||||
return ndArray.sub(arg).wrap()
|
||||
}
|
||||
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> = ndArray.sub(arg).wrap()
|
||||
|
||||
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> {
|
||||
return ndArray.mul(arg).wrap()
|
||||
}
|
||||
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> = ndArray.mul(arg).wrap()
|
||||
|
||||
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
||||
return arg.ndArray.rdiv(this).wrap()
|
||||
}
|
||||
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
|
||||
arg.ndArray.rdiv(this).wrap()
|
||||
|
||||
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
|
||||
public companion object : DoubleNd4jArrayFieldOps()
|
||||
}
|
||||
|
||||
public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape))
|
||||
public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps
|
||||
|
||||
public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND<Double, DoubleField>
|
||||
|
||||
public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField =
|
||||
DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
|
||||
|
||||
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
||||
*/
|
||||
public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> {
|
||||
override val elementContext: FloatField get() = FloatField
|
||||
public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Float, FloatField> {
|
||||
override val elementAlgebra: FloatField get() = FloatField
|
||||
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = asFloatStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Float>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Float> -> checkShape(ndArray)
|
||||
is Nd4jArrayStructure<Float> -> ndArray
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
@ -303,21 +263,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend
|
||||
|
||||
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
|
||||
public companion object : FloatNd4jArrayFieldOps()
|
||||
}
|
||||
|
||||
public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND<Float, FloatField>
|
||||
|
||||
public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps
|
||||
|
||||
public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField =
|
||||
FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
|
||||
|
||||
/**
|
||||
* Represents [RingND] over [Nd4jArrayIntStructure].
|
||||
*/
|
||||
public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> {
|
||||
override val elementContext: IntRing
|
||||
get() = IntRing
|
||||
public open class IntNd4jArrayRingOps : Nd4jArrayRingOps<Int, IntRing> {
|
||||
override val elementAlgebra: IntRing get() = IntRing
|
||||
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = asIntStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Int>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Int> -> checkShape(ndArray)
|
||||
is Nd4jArrayStructure<Int> -> ndArray
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
@ -334,4 +302,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int,
|
||||
|
||||
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
|
||||
public companion object : IntNd4jArrayRingOps()
|
||||
}
|
||||
|
||||
public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps
|
||||
|
||||
public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND<Int, IntRing>
|
||||
|
||||
public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing =
|
||||
IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest))
|
@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.one
|
||||
import space.kscience.kmath.nd.structureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import kotlin.math.PI
|
||||
import kotlin.test.Test
|
||||
@ -19,7 +23,7 @@ import kotlin.test.fail
|
||||
internal class Nd4jArrayAlgebraTest {
|
||||
@Test
|
||||
fun testProduce() {
|
||||
val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
||||
val res = DoubleField.nd4j.structureND(2, 2) { it.sum().toDouble() }
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
|
||||
expected[intArrayOf(0, 0)] = 0.0
|
||||
expected[intArrayOf(0, 1)] = 1.0
|
||||
@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest {
|
||||
|
||||
@Test
|
||||
fun testMap() {
|
||||
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } }
|
||||
val res = IntRing.nd4j {
|
||||
one(2, 2).map { it + it * 2 }
|
||||
}
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||
expected[intArrayOf(0, 0)] = 3
|
||||
expected[intArrayOf(0, 1)] = 3
|
||||
@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest {
|
||||
|
||||
@Test
|
||||
fun testAdd() {
|
||||
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 }
|
||||
val res = IntRing.nd4j { one(2, 2) + 25 }
|
||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||
expected[intArrayOf(0, 0)] = 26
|
||||
expected[intArrayOf(0, 1)] = 26
|
||||
@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
|
||||
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 }
|
||||
fun testSin() = DoubleField.nd4j{
|
||||
val initial = structureND(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
|
||||
val transformed = sin(initial)
|
||||
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||
val expected = structureND(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||
|
||||
println(transformed)
|
||||
assertTrue { StructureND.contentEquals(transformed, expected) }
|
||||
|
@ -41,8 +41,8 @@ public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler
|
||||
|
||||
override val zero: Sampler<T> = ConstantSampler(algebra.zero)
|
||||
|
||||
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
|
||||
override fun add(left: Sampler<T>, right: Sampler<T>): Sampler<T> = BasicSampler { generator ->
|
||||
left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
|
||||
}
|
||||
|
||||
override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator ->
|
||||
|
@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) {
|
||||
}
|
||||
|
||||
is MST.Unary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> value.toIExpr()
|
||||
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
||||
GroupOps.PLUS_OPERATION -> value.toIExpr()
|
||||
GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
||||
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
|
||||
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
|
||||
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
|
||||
@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) {
|
||||
}
|
||||
|
||||
is MST.Binary -> when (operation) {
|
||||
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
||||
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
||||
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
||||
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
||||
GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
||||
GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
||||
RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
||||
FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
||||
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
|
||||
else -> error("Binary operation $operation not defined in $this")
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.operations.Algebra
|
||||
import space.kscience.kmath.operations.RingOps
|
||||
|
||||
/**
|
||||
* Algebra over a ring on [Tensor].
|
||||
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
|
||||
*
|
||||
* @param T the type of items in the tensors.
|
||||
*/
|
||||
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||
*
|
||||
@ -53,7 +53,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
* @param other tensor to be added.
|
||||
* @return the sum of this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
|
||||
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Adds the scalar [value] to each element of this tensor.
|
||||
@ -93,7 +93,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
* @param other tensor to be subtracted.
|
||||
* @return the difference between this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
|
||||
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Subtracts the scalar [value] from each element of this tensor.
|
||||
@ -134,7 +134,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
* @param other tensor to be multiplied.
|
||||
* @return the product of this tensor and [other].
|
||||
*/
|
||||
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
|
||||
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
|
||||
|
||||
/**
|
||||
* Multiplies the scalar [value] by each element of this tensor.
|
||||
@ -155,7 +155,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
*
|
||||
* @return tensor negation of the original tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.unaryMinus(): Tensor<T>
|
||||
override fun Tensor<T>.unaryMinus(): Tensor<T>
|
||||
|
||||
/**
|
||||
* Returns the tensor at index i
|
||||
@ -323,4 +323,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
|
||||
*/
|
||||
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right
|
||||
|
||||
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ import kotlin.math.*
|
||||
public open class DoubleTensorAlgebra :
|
||||
TensorPartialDivisionAlgebra<Double>,
|
||||
AnalyticTensorAlgebra<Double>,
|
||||
LinearOpsTensorAlgebra<Double> {
|
||||
LinearOpsTensorAlgebra<Double>{
|
||||
|
||||
public companion object : DoubleTensorAlgebra()
|
||||
|
||||
@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra :
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
|
||||
DoubleTensor {
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<Double>,
|
||||
offset: Int,
|
||||
dim1: Int,
|
||||
dim2: Int
|
||||
): DoubleTensor {
|
||||
val n = diagonalEntries.shape.size
|
||||
val d1 = minusIndexFrom(n + 1, dim1)
|
||||
val d2 = minusIndexFrom(n + 1, dim2)
|
||||
|
@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
|
||||
*
|
||||
* @param shape the shape of the tensor.
|
||||
*/
|
||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||
override val strides: IntArray
|
||||
get() = stridesFromShape(shape)
|
||||
|
||||
@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other == null || this::class != other::class) return false
|
||||
|
||||
other as TensorLinearStructure
|
||||
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
return shape.contentHashCode()
|
||||
}
|
||||
}
|
||||
|
@ -26,8 +26,11 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
||||
|
||||
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||
is BufferedTensor<T> -> this
|
||||
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
|
||||
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
||||
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
|
||||
BufferedTensor(this.shape, this.buffer, 0)
|
||||
} else {
|
||||
this.copyToBufferedTensor()
|
||||
}
|
||||
else -> this.copyToBufferedTensor()
|
||||
}
|
||||
|
||||
|
@ -5,4 +5,12 @@
|
||||
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
public fun DoubleTensorAlgebra.ones(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import kotlin.jvm.JvmName
|
||||
|
||||
@JvmName("varArgOne")
|
||||
public fun DoubleTensorAlgebra.one(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
|
||||
public fun DoubleTensorAlgebra.one(shape: Shape): DoubleTensor = ones(shape)
|
||||
@JvmName("varArgZero")
|
||||
public fun DoubleTensorAlgebra.zero(vararg shape: Int): DoubleTensor = zeros(intArrayOf(*shape))
|
||||
public fun DoubleTensorAlgebra.zero(shape: Shape): DoubleTensor = zeros(shape)
|
@ -0,0 +1,124 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.viktor
|
||||
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.ExtendedFieldOps
|
||||
import space.kscience.kmath.operations.NumbersAddOps
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public open class ViktorFieldOpsND :
|
||||
FieldOpsND<Double, DoubleField>,
|
||||
ExtendedFieldOps<StructureND<Double>> {
|
||||
|
||||
public val StructureND<Double>.f64Buffer: F64Array
|
||||
get() = when (this) {
|
||||
is ViktorStructureND -> this.f64Buffer
|
||||
else -> structureND(shape) { this@f64Buffer[it] }.f64Buffer
|
||||
}
|
||||
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
||||
F64Array(*shape).apply {
|
||||
DefaultStrides(shape).indices().forEach { index ->
|
||||
set(value = DoubleField.initializer(index), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun StructureND<Double>.unaryMinus(): StructureND<Double> = -1 * this
|
||||
|
||||
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): ViktorStructureND =
|
||||
F64Array(*shape).apply {
|
||||
DefaultStrides(shape).indices().forEach { index ->
|
||||
set(value = DoubleField.transform(this@map[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun StructureND<Double>.mapIndexed(
|
||||
transform: DoubleField.(index: IntArray, Double) -> Double,
|
||||
): ViktorStructureND = F64Array(*shape).apply {
|
||||
DefaultStrides(shape).indices().forEach { index ->
|
||||
set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun zip(
|
||||
left: StructureND<Double>,
|
||||
right: StructureND<Double>,
|
||||
transform: DoubleField.(Double, Double) -> Double,
|
||||
): ViktorStructureND {
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
return F64Array(*left.shape).apply {
|
||||
DefaultStrides(left.shape).indices().forEach { index ->
|
||||
set(value = DoubleField.transform(left[index], right[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
}
|
||||
|
||||
override fun add(left: StructureND<Double>, right: StructureND<Double>): ViktorStructureND =
|
||||
(left.f64Buffer + right.f64Buffer).asStructure()
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
|
||||
(a.f64Buffer * value).asStructure()
|
||||
|
||||
override fun StructureND<Double>.plus(other: StructureND<Double>): ViktorStructureND =
|
||||
(f64Buffer + other.f64Buffer).asStructure()
|
||||
|
||||
override fun StructureND<Double>.minus(other: StructureND<Double>): ViktorStructureND =
|
||||
(f64Buffer - other.f64Buffer).asStructure()
|
||||
|
||||
override fun StructureND<Double>.times(k: Number): ViktorStructureND =
|
||||
(f64Buffer * k.toDouble()).asStructure()
|
||||
|
||||
override fun StructureND<Double>.plus(arg: Double): ViktorStructureND =
|
||||
(f64Buffer.plus(arg)).asStructure()
|
||||
|
||||
override fun sin(arg: StructureND<Double>): ViktorStructureND = arg.map { sin(it) }
|
||||
override fun cos(arg: StructureND<Double>): ViktorStructureND = arg.map { cos(it) }
|
||||
override fun tan(arg: StructureND<Double>): ViktorStructureND = arg.map { tan(it) }
|
||||
override fun asin(arg: StructureND<Double>): ViktorStructureND = arg.map { asin(it) }
|
||||
override fun acos(arg: StructureND<Double>): ViktorStructureND = arg.map { acos(it) }
|
||||
override fun atan(arg: StructureND<Double>): ViktorStructureND = arg.map { atan(it) }
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): ViktorStructureND = arg.map { it.pow(pow) }
|
||||
|
||||
override fun exp(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.exp().asStructure()
|
||||
|
||||
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
|
||||
|
||||
override fun sinh(arg: StructureND<Double>): ViktorStructureND = arg.map { sinh(it) }
|
||||
|
||||
override fun cosh(arg: StructureND<Double>): ViktorStructureND = arg.map { cosh(it) }
|
||||
|
||||
override fun asinh(arg: StructureND<Double>): ViktorStructureND = arg.map { asinh(it) }
|
||||
|
||||
override fun acosh(arg: StructureND<Double>): ViktorStructureND = arg.map { acosh(it) }
|
||||
|
||||
override fun atanh(arg: StructureND<Double>): ViktorStructureND = arg.map { atanh(it) }
|
||||
|
||||
public companion object : ViktorFieldOpsND()
|
||||
}
|
||||
|
||||
public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND
|
||||
|
||||
public open class ViktorFieldND(
|
||||
override val shape: Shape
|
||||
) : ViktorFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
||||
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
|
||||
override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() }
|
||||
|
||||
override fun number(value: Number): ViktorStructureND =
|
||||
F64Array.full(init = value.toDouble(), shape = shape).asStructure()
|
||||
}
|
||||
|
||||
public fun DoubleField.viktorAlgebra(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
|
||||
|
||||
public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
|
@ -7,12 +7,8 @@ package space.kscience.kmath.viktor
|
||||
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.ExtendedField
|
||||
import space.kscience.kmath.operations.NumbersAddOperations
|
||||
import space.kscience.kmath.operations.ScaleOperations
|
||||
import space.kscience.kmath.nd.DefaultStrides
|
||||
import space.kscience.kmath.nd.MutableStructureND
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND<Double> {
|
||||
@ -31,96 +27,4 @@ public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructur
|
||||
|
||||
public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this)
|
||||
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
||||
NumbersAddOperations<StructureND<Double>>, ExtendedField<StructureND<Double>>,
|
||||
ScaleOperations<StructureND<Double>> {
|
||||
|
||||
public val StructureND<Double>.f64Buffer: F64Array
|
||||
get() = when {
|
||||
!shape.contentEquals(this@ViktorFieldND.shape) -> throw ShapeMismatchException(
|
||||
this@ViktorFieldND.shape,
|
||||
shape
|
||||
)
|
||||
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
|
||||
else -> produce { this@f64Buffer[it] }.f64Buffer
|
||||
}
|
||||
|
||||
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
|
||||
override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() }
|
||||
|
||||
private val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
override val elementContext: DoubleField get() = DoubleField
|
||||
|
||||
override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
||||
F64Array(*shape).apply {
|
||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||
set(value = DoubleField.initializer(index), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun StructureND<Double>.unaryMinus(): StructureND<Double> = -1 * this
|
||||
|
||||
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): ViktorStructureND =
|
||||
F64Array(*this@ViktorFieldND.shape).apply {
|
||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||
set(value = DoubleField.transform(this@map[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun StructureND<Double>.mapIndexed(
|
||||
transform: DoubleField.(index: IntArray, Double) -> Double,
|
||||
): ViktorStructureND = F64Array(*this@ViktorFieldND.shape).apply {
|
||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||
set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun combine(
|
||||
a: StructureND<Double>,
|
||||
b: StructureND<Double>,
|
||||
transform: DoubleField.(Double, Double) -> Double,
|
||||
): ViktorStructureND = F64Array(*shape).apply {
|
||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||
set(value = DoubleField.transform(a[index], b[index]), indices = index)
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
override fun add(a: StructureND<Double>, b: StructureND<Double>): ViktorStructureND =
|
||||
(a.f64Buffer + b.f64Buffer).asStructure()
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
|
||||
(a.f64Buffer * value).asStructure()
|
||||
|
||||
override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND =
|
||||
(f64Buffer + b.f64Buffer).asStructure()
|
||||
|
||||
override inline fun StructureND<Double>.minus(b: StructureND<Double>): ViktorStructureND =
|
||||
(f64Buffer - b.f64Buffer).asStructure()
|
||||
|
||||
override inline fun StructureND<Double>.times(k: Number): ViktorStructureND =
|
||||
(f64Buffer * k.toDouble()).asStructure()
|
||||
|
||||
override inline fun StructureND<Double>.plus(arg: Double): ViktorStructureND =
|
||||
(f64Buffer.plus(arg)).asStructure()
|
||||
|
||||
override fun number(value: Number): ViktorStructureND =
|
||||
F64Array.full(init = value.toDouble(), shape = shape).asStructure()
|
||||
|
||||
override fun sin(arg: StructureND<Double>): ViktorStructureND = arg.map { sin(it) }
|
||||
override fun cos(arg: StructureND<Double>): ViktorStructureND = arg.map { cos(it) }
|
||||
override fun tan(arg: StructureND<Double>): ViktorStructureND = arg.map { tan(it) }
|
||||
override fun asin(arg: StructureND<Double>): ViktorStructureND = arg.map { asin(it) }
|
||||
override fun acos(arg: StructureND<Double>): ViktorStructureND = arg.map { acos(it) }
|
||||
override fun atan(arg: StructureND<Double>): ViktorStructureND = arg.map { atan(it) }
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): ViktorStructureND = arg.map { it.pow(pow) }
|
||||
|
||||
override fun exp(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.exp().asStructure()
|
||||
|
||||
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
|
||||
}
|
||||
|
||||
public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
|
||||
|
@ -1,16 +1,18 @@
|
||||
pluginManagement {
|
||||
repositories {
|
||||
mavenLocal()
|
||||
maven("https://repo.kotlin.link")
|
||||
mavenCentral()
|
||||
gradlePluginPortal()
|
||||
}
|
||||
|
||||
val kotlinVersion = "1.6.0-M1"
|
||||
val kotlinVersion = "1.6.0-RC"
|
||||
val toolsVersion = "0.10.5"
|
||||
|
||||
plugins {
|
||||
id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
|
||||
id("ru.mipt.npm.gradle.project") version "0.10.5"
|
||||
id("ru.mipt.npm.gradle.project") version toolsVersion
|
||||
id("ru.mipt.npm.gradle.jvm") version toolsVersion
|
||||
id("ru.mipt.npm.gradle.mpp") version toolsVersion
|
||||
kotlin("multiplatform") version kotlinVersion
|
||||
kotlin("plugin.allopen") version kotlinVersion
|
||||
}
|
||||
@ -30,6 +32,7 @@ include(
|
||||
":kmath-histograms",
|
||||
":kmath-commons",
|
||||
":kmath-viktor",
|
||||
":kmath-multik",
|
||||
":kmath-optimization",
|
||||
":kmath-stat",
|
||||
":kmath-nd4j",
|
||||
|
Loading…
Reference in New Issue
Block a user