Merge pull request #419 from mipt-npm/feature/multik

Feature/multik
This commit is contained in:
Alexander Nozik 2021-10-18 13:06:34 +03:00 committed by GitHub
commit ae8655d6af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 1868 additions and 1281 deletions

View File

@ -1,6 +0,0 @@
<component name="CopyrightManager">
<copyright>
<option name="notice" value="Copyright 2018-2021 KMath contributors.&#10;Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file." />
<option name="myName" value="kmath" />
</copyright>
</component>

View File

@ -1,21 +0,0 @@
<component name="CopyrightManager">
<settings default="kmath">
<module2copyright>
<element module="Apply copyright" copyright="kmath" />
</module2copyright>
<LanguageOptions name="Groovy">
<option name="fileTypeOverride" value="1" />
</LanguageOptions>
<LanguageOptions name="HTML">
<option name="fileTypeOverride" value="1" />
<option name="prefixLines" value="false" />
</LanguageOptions>
<LanguageOptions name="Properties">
<option name="fileTypeOverride" value="1" />
</LanguageOptions>
<LanguageOptions name="XML">
<option name="fileTypeOverride" value="1" />
<option name="prefixLines" value="false" />
</LanguageOptions>
</settings>
</component>

View File

@ -1,4 +0,0 @@
<component name="DependencyValidationManager">
<scope name="Apply copyright"
pattern="!file[*]:*//testData//*&amp;&amp;!file[*]:testData//*&amp;&amp;!file[*]:*.gradle.kts&amp;&amp;!file[*]:*.gradle&amp;&amp;!file[group:kotlin-ultimate]:*/&amp;&amp;!file[kotlin.libraries]:stdlib/api//*"/>
</component>

View File

@ -42,6 +42,9 @@
- Use `Symbol` factory function instead of `StringSymbol` - Use `Symbol` factory function instead of `StringSymbol`
- New discoverability pattern: `<Type>.algebra.<nd/etc>` - New discoverability pattern: `<Type>.algebra.<nd/etc>`
- Adjusted commons-math API for linear solvers to match conventions. - Adjusted commons-math API for linear solvers to match conventions.
- Buffer algebra does not require size anymore
- Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`

View File

@ -48,6 +48,7 @@ kotlin {
implementation(project(":kmath-nd4j")) implementation(project(":kmath-nd4j"))
implementation(project(":kmath-kotlingrad")) implementation(project(":kmath-kotlingrad"))
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(projects.kmathMultik)
implementation("org.nd4j:nd4j-native:1.0.0-M1") implementation("org.nd4j:nd4j-native:1.0.0-M1")
// uncomment if your system supports AVX2 // uncomment if your system supports AVX2
// val os = System.getProperty("os.name") // val os = System.getProperty("os.name")

View File

@ -9,56 +9,85 @@ import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ones
import org.jetbrains.kotlinx.multik.ndarray.data.DN
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.multik.multikND
import space.kscience.kmath.multik.multikTensorAlgebra
import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd4j.nd4j import space.kscience.kmath.nd4j.nd4j
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.ones import space.kscience.kmath.tensors.core.one
import space.kscience.kmath.tensors.core.tensorAlgebra import space.kscience.kmath.tensors.core.tensorAlgebra
import space.kscience.kmath.viktor.viktorAlgebra
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class NDFieldBenchmark { internal class NDFieldBenchmark {
@Benchmark @Benchmark
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) { fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += one } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) { fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) { fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@Benchmark
fun multikAdd(blackhole: Blackhole) = with(multikField) {
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@Benchmark
fun viktorAdd(blackhole: Blackhole) = with(viktorField) {
var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) { fun tensorAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
var res: DoubleTensor = ones(dim, dim) var res: DoubleTensor = one(shape)
repeat(n) { res = res + 1.0 } repeat(n) { res = res + 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark @Benchmark
fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) { fun tensorInPlaceAdd(blackhole: Blackhole) = with(Double.tensorAlgebra) {
val res: DoubleTensor = ones(dim, dim) val res: DoubleTensor = one(shape)
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
@Benchmark
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) {
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
// @Benchmark // @Benchmark
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { // fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
// var res: StructureND<Double> = one // var res: StructureND<Double> = one(dim, dim)
// repeat(n) { res += 1.0 } // repeat(n) { res += 1.0 }
// blackhole.consume(res) // blackhole.consume(res)
// } // }
@ -66,9 +95,12 @@ internal class NDFieldBenchmark {
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val shape = intArrayOf(dim, dim)
private val specializedField = DoubleField.ndAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) private val specializedField = DoubleField.ndAlgebra
private val nd4jField = DoubleField.nd4j(dim, dim) private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
private val nd4jField = DoubleField.nd4j
private val multikField = DoubleField.multikND
private val viktorField = DoubleField.viktorAlgebra
} }
} }

View File

@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.autoNdAlgebra
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.viktor.ViktorNDField import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class ViktorBenchmark { internal class ViktorBenchmark {
@Benchmark @Benchmark
fun automaticFieldAddition(blackhole: Blackhole) { fun automaticFieldAddition(blackhole: Blackhole) {
with(autoField) { with(autoField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -30,7 +29,7 @@ internal class ViktorBenchmark {
@Benchmark @Benchmark
fun realFieldAddition(blackhole: Blackhole) { fun realFieldAddition(blackhole: Blackhole) {
with(realField) { with(realField) {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -39,7 +38,7 @@ internal class ViktorBenchmark {
@Benchmark @Benchmark
fun viktorFieldAddition(blackhole: Blackhole) { fun viktorFieldAddition(blackhole: Blackhole) {
with(viktorField) { with(viktorField) {
var res = one var res = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
} }
@ -56,10 +55,11 @@ internal class ViktorBenchmark {
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realField = DoubleField.ndAlgebra(dim, dim) private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorNDField(dim, dim) private val viktorField = ViktorFieldND(dim, dim)
} }
} }

View File

@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.nd.autoNdAlgebra import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.viktor.ViktorFieldND import space.kscience.kmath.viktor.ViktorFieldND
@State(Scope.Benchmark) @State(Scope.Benchmark)
internal class ViktorLogBenchmark { internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun realFieldLog(blackhole: Blackhole) { fun realFieldLog(blackhole: Blackhole) {
with(realNdField) { with(realField) {
val fortyTwo = produce { 42.0 } val fortyTwo = structureND(shape) { 42.0 }
var res = one var res = one(shape)
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)
} }
@ -30,7 +33,7 @@ internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun viktorFieldLog(blackhole: Blackhole) { fun viktorFieldLog(blackhole: Blackhole) {
with(viktorField) { with(viktorField) {
val fortyTwo = produce { 42.0 } val fortyTwo = structureND(shape) { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)
@ -48,10 +51,11 @@ internal class ViktorLogBenchmark {
private companion object { private companion object {
private const val dim = 1000 private const val dim = 1000
private const val n = 100 private const val n = 100
private val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
private val autoField = DoubleField.autoNdAlgebra(dim, dim) private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
private val realNdField = DoubleField.ndAlgebra(dim, dim) private val realField = DoubleField.ndAlgebra
private val viktorField = ViktorFieldND(intArrayOf(dim, dim)) private val viktorField = ViktorFieldND(dim, dim)
} }
} }

View File

@ -29,6 +29,11 @@ dependencies {
implementation(project(":kmath-tensors")) implementation(project(":kmath-tensors"))
implementation(project(":kmath-symja")) implementation(project(":kmath-symja"))
implementation(project(":kmath-for-real")) implementation(project(":kmath-for-real"))
//jafama
implementation(project(":kmath-jafama"))
//multik
implementation(projects.kmathMultik)
implementation("org.nd4j:nd4j-native:1.0.0-beta7") implementation("org.nd4j:nd4j-native:1.0.0-beta7")
@ -42,11 +47,12 @@ dependencies {
// } else // } else
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
implementation("org.slf4j:slf4j-simple:1.7.31") // multik implementation
implementation("org.jetbrains.kotlinx:multik-default:0.1.0")
implementation("org.slf4j:slf4j-simple:1.7.32")
// plotting // plotting
implementation("space.kscience:plotlykt-server:0.4.2") implementation("space.kscience:plotlykt-server:0.5.0")
//jafama
implementation(project(":kmath-jafama"))
} }
kotlin.sourceSets.all { kotlin.sourceSets.all {

View File

@ -9,6 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator
import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value import space.kscience.kmath.integration.value
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.structureND
import space.kscience.kmath.nd.withNdAlgebra import space.kscience.kmath.nd.withNdAlgebra
import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.algebra
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -17,7 +18,7 @@ fun main(): Unit = Double.algebra {
withNdAlgebra(2, 2) { withNdAlgebra(2, 2) {
//Produce a diagonal StructureND //Produce a diagonal StructureND
fun diagonal(v: Double) = produce { (i, j) -> fun diagonal(v: Double) = structureND { (i, j) ->
if (i == j) v else 0.0 if (i == j) v else 0.0
} }

View File

@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra
import space.kscience.kmath.complex.ndAlgebra import space.kscience.kmath.complex.ndAlgebra
import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.structureND
fun main() = Complex.algebra { fun main() = Complex.algebra {
val complex = 2 + 2 * i val complex = 2 + 2 * i
println(complex * 8 - 5 * i) println(complex * 8 - 5 * i)
//flat buffer //flat buffer
val buffer = bufferAlgebra(8).run { val buffer = with(bufferAlgebra){
buffer { Complex(it, -it) }.map { Complex(it.im, it.re) } buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) }
} }
println(buffer) println(buffer)
// 2d element // 2d element
val element: BufferND<Complex> = ndAlgebra(2, 2).produce { (i, j) -> val element: BufferND<Complex> = ndAlgebra.structureND(2, 2) { (i, j) ->
Complex(i - j, i + j) Complex(i - j, i + j)
} }
println(element) println(element)
// 1d element operation // 1d element operation
val result: StructureND<Complex> = ndAlgebra(8).run { val result: StructureND<Complex> = ndAlgebra{
val a = produce { (it) -> i * it - it.toDouble() } val a = structureND(8) { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)

View File

@ -0,0 +1,24 @@
package space.kscience.kmath.operations
import space.kscience.kmath.commons.linear.CMLinearSpace
import space.kscience.kmath.linear.matrix
import space.kscience.kmath.nd.DoubleBufferND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.viktor.ViktorStructureND
import space.kscience.kmath.viktor.viktorAlgebra
fun main() {
val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.structureND(Shape(2, 2)) { (i, j) ->
if (i == j) 2.0 else 0.0
}
val cmMatrix: Structure2D<Double> = CMLinearSpace.matrix(2, 2)(0.0, 1.0, 0.0, 3.0)
val res: DoubleBufferND = DoubleField.ndAlgebra {
exp(viktorStructure) + 2.0 * cmMatrix
}
println(res)
}

View File

@ -12,6 +12,7 @@ import space.kscience.kmath.linear.transpose
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -54,7 +55,7 @@ fun complexExample() {
val x = one * 2.5 val x = one * 2.5
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
//a structure generator specific to this context //a structure generator specific to this context
val matrix = produce { (k, l) -> k + l * i } val matrix = structureND { (k, l) -> k + l * i }
//Perform sum //Perform sum
val sum = matrix + x + 1.0 val sum = matrix + x + 1.0

View File

@ -8,13 +8,11 @@ package space.kscience.kmath.structures
import kotlinx.coroutines.DelicateCoroutinesApi import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.autoNdAlgebra import space.kscience.kmath.nd4j.nd4j
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd4j.Nd4jArrayField
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.viktor.ViktorNDField import space.kscience.kmath.viktor.ViktorFieldND
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -31,37 +29,39 @@ fun main() {
Nd4j.zeros(0) Nd4j.zeros(0)
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
val shape = Shape(dim, dim)
// automatically build context most suited for given type. // automatically build context most suited for given type.
val autoField = DoubleField.autoNdAlgebra(dim, dim) val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
// specialized nd-field for Double. It works as generic Double field as well. // specialized nd-field for Double. It works as generic Double field as well.
val realField = DoubleField.ndAlgebra(dim, dim) val realField = DoubleField.ndAlgebra
//A generic boxing field. It should be used for objects, not primitives. //A generic boxing field. It should be used for objects, not primitives.
val boxingField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim) val boxingField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
// Nd4j specialized field. // Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim) val nd4jField = DoubleField.nd4j
//viktor field //viktor field
val viktorField = ViktorNDField(dim, dim) val viktorField = ViktorFieldND(dim, dim)
//parallel processing based on Java Streams //parallel processing based on Java Streams
val parallelField = DoubleField.ndStreaming(dim, dim) val parallelField = DoubleField.ndStreaming(dim, dim)
measureAndPrint("Boxing addition") { measureAndPrint("Boxing addition") {
boxingField { boxingField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Specialized addition") { measureAndPrint("Specialized addition") {
realField { realField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Nd4j specialized addition") { measureAndPrint("Nd4j specialized addition") {
nd4jField { nd4jField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
@ -82,13 +82,13 @@ fun main() {
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField { autoField {
var res: StructureND<Double> = one var res: StructureND<Double> = one(shape)
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
} }
} }
measureAndPrint("Lazy addition") { measureAndPrint("Lazy addition") {
val res = realField.one.mapAsync(GlobalScope) { val res = realField.one(shape).mapAsync(GlobalScope) {
var c = 0.0 var c = 0.0
repeat(n) { repeat(n) {
c += 1.0 c += 1.0

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
import java.util.* import java.util.*
import java.util.stream.IntStream import java.util.stream.IntStream
@ -17,17 +17,17 @@ import java.util.stream.IntStream
* execution. * execution.
*/ */
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>, class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>, NumbersAddOps<StructureND<Double>>,
ExtendedField<StructureND<Double>> { ExtendedField<StructureND<Double>> {
private val strides = DefaultStrides(shape) private val strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override val zero: BufferND<Double> by lazy { produce { zero } } override val zero: BufferND<Double> by lazy { structureND(shape) { zero } }
override val one: BufferND<Double> by lazy { produce { one } } override val one: BufferND<Double> by lazy { structureND(shape) { one } }
override fun number(value: Number): BufferND<Double> { override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce { d } return structureND(shape) { d }
} }
private val StructureND<Double>.buffer: DoubleBuffer private val StructureND<Double>.buffer: DoubleBuffer
@ -36,11 +36,11 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
this@StreamDoubleFieldND.shape, this@StreamDoubleFieldND.shape,
shape shape
) )
this is BufferND && this.strides == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }
override fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> { override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
val index = strides.index(offset) val index = strides.index(offset)
DoubleField.initializer(index) DoubleField.initializer(index)
@ -69,13 +69,13 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
return BufferND(strides, array.asBuffer()) return BufferND(strides, array.asBuffer())
} }
override fun combine( override fun zip(
a: StructureND<Double>, left: StructureND<Double>,
b: StructureND<Double>, right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double, transform: DoubleField.(Double, Double) -> Double,
): BufferND<Double> { ): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset]) DoubleField.transform(left.buffer.array[offset], right.buffer.array[offset])
}.toArray() }.toArray()
return BufferND(strides, array.asBuffer()) return BufferND(strides, array.asBuffer())
} }

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.buffer import space.kscience.kmath.operations.buffer
import space.kscience.kmath.operations.bufferAlgebra import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.operations.withSize
inline fun <reified R : Any> MutableBuffer.Companion.same( inline fun <reified R : Any> MutableBuffer.Companion.same(
n: Int, n: Int,
@ -16,7 +17,7 @@ inline fun <reified R : Any> MutableBuffer.Companion.same(
fun main() { fun main() {
with(DoubleField.bufferAlgebra(5)) { with(DoubleField.bufferAlgebra.withSize(5)) {
println(number(2.0) + buffer(1, 2, 3, 4, 5)) println(number(2.0) + buffer(1, 2, 3, 4, 5))
} }
} }

View File

@ -0,0 +1,24 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.tensors
import edu.mcgill.kaliningraph.power
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.linalg.dot
import org.jetbrains.kotlinx.multik.api.math.exp
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
import org.jetbrains.kotlinx.multik.ndarray.operations.unaryMinus
import space.kscience.kmath.multik.multikND
import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField
fun main(): Unit = with(DoubleField.multikND) {
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0))
one(a.shape) - a
}

View File

@ -1,5 +1,5 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult
import com.github.h0tk3y.betterParse.parser.Parser import com.github.h0tk3y.betterParse.parser.Parser
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.operations.FieldOperations import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations import space.kscience.kmath.operations.RingOps
/** /**
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4. * better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
.or(binaryFunction) .or(binaryFunction)
.or(unaryFunction) .or(unaryFunction)
.or(singular) .or(singular)
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) }) .or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) })
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar) .or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b -> private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = div or mul use TokenMatch::type operator = div or mul use TokenMatch::type
) { a, op, b -> ) { a, op, b ->
if (op == div) if (op == div)
MST.Binary(FieldOperations.DIV_OPERATION, a, b) MST.Binary(FieldOps.DIV_OPERATION, a, b)
else else
MST.Binary(RingOperations.TIMES_OPERATION, a, b) MST.Binary(RingOps.TIMES_OPERATION, a, b)
} }
private val subSumChain: Parser<MST> by leftAssociative( private val subSumChain: Parser<MST> by leftAssociative(
@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
operator = plus or minus use TokenMatch::type operator = plus or minus use TokenMatch::type
) { a, op, b -> ) { a, op, b ->
if (op == plus) if (op == plus)
MST.Binary(GroupOperations.PLUS_OPERATION, a, b) MST.Binary(GroupOps.PLUS_OPERATION, a, b)
else else
MST.Binary(GroupOperations.MINUS_OPERATION, a, b) MST.Binary(GroupOps.MINUS_OPERATION, a, b)
} }
override val rootParser: Parser<MST> by subSumChain override val rootParser: Parser<MST> by subSumChain

View File

@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node ->
@UnstableKMathAPI @UnstableKMathAPI
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-')) private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
UnaryMinusSyntax( UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION, operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax( operand = OperandSyntax(
operand = NumberSyntax(string = s.removePrefix("-")), operand = NumberSyntax(string = s.removePrefix("-")),
parentheses = true, parentheses = true,
@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
val exponent = afterE.toDouble().toString().removeSuffix(".0") val exponent = afterE.toDouble().toString().removeSuffix(".0")
return MultiplicationSyntax( return MultiplicationSyntax(
operation = RingOperations.TIMES_OPERATION, operation = RingOps.TIMES_OPERATION,
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true), left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
right = OperandSyntax( right = OperandSyntax(
operand = SuperscriptSyntax( operand = SuperscriptSyntax(
@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
if (toString.startsWith('-')) if (toString.startsWith('-'))
return UnaryMinusSyntax( return UnaryMinusSyntax(
operation = GroupOperations.MINUS_OPERATION, operation = GroupOps.MINUS_OPERATION,
operand = OperandSyntax(operand = infty, parentheses = true), operand = OperandSyntax(operand = infty, parentheses = true),
) )
@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.PLUS_OPERATION]. * The default instance configured with [GroupOps.PLUS_OPERATION].
*/ */
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION)) public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION))
} }
} }
@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.MINUS_OPERATION]. * The default instance configured with [GroupOps.MINUS_OPERATION].
*/ */
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION)) public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION))
} }
} }
@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.PLUS_OPERATION]. * The default instance configured with [GroupOps.PLUS_OPERATION].
*/ */
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION)) public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION))
} }
} }
@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [GroupOperations.MINUS_OPERATION]. * The default instance configured with [GroupOps.MINUS_OPERATION].
*/ */
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION)) public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION))
} }
} }
@ -295,9 +295,9 @@ public class Fraction(operations: Collection<String>?) : Binary(operations) {
public companion object { public companion object {
/** /**
* The default instance configured with [FieldOperations.DIV_OPERATION]. * The default instance configured with [FieldOps.DIV_OPERATION].
*/ */
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION)) public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION))
} }
} }
@ -422,9 +422,9 @@ public class Multiplication(operations: Collection<String>?) : Binary(operations
public companion object { public companion object {
/** /**
* The default instance configured with [RingOperations.TIMES_OPERATION]. * The default instance configured with [RingOps.TIMES_OPERATION].
*/ */
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION)) public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION))
} }
} }

View File

@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.FieldOperations import space.kscience.kmath.operations.FieldOps
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.PowerOperations import space.kscience.kmath.operations.PowerOperations
import space.kscience.kmath.operations.RingOperations import space.kscience.kmath.operations.RingOps
/** /**
* Removes unnecessary times (&times;) symbols from [MultiplicationSyntax]. * Removes unnecessary times (&times;) symbols from [MultiplicationSyntax].
@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
is BinarySyntax -> when (it.operation) { is BinarySyntax -> when (it.operation) {
PowerOperations.POW_OPERATION -> 1 PowerOperations.POW_OPERATION -> 1
RingOperations.TIMES_OPERATION -> 3 RingOps.TIMES_OPERATION -> 3
FieldOperations.DIV_OPERATION -> 3 FieldOps.DIV_OPERATION -> 3
GroupOperations.MINUS_OPERATION -> 4 GroupOps.MINUS_OPERATION -> 4
GroupOperations.PLUS_OPERATION -> 4 GroupOps.PLUS_OPERATION -> 4
else -> 0 else -> 0
} }

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testLatex import space.kscience.kmath.ast.rendering.TestUtils.testLatex
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test import kotlin.test.Test
internal class TestLatex { internal class TestLatex {
@ -36,7 +36,7 @@ internal class TestLatex {
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)") fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
@Test @Test
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1") fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1")
@Test @Test
fun unaryMinus() = testLatex("-x", "-x") fun unaryMinus() = testLatex("-x", "-x")

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
import space.kscience.kmath.ast.rendering.TestUtils.testMathML import space.kscience.kmath.ast.rendering.TestUtils.testMathML
import space.kscience.kmath.expressions.MST import space.kscience.kmath.expressions.MST
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import kotlin.test.Test import kotlin.test.Test
internal class TestMathML { internal class TestMathML {
@ -47,7 +47,7 @@ internal class TestMathML {
@Test @Test
fun unaryPlus() = fun unaryPlus() =
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>") testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
@Test @Test
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>") fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")

View File

@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value) override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value)) GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value) GroupOps.PLUS_OPERATION -> visit(mst.value)
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value)) PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64) TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64) TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right)) FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64) PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
else -> super.visitBinary(mst) else -> super.visitBinary(mst)
} }
@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, targ
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value) override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) { override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
GroupOperations.PLUS_OPERATION -> visit(mst.value) GroupOps.PLUS_OPERATION -> visit(mst.value)
else -> super.visitUnary(mst) else -> super.visitUnary(mst)
} }
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) { override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right)) GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right)) GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right)) RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
else -> super.visitBinary(mst) else -> super.visitBinary(mst)
} }
} }

View File

@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
import space.kscience.kmath.expressions.* import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
/** /**
* A field over commons-math [DerivativeStructure]. * A field over commons-math [DerivativeStructure].
@ -22,7 +22,7 @@ public class DerivativeStructureField(
public val order: Int, public val order: Int,
bindings: Map<Symbol, Double>, bindings: Map<Symbol, Double>,
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>, ) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
NumbersAddOperations<DerivativeStructure> { NumbersAddOps<DerivativeStructure> {
public val numberOfVariables: Int = bindings.size public val numberOfVariables: Int = bindings.size
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) } override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
@ -70,12 +70,12 @@ public class DerivativeStructureField(
override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate() override fun DerivativeStructure.unaryMinus(): DerivativeStructure = negate()
override fun add(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.add(b) override fun add(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.add(right)
override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value) override fun scale(a: DerivativeStructure, value: Double): DerivativeStructure = a.multiply(value)
override fun multiply(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.multiply(b) override fun multiply(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.multiply(right)
override fun divide(a: DerivativeStructure, b: DerivativeStructure): DerivativeStructure = a.divide(b) override fun divide(left: DerivativeStructure, right: DerivativeStructure): DerivativeStructure = left.divide(right)
override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin() override fun sin(arg: DerivativeStructure): DerivativeStructure = arg.sin()
override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos() override fun cos(arg: DerivativeStructure): DerivativeStructure = arg.cos()
override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan() override fun tan(arg: DerivativeStructure): DerivativeStructure = arg.tan()
@ -99,10 +99,10 @@ public class DerivativeStructureField(
override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp() override fun exp(arg: DerivativeStructure): DerivativeStructure = arg.exp()
override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log() override fun ln(arg: DerivativeStructure): DerivativeStructure = arg.log()
override operator fun DerivativeStructure.plus(b: Number): DerivativeStructure = add(b.toDouble()) override operator fun DerivativeStructure.plus(other: Number): DerivativeStructure = add(other.toDouble())
override operator fun DerivativeStructure.minus(b: Number): DerivativeStructure = subtract(b.toDouble()) override operator fun DerivativeStructure.minus(other: Number): DerivativeStructure = subtract(other.toDouble())
override operator fun Number.plus(b: DerivativeStructure): DerivativeStructure = b + this override operator fun Number.plus(other: DerivativeStructure): DerivativeStructure = other + this
override operator fun Number.minus(b: DerivativeStructure): DerivativeStructure = b - this override operator fun Number.minus(other: DerivativeStructure): DerivativeStructure = other - this
} }
/** /**

View File

@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
public object ComplexField : public object ComplexField :
ExtendedField<Complex>, ExtendedField<Complex>,
Norm<Complex, Complex>, Norm<Complex, Complex>,
NumbersAddOperations<Complex>, NumbersAddOps<Complex>,
ScaleOperations<Complex> { ScaleOperations<Complex> {
override val zero: Complex = 0.0.toComplex() override val zero: Complex = 0.0.toComplex()
@ -77,33 +77,33 @@ public object ComplexField :
override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value) override fun scale(a: Complex, value: Double): Complex = Complex(a.re * value, a.im * value)
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im) override fun add(left: Complex, right: Complex): Complex = Complex(left.re + right.re, left.im + right.im)
// override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble()) // override fun multiply(a: Complex, k: Number): Complex = Complex(a.re * k.toDouble(), a.im * k.toDouble())
override fun multiply(a: Complex, b: Complex): Complex = override fun multiply(left: Complex, right: Complex): Complex =
Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re) Complex(left.re * right.re - left.im * right.im, left.re * right.im + left.im * right.re)
override fun divide(a: Complex, b: Complex): Complex = when { override fun divide(left: Complex, right: Complex): Complex = when {
abs(b.im) < abs(b.re) -> { abs(right.im) < abs(right.re) -> {
val wr = b.im / b.re val wr = right.im / right.re
val wd = b.re + wr * b.im val wd = right.re + wr * right.im
if (wd.isNaN() || wd == 0.0) if (wd.isNaN() || wd == 0.0)
throw ArithmeticException("Division by zero or infinity") throw ArithmeticException("Division by zero or infinity")
else else
Complex((a.re + a.im * wr) / wd, (a.im - a.re * wr) / wd) Complex((left.re + left.im * wr) / wd, (left.im - left.re * wr) / wd)
} }
b.im == 0.0 -> throw ArithmeticException("Division by zero") right.im == 0.0 -> throw ArithmeticException("Division by zero")
else -> { else -> {
val wr = b.re / b.im val wr = right.re / right.im
val wd = b.im + wr * b.re val wd = right.im + wr * right.re
if (wd.isNaN() || wd == 0.0) if (wd.isNaN() || wd == 0.0)
throw ArithmeticException("Division by zero or infinity") throw ArithmeticException("Division by zero or infinity")
else else
Complex((a.re * wr + a.im) / wd, (a.im * wr - a.re) / wd) Complex((left.re * wr + left.im) / wd, (left.im * wr - left.re) / wd)
} }
} }
@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) {
public val Complex.Companion.algebra: ComplexField get() = ComplexField public val Complex.Companion.algebra: ComplexField get() = ComplexField
/** /**
* Creates a complex number with real part equal to this real. * Creates a complex number with real part equal to this real.
* *

View File

@ -6,13 +6,8 @@
package space.kscience.kmath.complex package space.kscience.kmath.complex
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.BufferedFieldND import space.kscience.kmath.operations.*
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.BufferField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
@ -22,100 +17,61 @@ import kotlin.contracts.contract
* An optimized nd-field for complex numbers * An optimized nd-field for complex numbers
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class ComplexFieldND( public sealed class ComplexFieldOpsND : BufferedFieldOpsND<Complex, ComplexField>(ComplexField.bufferAlgebra),
shape: IntArray, ScaleOperations<StructureND<Complex>>, ExtendedFieldOps<StructureND<Complex>> {
) : BufferedFieldND<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
NumbersAddOperations<StructureND<Complex>>,
ExtendedField<StructureND<Complex>> {
override val zero: BufferND<Complex> by lazy { produce { zero } } override fun StructureND<Complex>.toBufferND(): BufferND<Complex> = when (this) {
override val one: BufferND<Complex> by lazy { produce { one } } is BufferND -> this
else -> {
override fun number(value: Number): BufferND<Complex> { val indexer = indexerBuilder(shape)
val d = value.toComplex() // minimize conversions BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) })
return produce { d } }
} }
// //TODO do specialization
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun map(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double) -> Double,
// ): RealNDElement {
// check(arg)
// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
// return BufferedNDFieldElement(this, array)
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun mapIndexed(
// arg: AbstractNDBuffer<Double>,
// transform: DoubleField.(index: IntArray, Double) -> Double,
// ): RealNDElement {
// check(arg)
// return BufferedNDFieldElement(
// this,
// RealBuffer(arg.strides.linearSize) { offset ->
// elementContext.transform(
// arg.strides.index(offset),
// arg.buffer[offset]
// )
// })
// }
//
// @Suppress("OVERRIDE_BY_INLINE")
// override inline fun combine(
// a: AbstractNDBuffer<Double>,
// b: AbstractNDBuffer<Double>,
// transform: DoubleField.(Double, Double) -> Double,
// ): RealNDElement {
// check(a, b)
// val buffer = RealBuffer(strides.linearSize) { offset ->
// elementContext.transform(a.buffer[offset], b.buffer[offset])
// }
// return BufferedNDFieldElement(this, buffer)
// }
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) } override fun scale(a: StructureND<Complex>, value: Double): BufferND<Complex> =
mapInline(a.toBufferND()) { it * value }
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) } override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> =
mapInline(arg.toBufferND()) { power(it, pow) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) } override fun exp(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { exp(it) }
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { ln(it) }
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) } override fun sin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sin(it) }
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) } override fun cos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cos(it) }
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) } override fun tan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tan(it) }
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) } override fun asin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asin(it) }
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) } override fun acos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acos(it) }
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) } override fun atan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atan(it) }
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) } override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sinh(it) }
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) } override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cosh(it) }
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) } override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tanh(it) }
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) } override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asinh(it) }
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) } override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acosh(it) }
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) } override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atanh(it) }
}
public companion object : ComplexFieldOpsND()
/**
* Fast element production using function inlining
*/
public inline fun BufferedFieldND<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND<Complex> {
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferND(strides, buffer)
} }
@UnstableKMathAPI @UnstableKMathAPI
public fun ComplexField.bufferAlgebra(size: Int): BufferField<Complex, ComplexField> = public val ComplexField.bufferAlgebra: BufferFieldOps<Complex, ComplexField>
bufferAlgebra(Buffer.Companion::complex, size) get() = bufferAlgebra(Buffer.Companion::complex)
@OptIn(UnstableKMathAPI::class)
public class ComplexFieldND(override val shape: Shape) :
ComplexFieldOpsND(), FieldND<Complex, ComplexField>, NumbersAddOps<StructureND<Complex>> {
override fun number(value: Number): BufferND<Complex> {
val d = value.toDouble() // minimize conversions
return structureND(shape) { d.toComplex() }
}
}
public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape) public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)

View File

@ -44,7 +44,7 @@ public val Quaternion.r: Double
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>, public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
ExponentialOperations<Quaternion>, NumbersAddOperations<Quaternion>, ScaleOperations<Quaternion> { ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
override val zero: Quaternion = 0.toQuaternion() override val zero: Quaternion = 0.toQuaternion()
override val one: Quaternion = 1.toQuaternion() override val one: Quaternion = 1.toQuaternion()
@ -63,27 +63,27 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
*/ */
public val k: Quaternion = Quaternion(0, 0, 0, 1) public val k: Quaternion = Quaternion(0, 0, 0, 1)
override fun add(a: Quaternion, b: Quaternion): Quaternion = override fun add(left: Quaternion, right: Quaternion): Quaternion =
Quaternion(a.w + b.w, a.x + b.x, a.y + b.y, a.z + b.z) Quaternion(left.w + right.w, left.x + right.x, left.y + right.y, left.z + right.z)
override fun scale(a: Quaternion, value: Double): Quaternion = override fun scale(a: Quaternion, value: Double): Quaternion =
Quaternion(a.w * value, a.x * value, a.y * value, a.z * value) Quaternion(a.w * value, a.x * value, a.y * value, a.z * value)
override fun multiply(a: Quaternion, b: Quaternion): Quaternion = Quaternion( override fun multiply(left: Quaternion, right: Quaternion): Quaternion = Quaternion(
a.w * b.w - a.x * b.x - a.y * b.y - a.z * b.z, left.w * right.w - left.x * right.x - left.y * right.y - left.z * right.z,
a.w * b.x + a.x * b.w + a.y * b.z - a.z * b.y, left.w * right.x + left.x * right.w + left.y * right.z - left.z * right.y,
a.w * b.y - a.x * b.z + a.y * b.w + a.z * b.x, left.w * right.y - left.x * right.z + left.y * right.w + left.z * right.x,
a.w * b.z + a.x * b.y - a.y * b.x + a.z * b.w, left.w * right.z + left.x * right.y - left.y * right.x + left.z * right.w,
) )
override fun divide(a: Quaternion, b: Quaternion): Quaternion { override fun divide(left: Quaternion, right: Quaternion): Quaternion {
val s = b.w * b.w + b.x * b.x + b.y * b.y + b.z * b.z val s = right.w * right.w + right.x * right.x + right.y * right.y + right.z * right.z
return Quaternion( return Quaternion(
(b.w * a.w + b.x * a.x + b.y * a.y + b.z * a.z) / s, (right.w * left.w + right.x * left.x + right.y * left.y + right.z * left.z) / s,
(b.w * a.x - b.x * a.w - b.y * a.z + b.z * a.y) / s, (right.w * left.x - right.x * left.w - right.y * left.z + right.z * left.y) / s,
(b.w * a.y + b.x * a.z - b.y * a.w - b.z * a.x) / s, (right.w * left.y + right.x * left.z - right.y * left.w - right.z * left.x) / s,
(b.w * a.z - b.x * a.y + b.y * a.x - b.z * a.w) / s, (right.w * left.z - right.x * left.y + right.y * left.x - right.z * left.w) / s,
) )
} }
@ -158,16 +158,16 @@ public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>,
return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z) return Quaternion(ln(n), th * arg.x, th * arg.y, th * arg.z)
} }
override operator fun Number.plus(b: Quaternion): Quaternion = Quaternion(toDouble() + b.w, b.x, b.y, b.z) override operator fun Number.plus(other: Quaternion): Quaternion = Quaternion(toDouble() + other.w, other.x, other.y, other.z)
override operator fun Number.minus(b: Quaternion): Quaternion = override operator fun Number.minus(other: Quaternion): Quaternion =
Quaternion(toDouble() - b.w, -b.x, -b.y, -b.z) Quaternion(toDouble() - other.w, -other.x, -other.y, -other.z)
override operator fun Quaternion.plus(b: Number): Quaternion = Quaternion(w + b.toDouble(), x, y, z) override operator fun Quaternion.plus(other: Number): Quaternion = Quaternion(w + other.toDouble(), x, y, z)
override operator fun Quaternion.minus(b: Number): Quaternion = Quaternion(w - b.toDouble(), x, y, z) override operator fun Quaternion.minus(other: Number): Quaternion = Quaternion(w - other.toDouble(), x, y, z)
override operator fun Number.times(b: Quaternion): Quaternion = override operator fun Number.times(other: Quaternion): Quaternion =
Quaternion(toDouble() * b.w, toDouble() * b.x, toDouble() * b.y, toDouble() * b.z) Quaternion(toDouble() * other.w, toDouble() * other.x, toDouble() * other.y, toDouble() * other.z)
override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z) override fun Quaternion.unaryMinus(): Quaternion = Quaternion(-w, -x, -y, -z)
override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg) override fun norm(arg: Quaternion): Quaternion = sqrt(arg.conjugate * arg)

View File

@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
override val zero: Expression<T> get() = const(algebra.zero) override val zero: Expression<T> get() = const(algebra.zero)
override fun Expression<T>.unaryMinus(): Expression<T> = override fun Expression<T>.unaryMinus(): Expression<T> =
unaryOperation(GroupOperations.MINUS_OPERATION, this) unaryOperation(GroupOps.MINUS_OPERATION, this)
/** /**
* Builds an Expression of addition of two another expressions. * Builds an Expression of addition of two another expressions.
*/ */
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> = override fun add(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperation(GroupOperations.PLUS_OPERATION, a, b) binaryOperation(GroupOps.PLUS_OPERATION, left, right)
// /** // /**
// * Builds an Expression of multiplication of expression by number. // * Builds an Expression of multiplication of expression by number.
@ -88,8 +88,8 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
/** /**
* Builds an Expression of multiplication of two expressions. * Builds an Expression of multiplication of two expressions.
*/ */
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> = override fun multiply(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg) public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
@ -107,8 +107,8 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
/** /**
* Builds an Expression of division an expression by another one. * Builds an Expression of division an expression by another one.
*/ */
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> = override fun divide(left: Expression<T>, right: Expression<T>): Expression<T> =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg) public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this

View File

@ -31,18 +31,18 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value) override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b) override fun add(left: MST, right: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = override operator fun MST.unaryPlus(): MST.Unary =
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this) unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
override operator fun MST.unaryMinus(): MST.Unary = override operator fun MST.unaryMinus(): MST.Unary =
unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this) unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
override operator fun MST.minus(b: MST): MST.Binary = override operator fun MST.minus(other: MST): MST.Binary =
binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b) binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, other)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value)) binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstNumericAlgebra.binaryOperationFunction(operation) MstNumericAlgebra.binaryOperationFunction(operation)
@ -56,23 +56,23 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
*/ */
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> { public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstGroup.zero override inline val zero: MST.Numeric get() = MstGroup.zero
override val one: MST.Numeric = number(1.0) override val one: MST.Numeric = number(1.0)
override fun number(value: Number): MST.Numeric = MstGroup.number(value) override fun number(value: Number): MST.Numeric = MstGroup.number(value)
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstGroup.add(left, right)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = override fun multiply(left: MST, right: MST): MST.Binary =
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b) binaryOperationFunction(RingOps.TIMES_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstGroup { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstGroup { this@minus - other }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstGroup.binaryOperationFunction(operation) MstGroup.binaryOperationFunction(operation)
@ -86,24 +86,24 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
*/ */
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> { public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
override inline val zero: MST.Numeric get() = MstRing.zero override inline val zero: MST.Numeric get() = MstRing.zero
override inline val one: MST.Numeric get() = MstRing.one override inline val one: MST.Numeric get() = MstRing.one
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value) override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
override fun number(value: Number): MST.Numeric = MstRing.number(value) override fun number(value: Number): MST.Numeric = MstRing.number(value)
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstRing.add(left, right)
override fun scale(a: MST, value: Double): MST.Binary = override fun scale(a: MST, value: Double): MST.Binary =
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value)) MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b) override fun multiply(left: MST, right: MST): MST.Binary = MstRing.multiply(left, right)
override fun divide(a: MST, b: MST): MST.Binary = override fun divide(left: MST, right: MST): MST.Binary =
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b) binaryOperationFunction(FieldOps.DIV_OPERATION)(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstRing { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstRing { this@minus - other }
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary = override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
MstRing.binaryOperationFunction(operation) MstRing.binaryOperationFunction(operation)
@ -134,17 +134,17 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg) override fun asinh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ASINH_OPERATION)(arg)
override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg) override fun acosh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ACOSH_OPERATION)(arg)
override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg) override fun atanh(arg: MST): MST.Unary = unaryOperationFunction(ExponentialOperations.ATANH_OPERATION)(arg)
override fun add(a: MST, b: MST): MST.Binary = MstField.add(a, b) override fun add(left: MST, right: MST): MST.Binary = MstField.add(left, right)
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg) override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
override fun scale(a: MST, value: Double): MST = override fun scale(a: MST, value: Double): MST =
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value)) binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b) override fun multiply(left: MST, right: MST): MST.Binary = MstField.multiply(left, right)
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b) override fun divide(left: MST, right: MST): MST.Binary = MstField.divide(left, right)
override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus } override operator fun MST.unaryPlus(): MST.Unary = MstField { +this@unaryPlus }
override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus } override operator fun MST.unaryMinus(): MST.Unary = MstField { -this@unaryMinus }
override operator fun MST.minus(b: MST): MST.Binary = MstField { this@minus - b } override operator fun MST.minus(other: MST): MST.Binary = MstField { this@minus - other }
override fun power(arg: MST, pow: Number): MST.Binary = override fun power(arg: MST, pow: Number): MST.Binary =
binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow)) binaryOperationFunction(PowerOperations.POW_OPERATION)(arg, number(pow))

View File

@ -59,7 +59,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
public open class SimpleAutoDiffField<T : Any, F : Field<T>>( public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
public val context: F, public val context: F,
bindings: Map<Symbol, T>, bindings: Map<Symbol, T>,
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOperations<AutoDiffValue<T>> { ) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
override val zero: AutoDiffValue<T> get() = const(context.zero) override val zero: AutoDiffValue<T> get() = const(context.zero)
override val one: AutoDiffValue<T> get() = const(context.one) override val one: AutoDiffValue<T> get() = const(context.one)
@ -168,22 +168,22 @@ public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
// Basic math (+, -, *, /) // Basic math (+, -, *, /)
override fun add(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun add(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value + b.value }) { z -> derive(const { left.value + right.value }) { z ->
a.d += z.d left.d += z.d
b.d += z.d right.d += z.d
} }
override fun multiply(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun multiply(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value * b.value }) { z -> derive(const { left.value * right.value }) { z ->
a.d += z.d * b.value left.d += z.d * right.value
b.d += z.d * a.value right.d += z.d * left.value
} }
override fun divide(a: AutoDiffValue<T>, b: AutoDiffValue<T>): AutoDiffValue<T> = override fun divide(left: AutoDiffValue<T>, right: AutoDiffValue<T>): AutoDiffValue<T> =
derive(const { a.value / b.value }) { z -> derive(const { left.value / right.value }) { z ->
a.d += z.d / b.value left.d += z.d / right.value
b.d -= z.d * a.value / (b.value * b.value) right.d -= z.d * left.value / (right.value * right.value)
} }
override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> = override fun scale(a: AutoDiffValue<T>, value: Double): AutoDiffValue<T> =

View File

@ -6,12 +6,10 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.BufferedRingND import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND import space.kscience.kmath.nd.asND
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.VirtualBuffer import space.kscience.kmath.structures.VirtualBuffer
@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices
public class BufferedLinearSpace<T, out A : Ring<T>>( public class BufferedLinearSpace<T, out A : Ring<T>>(
override val elementAlgebra: A, private val bufferAlgebra: BufferAlgebra<T, A>
private val bufferFactory: BufferFactory<T>,
) : LinearSpace<T, A> { ) : LinearSpace<T, A> {
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
private fun ndRing( private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
rows: Int,
cols: Int,
): BufferedRingND<T, A> = elementAlgebra.ndAlgebra(bufferFactory, rows, cols)
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> = override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() ndAlgebra.structureND(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> = override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
bufferFactory(size) { elementAlgebra.initializer(it) } bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.unaryMinus(): Matrix<T> = ndAlgebra {
asND().map { -it }.as2D() asND().map { -it }.as2D()
} }
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }
@ -88,11 +83,11 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
} }
} }
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run { override fun Matrix<T>.times(value: T): Matrix<T> = ndAlgebra {
asND().map { it * value }.as2D() asND().map { it * value }.as2D()
} }
} }
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> = public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
BufferedLinearSpace(this, bufferFactory) BufferedLinearSpace(BufferRingOps(this, bufferFactory))

View File

@ -6,11 +6,12 @@
package space.kscience.kmath.linear package space.kscience.kmath.linear
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DoubleFieldND import space.kscience.kmath.nd.DoubleFieldOpsND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.asND import space.kscience.kmath.nd.asND
import space.kscience.kmath.operations.DoubleBufferOperations import space.kscience.kmath.operations.DoubleBufferOps
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
private fun ndRing(
rows: Int,
cols: Int,
): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols))
override fun buildMatrix( override fun buildMatrix(
rows: Int, rows: Int,
columns: Int, columns: Int,
initializer: DoubleField.(i: Int, j: Int) -> Double initializer: DoubleField.(i: Int, j: Int) -> Double
): Matrix<Double> = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D() ): Matrix<Double> = DoubleFieldOpsND.structureND(intArrayOf(rows, columns)) { (i, j) ->
DoubleField.initializer(i, j)
}.as2D()
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer = override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
DoubleBuffer(size) { DoubleField.initializer(it) } DoubleBuffer(size) { DoubleField.initializer(it) }
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.unaryMinus(): Matrix<Double> = DoubleFieldOpsND {
asND().map { -it }.as2D() asND().map { -it }.as2D()
} }
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
asND().plus(other.asND()).as2D() asND().plus(other.asND()).as2D()
} }
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" } require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
asND().minus(other.asND()).as2D() asND().minus(other.asND()).as2D()
} }
@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
} }
override fun Matrix<Double>.times(value: Double): Matrix<Double> = ndRing(rowNum, colNum).run { override fun Matrix<Double>.times(value: Double): Matrix<Double> = DoubleFieldOpsND {
asND().map { it * value }.as2D() asND().map { it * value }.as2D()
} }
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@plus + other this@plus + other
} }
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
this@minus - other this@minus - other
} }
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOperations.run { public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@times, value) scale(this@times, value)
} }
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOperations.run { public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOps.run {
scale(this@div, 1.0 / value) scale(this@div, 1.0 / value)
} }

View File

@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.Structure2D import space.kscience.kmath.nd.Structure2D
import space.kscience.kmath.nd.StructureFeature import space.kscience.kmath.nd.StructureFeature
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.operations.BufferRingOps
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -188,7 +189,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
public fun <T : Any, A : Ring<T>> buffered( public fun <T : Any, A : Ring<T>> buffered(
algebra: A, algebra: A,
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing, bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
): LinearSpace<T, A> = BufferedLinearSpace(algebra, bufferFactory) ): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
@Deprecated("use DoubleField.linearSpace") @Deprecated("use DoubleField.linearSpace")
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer) public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)

View File

@ -7,7 +7,6 @@ package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
/** /**
@ -19,6 +18,14 @@ import kotlin.reflect.KClass
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) : public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.") RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
public typealias Shape = IntArray
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
public interface WithShape {
public val shape: Shape
}
/** /**
* The base interface for all ND-algebra implementations. * The base interface for all ND-algebra implementations.
* *
@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
* @param C the type of the element context. * @param C the type of the element context.
*/ */
public interface AlgebraND<T, out C : Algebra<T>> { public interface AlgebraND<T, out C : Algebra<T>> {
/**
* The shape of ND-structures this algebra operates on.
*/
public val shape: IntArray
/** /**
* The algebra over elements of ND structure. * The algebra over elements of ND structure.
*/ */
public val elementContext: C public val elementAlgebra: C
/** /**
* Produces a new NDStructure using given initializer function. * Produces a new [StructureND] using given initializer function.
*/ */
public fun produce(initializer: C.(IntArray) -> T): StructureND<T> public fun structureND(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
/** /**
* Maps elements from one structure to another one by applying [transform] to them. * Maps elements from one structure to another one by applying [transform] to them.
@ -54,7 +56,7 @@ public interface AlgebraND<T, out C : Algebra<T>> {
/** /**
* Combines two structures into one. * Combines two structures into one.
*/ */
public fun combine(a: StructureND<T>, b: StructureND<T>, transform: C.(T, T) -> T): StructureND<T> public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
/** /**
* Element-wise invocation of function working on [T] on a [StructureND]. * Element-wise invocation of function working on [T] on a [StructureND].
@ -77,7 +79,6 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public companion object public companion object
} }
/** /**
* Get a feature of the structure in this scope. Structure features take precedence other context features. * Get a feature of the structure in this scope. Structure features take precedence other context features.
* *
@ -89,46 +90,22 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? = public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
getFeature(structure, F::class) getFeature(structure, F::class)
/**
* Checks if given elements are consistent with this context.
*
* @param structures the structures to check.
* @return the array of valid structures.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(vararg structures: StructureND<T>): Array<out StructureND<T>> =
structures
.map(StructureND<T>::shape)
.singleOrNull { !shape.contentEquals(it) }
?.let<IntArray, Array<out StructureND<T>>> { throw ShapeMismatchException(shape, it) }
?: structures
/**
* Checks if given element is consistent with this context.
*
* @param element the structure to check.
* @return the valid structure.
*/
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND<T>): StructureND<T> {
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
return element
}
/** /**
* Space of [StructureND]. * Space of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param S the type of group over structure elements. * @param A the type of group over structure elements.
*/ */
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> { public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
/** /**
* Element-wise addition. * Element-wise addition.
* *
* @param a the augend. * @param left the augend.
* @param b the addend. * @param right the addend.
* @return the sum. * @return the sum.
*/ */
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun add(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> add(aValue, bValue) } zip(left, right) { aValue, bValue -> add(aValue, bValue) }
// TODO move to extensions after KEEP-176 // TODO move to extensions after KEEP-176
@ -157,7 +134,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
* @param arg the addend. * @param arg the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(this@plus, value) } public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this
/** /**
* Subtracts an ND structure from an element of it. * Subtracts an ND structure from an element of it.
@ -171,22 +148,26 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
public companion object public companion object
} }
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
override val zero: StructureND<T> get() = structureND(shape) { elementAlgebra.zero }
}
/** /**
* Ring of [StructureND]. * Ring of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param R the type of ring over structure elements. * @param A the type of ring over structure elements.
*/ */
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> { public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, GroupOpsND<T, A> {
/** /**
* Element-wise multiplication. * Element-wise multiplication.
* *
* @param a the multiplicand. * @param left the multiplicand.
* @param b the multiplier. * @param right the multiplier.
* @return the product. * @return the product.
*/ */
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun multiply(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) } zip(left, right) { aValue, bValue -> multiply(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after KEEP-176
@ -211,24 +192,32 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
public companion object public companion object
} }
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
}
/** /**
* Field of [StructureND]. * Field of [StructureND].
* *
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param F the type field over structure elements. * @param A the type field over structure elements.
*/ */
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F> { public interface FieldOpsND<T, out A : Field<T>> :
FieldOps<StructureND<T>>,
RingOpsND<T, A>,
ScaleOperations<StructureND<T>> {
/** /**
* Element-wise division. * Element-wise division.
* *
* @param a the dividend. * @param left the dividend.
* @param b the divisor. * @param right the divisor.
* @return the quotient. * @return the quotient.
*/ */
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> = override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> =
combine(a, b) { aValue, bValue -> divide(aValue, bValue) } zip(left, right) { aValue, bValue -> divide(aValue, bValue) }
//TODO move to extensions after KEEP-176 //TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
/** /**
* Divides an ND structure by an element of it. * Divides an ND structure by an element of it.
* *
@ -247,42 +236,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
*/ */
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) } public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
/**
* Element-wise scaling.
*
* @param a the multiplicand.
* @param value the multiplier.
* @return the product.
*/
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
// @ThreadLocal
// public companion object {
// private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
//
// /**
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
// */
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
//
// /**
// * Create an ND field with boxing generic buffer.
// */
// public fun <T : Any, F : Field<T>> boxing(
// field: F,
// vararg shape: Int,
// bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
// ): BufferedNDField<T, F> = BufferedNDField(shape, field, bufferFactory)
//
// /**
// * Create a most suitable implementation for nd-field using reified class.
// */
// @Suppress("UNCHECKED_CAST")
// public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): NDField<T, F> =
// when {
// T::class == Double::class -> real(*shape) as NDField<T, F>
// T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
// }
// }
} }
public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
}

View File

@ -3,145 +3,177 @@
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/ */
@file:OptIn(UnstableKMathAPI::class)
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.jvm.JvmName
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> { public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val strides: Strides public val indexerBuilder: (IntArray) -> ShapeIndex
public val bufferFactory: BufferFactory<T> public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
override fun produce(initializer: A.(IntArray) -> T): BufferND<T> = BufferND( override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
strides, val indexer = indexerBuilder(shape)
bufferFactory(strides.linearSize) { offset -> return BufferND(
elementContext.initializer(strides.index(offset)) indexer,
bufferAlgebra.buffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
}
)
}
public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
is BufferND -> this
else -> {
val indexer = indexerBuilder(shape)
BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> = mapInline(toBufferND(), transform)
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> =
mapIndexedInline(toBufferND(), transform)
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> =
zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
}
}
public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
arg: BufferND<T>,
crossinline transform: A.(T) -> T
): BufferND<T> {
val indexes = arg.indexes
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
}
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
arg: BufferND<T>,
crossinline transform: A.(index: IntArray, arg: T) -> T
): BufferND<T> {
val indexes = arg.indexes
return BufferND(
indexes,
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
transform(indexes.index(offset), value)
} }
) )
public val StructureND<T>.buffer: Buffer<T>
get() = when {
!shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException(
this@BufferAlgebraND.shape,
shape
)
this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
}
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(buffer[offset])
}
return BufferND(strides, buffer)
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(
strides.index(offset),
buffer[offset]
)
}
return BufferND(strides, buffer)
}
override fun combine(a: StructureND<T>, b: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> {
val buffer = bufferFactory(strides.linearSize) { offset ->
elementContext.transform(a.buffer[offset], b.buffer[offset])
}
return BufferND(strides, buffer)
}
} }
public open class BufferedGroupND<T, out A : Group<T>>( internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
final override val shape: IntArray, l: BufferND<T>,
final override val elementContext: A, r: BufferND<T>,
final override val bufferFactory: BufferFactory<T>, crossinline block: A.(l: T, r: T) -> T
) : GroupND<T, A>, BufferAlgebraND<T, A> { ): BufferND<T> {
override val strides: Strides = DefaultStrides(shape) require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
override val zero: BufferND<T> by lazy { produce { zero } } val indexes = l.indexes
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) } return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
} }
public open class BufferedRingND<T, out R : Ring<T>>( public open class BufferedGroupNDOps<T, out A : Group<T>>(
shape: IntArray, override val bufferAlgebra: BufferAlgebra<T, A>,
elementContext: R, override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
bufferFactory: BufferFactory<T>, ) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
) : BufferedGroupND<T, R>(shape, elementContext, bufferFactory), RingND<T, R> { override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
override val one: BufferND<T> by lazy { produce { one } }
} }
public open class BufferedFieldND<T, out R : Field<T>>( public open class BufferedRingOpsND<T, out A : Ring<T>>(
shape: IntArray, bufferAlgebra: BufferAlgebra<T, A>,
elementContext: R, indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
bufferFactory: BufferFactory<T>, ) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
) : BufferedRingND<T, R>(shape, elementContext, bufferFactory), FieldND<T, R> {
public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
} }
// group factories public val <T, A : Group<T>> BufferAlgebra<T, A>.nd: BufferedGroupNDOps<T, A> get() = BufferedGroupNDOps(this)
public fun <T, A : Group<T>> A.ndAlgebra( public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get() = BufferedRingOpsND(this)
bufferFactory: BufferFactory<T>, public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
vararg shape: Int,
): BufferedGroupND<T, A> = BufferedGroupND(shape, this, bufferFactory)
@JvmName("withNdGroup")
public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
noinline bufferFactory: BufferFactory<T>,
vararg shape: Int,
action: BufferedGroupND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
//ring factories public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.structureND(
public fun <T, A : Ring<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>,
vararg shape: Int, vararg shape: Int,
): BufferedRingND<T, A> = BufferedRingND(shape, this, bufferFactory) initializer: A.(IntArray) -> T
): BufferND<T> = structureND(shape, initializer)
@JvmName("withNdRing") public fun <T, EA : Algebra<T>, A> A.structureND(
public inline fun <T, A : Ring<T>, R> A.withNdAlgebra( initializer: EA.(IntArray) -> T
noinline bufferFactory: BufferFactory<T>, ): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = structureND(shape, initializer)
vararg shape: Int,
action: BufferedRingND<T, A>.() -> R,
): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ndAlgebra(bufferFactory, *shape).run(action)
}
//field factories //// group factories
public fun <T, A : Field<T>> A.ndAlgebra( //public fun <T, A : Group<T>> A.ndAlgebra(
bufferFactory: BufferFactory<T>, // bufferAlgebra: BufferAlgebra<T, A>,
vararg shape: Int, // vararg shape: Int,
): BufferedFieldND<T, A> = BufferedFieldND(shape, this, bufferFactory) //): BufferedGroupNDOps<T, A> = BufferedGroupNDOps(bufferAlgebra)
//
//@JvmName("withNdGroup")
//public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedGroupNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}
/** ////ring factories
* Create a [FieldND] for this [Field] inferring proper buffer factory from the type //public fun <T, A : Ring<T>> A.ndAlgebra(
*/ // bufferFactory: BufferFactory<T>,
@UnstableKMathAPI // vararg shape: Int,
@Suppress("UNCHECKED_CAST") //): BufferedRingNDOps<T, A> = BufferedRingNDOps(shape, this, bufferFactory)
public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra( //
vararg shape: Int, //@JvmName("withNdRing")
): FieldND<T, A> = when (this) { //public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
DoubleField -> DoubleFieldND(shape) as FieldND<T, A> // noinline bufferFactory: BufferFactory<T>,
else -> BufferedFieldND(shape, this, Buffer.Companion::auto) // vararg shape: Int,
} // action: BufferedRingNDOps<T, A>.() -> R,
//): R {
@JvmName("withNdField") // contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
public inline fun <T, A : Field<T>, R> A.withNdAlgebra( // return ndAlgebra(bufferFactory, *shape).run(action)
noinline bufferFactory: BufferFactory<T>, //}
vararg shape: Int, //
action: BufferedFieldND<T, A>.() -> R, ////field factories
): R { //public fun <T, A : Field<T>> A.ndAlgebra(
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } // bufferFactory: BufferFactory<T>,
return ndAlgebra(bufferFactory, *shape).run(action) // vararg shape: Int,
} //): BufferedFieldNDOps<T, A> = BufferedFieldNDOps(shape, this, bufferFactory)
//
///**
// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type
// */
//@UnstableKMathAPI
//@Suppress("UNCHECKED_CAST")
//public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
// vararg shape: Int,
//): FieldND<T, A> = when (this) {
// DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto)
//}
//
//@JvmName("withNdField")
//public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
// noinline bufferFactory: BufferFactory<T>,
// vararg shape: Int,
// action: BufferedFieldNDOps<T, A>.() -> R,
//): R {
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
// return ndAlgebra(bufferFactory, *shape).run(action)
//}

View File

@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Represents [StructureND] over [Buffer]. * Represents [StructureND] over [Buffer].
* *
* @param T the type of items. * @param T the type of items.
* @param strides The strides to access elements of [Buffer] by linear indices. * @param indexes The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val strides: Strides, public val indexes: ShapeIndex,
public val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
init { override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
if (strides.linearSize != buffer.size) {
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
}
}
override operator fun get(index: IntArray): T = buffer[strides.offset(index)] override val shape: IntArray get() = indexes.shape
override val shape: IntArray get() = strides.shape
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map {
it to this[it] it to this[it]
} }
@ -49,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): BufferND<R> { ): BufferND<R> {
return if (this is BufferND<T>) return if (this is BufferND<T>)
BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) }) BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
@ -61,14 +55,14 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* *
* @param T the type of items. * @param T the type of items.
* @param strides The strides to access elements of [MutableBuffer] by linear indices. * @param strides The strides to access elements of [MutableBuffer] by linear indices.
* @param mutableBuffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public class MutableBufferND<T>( public class MutableBufferND<T>(
strides: Strides, strides: ShapeIndex,
public val mutableBuffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) { ) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[strides.offset(index)] = value buffer[indexes.offset(index)] = value
} }
} }
@ -80,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) }) MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -6,108 +6,186 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.*
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.math.pow
public class DoubleBufferND(
indexes: ShapeIndex,
override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer)
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
override fun StructureND<Double>.toBufferND(): DoubleBufferND = when (this) {
is DoubleBufferND -> this
else -> {
val indexer = indexerBuilder(shape)
DoubleBufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
}
}
private inline fun mapInline(
arg: DoubleBufferND,
transform: (Double) -> Double
): DoubleBufferND {
val indexes = arg.indexes
val array = arg.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
}
private inline fun zipInline(
l: DoubleBufferND,
r: DoubleBufferND,
block: (l: Double, r: Double) -> Double
): DoubleBufferND {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes
val lArray = l.buffer.array
val rArray = r.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })
}
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): BufferND<Double> =
mapInline(toBufferND()) { DoubleField.transform(it) }
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
val indexer = indexerBuilder(shape)
return DoubleBufferND(
indexer,
DoubleBuffer(indexer.linearSize) { offset ->
elementAlgebra.initializer(indexer.index(offset))
}
)
}
override fun add(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l + r }
override fun multiply(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l, r -> l * r }
override fun StructureND<Double>.unaryMinus(): DoubleBufferND = mapInline(toBufferND()) { -it }
override fun StructureND<Double>.div(other: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), other.toBufferND()) { l, r -> l / r }
override fun divide(left: StructureND<Double>, right: StructureND<Double>): DoubleBufferND =
zipInline(left.toBufferND(), right.toBufferND()) { l: Double, r: Double -> l / r }
override fun StructureND<Double>.div(arg: Double): DoubleBufferND =
mapInline(toBufferND()) { it / arg }
override fun Double.div(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { this / it }
override fun StructureND<Double>.unaryPlus(): DoubleBufferND = toBufferND()
override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l + r }
override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l - r }
override fun StructureND<Double>.times(other: StructureND<Double>): DoubleBufferND =
zipInline(toBufferND(), other.toBufferND()) { l: Double, r: Double -> l * r }
override fun StructureND<Double>.times(k: Number): DoubleBufferND =
mapInline(toBufferND()) { it * k.toDouble() }
override fun StructureND<Double>.div(k: Number): DoubleBufferND =
mapInline(toBufferND()) { it / k.toDouble() }
override fun Number.times(other: StructureND<Double>): DoubleBufferND = other * this
override fun StructureND<Double>.plus(arg: Double): DoubleBufferND = mapInline(toBufferND()) { it + arg }
override fun StructureND<Double>.minus(arg: Double): StructureND<Double> = mapInline(toBufferND()) { it - arg }
override fun Double.plus(arg: StructureND<Double>): StructureND<Double> = arg + this
override fun Double.minus(arg: StructureND<Double>): StructureND<Double> = mapInline(arg.toBufferND()) { this - it }
override fun scale(a: StructureND<Double>, value: Double): DoubleBufferND =
mapInline(a.toBufferND()) { it * value }
override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND =
mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) }
override fun exp(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.exp(it) }
override fun ln(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.ln(it) }
override fun sin(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.sin(it) }
override fun cos(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.cos(it) }
override fun tan(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.tan(it) }
override fun asin(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.asin(it) }
override fun acos(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.acos(it) }
override fun atan(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.atan(it) }
override fun sinh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.sinh(it) }
override fun cosh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.cosh(it) }
override fun tanh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.tanh(it) }
override fun asinh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.asinh(it) }
override fun acosh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.acosh(it) }
override fun atanh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.atanh(it) }
public companion object : DoubleFieldOpsND()
}
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class DoubleFieldND( public class DoubleFieldND(override val shape: Shape) :
shape: IntArray, DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
) : BufferedFieldND<Double, DoubleField>(shape, DoubleField, ::DoubleBuffer),
NumbersAddOperations<StructureND<Double>>,
ScaleOperations<StructureND<Double>>,
ExtendedField<StructureND<Double>> {
override val zero: BufferND<Double> by lazy { produce { zero } } override fun number(value: Number): DoubleBufferND {
override val one: BufferND<Double> by lazy { produce { one } }
override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce { d } return structureND(shape) { d }
} }
override val StructureND<Double>.buffer: DoubleBuffer
get() = when {
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
this@DoubleFieldND.shape,
shape
)
this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.map(
transform: DoubleField.(Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
return BufferND(strides, buffer)
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset ->
val index = strides.index(offset)
DoubleField.initializer(index)
}
return BufferND(strides, DoubleBuffer(array))
}
@Suppress("OVERRIDE_BY_INLINE")
override inline fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): BufferND<Double> = BufferND(
strides,
buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(
strides.index(offset),
buffer.array[offset]
)
})
@Suppress("OVERRIDE_BY_INLINE")
override inline fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): BufferND<Double> {
val buffer = DoubleBuffer(strides.linearSize) { offset ->
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
}
return BufferND(strides, buffer)
}
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
} }
public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape) public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
/** /**
* Produce a context for n-dimensional operations inside this real field * Produce a context for n-dimensional operations inside this real field
*/ */
@UnstableKMathAPI
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R { public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return DoubleFieldND(shape).run(action) return DoubleFieldND(shape).run(action)

View File

@ -0,0 +1,119 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import kotlin.native.concurrent.ThreadLocal
/**
* A converter from linear index to multivariate index
*/
public interface ShapeIndex{
public val shape: Shape
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray>
override fun equals(other: Any?): Boolean
override fun hashCode(): Int
}
/**
* Linear transformation of indexes
*/
public abstract class Strides: ShapeIndex {
/**
* Array strides
*/
public abstract val strides: IntArray
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
override val linearSize: Int get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
public companion object {
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
}
@ThreadLocal
private val defaultStridesCache = HashMap<IntArray, Strides>()

View File

@ -6,34 +6,27 @@
package space.kscience.kmath.nd package space.kscience.kmath.nd
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.NumbersAddOperations import space.kscience.kmath.operations.NumbersAddOps
import space.kscience.kmath.operations.ShortRing import space.kscience.kmath.operations.ShortRing
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.operations.bufferAlgebra
import space.kscience.kmath.structures.ShortBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
public sealed class ShortRingOpsND : BufferedRingOpsND<Short, ShortRing>(ShortRing.bufferAlgebra) {
public companion object : ShortRingOpsND()
}
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public class ShortRingND( public class ShortRingND(
shape: IntArray, override val shape: Shape
) : BufferedRingND<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto), ) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
NumbersAddOperations<StructureND<Short>> {
override val zero: BufferND<Short> by lazy { produce { zero } }
override val one: BufferND<Short> by lazy { produce { one } }
override fun number(value: Number): BufferND<Short> { override fun number(value: Number): BufferND<Short> {
val d = value.toShort() // minimize conversions val d = value.toShort() // minimize conversions
return produce { d } return structureND(shape) { d }
} }
} }
/**
* Fast element production using function inlining.
*/
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R { public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) } contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
return ShortRingND(shape).run(action) return ShortRingND(shape).run(action)

View File

@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
import kotlin.math.abs import kotlin.math.abs
import kotlin.native.concurrent.ThreadLocal
import kotlin.reflect.KClass import kotlin.reflect.KClass
public interface StructureFeature : Feature<StructureFeature> public interface StructureFeature : Feature<StructureFeature>
@ -72,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -88,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides) if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -187,11 +186,11 @@ public fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.contentEquals(
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
*/ */
@PerformancePitfall @PerformancePitfall
public fun <T : Comparable<T>> GroupND<T, Ring<T>>.contentEquals( public fun <T : Comparable<T>> GroupOpsND<T, Ring<T>>.contentEquals(
st1: StructureND<T>, st1: StructureND<T>,
st2: StructureND<T>, st2: StructureND<T>,
absoluteTolerance: T, absoluteTolerance: T,
): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance } ): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance }
/** /**
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance]. * Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
@ -231,107 +230,10 @@ public interface MutableStructureND<T> : StructureND<T> {
* Transform a structure element-by element in place. * Transform a structure element-by element in place.
*/ */
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (IntArray, T) -> T): Unit = public inline fun <T> MutableStructureND<T>.mapInPlace(action: (index: IntArray, t: T) -> T): Unit =
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) } elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
/** public inline fun <reified T : Any> StructureND<T>.zip(
* A way to convert ND indices to linear one and back.
*/
public interface Strides {
/**
* Shape of NDStructure
*/
public val shape: IntArray
/**
* Array strides
*/
public val strides: IntArray
/**
* Get linear index from multidimensional index
*/
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
value * strides[i]
}.sum()
/**
* Get multidimensional from linear
*/
public fun index(offset: Int): IntArray
/**
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
*/
public val linearSize: Int
// TODO introduce a fast way to calculate index of the next element?
/**
* Iterate over ND indices in a natural order
*/
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
}
/**
* Simple implementation of [Strides].
*/
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
override val linearSize: Int
get() = strides[shape.size]
/**
* Strides for memory access
*/
override val strides: IntArray by lazy {
sequence {
var current = 1
yield(1)
shape.forEach {
current *= it
yield(current)
}
}.toList().toIntArray()
}
override fun index(offset: Int): IntArray {
val res = IntArray(shape.size)
var current = offset
var strideIndex = strides.size - 2
while (strideIndex >= 0) {
res[strideIndex] = (current / strides[strideIndex])
current %= strides[strideIndex]
strideIndex--
}
return res
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is DefaultStrides) return false
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int = shape.contentHashCode()
@ThreadLocal
public companion object {
private val defaultStridesCache = HashMap<IntArray, Strides>()
/**
* Cached builder for default strides
*/
public operator fun invoke(shape: IntArray): Strides =
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
}
}
public inline fun <reified T : Any> StructureND<T>.combine(
struct: StructureND<T>, struct: StructureND<T>,
crossinline block: (T, T) -> T, crossinline block: (T, T) -> T,
): StructureND<T> { ): StructureND<T> {

View File

@ -0,0 +1,34 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.nd
import space.kscience.kmath.operations.Algebra
import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.Ring
import kotlin.jvm.JvmName
public fun <T, A : Algebra<T>> AlgebraND<T, A>.structureND(
shapeFirst: Int,
vararg shapeRest: Int,
initializer: A.(IntArray) -> T
): StructureND<T> = structureND(Shape(shapeFirst, *shapeRest), initializer)
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = structureND(shape) { zero }
@JvmName("zeroVarArg")
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = structureND(shapeFirst, *shapeRest) { zero }
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = structureND(shape) { one }
@JvmName("oneVarArg")
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
shapeFirst: Int,
vararg shapeRest: Int,
): StructureND<T> = structureND(shapeFirst, *shapeRest) { one }

View File

@ -117,15 +117,15 @@ public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = r
* *
* @param T the type of element of this semispace. * @param T the type of element of this semispace.
*/ */
public interface GroupOperations<T> : Algebra<T> { public interface GroupOps<T> : Algebra<T> {
/** /**
* Addition of two elements. * Addition of two elements.
* *
* @param a the augend. * @param left the augend.
* @param b the addend. * @param right the addend.
* @return the sum. * @return the sum.
*/ */
public fun add(a: T, b: T): T public fun add(left: T, right: T): T
// Operations to be performed in this context. Could be moved to extensions in case of KEEP-176. // Operations to be performed in this context. Could be moved to extensions in case of KEEP-176.
@ -149,20 +149,20 @@ public interface GroupOperations<T> : Algebra<T> {
* Addition of two elements. * Addition of two elements.
* *
* @receiver the augend. * @receiver the augend.
* @param b the addend. * @param other the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(b: T): T = add(this, b) public operator fun T.plus(other: T): T = add(this, other)
/** /**
* Subtraction of two elements. * Subtraction of two elements.
* *
* @receiver the minuend. * @receiver the minuend.
* @param b the subtrahend. * @param other the subtrahend.
* @return the difference. * @return the difference.
*/ */
public operator fun T.minus(b: T): T = add(this, -b) public operator fun T.minus(other: T): T = add(this, -other)
// Dynamic dispatch of operations
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) { override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
PLUS_OPERATION -> { arg -> +arg } PLUS_OPERATION -> { arg -> +arg }
MINUS_OPERATION -> { arg -> -arg } MINUS_OPERATION -> { arg -> -arg }
@ -193,7 +193,7 @@ public interface GroupOperations<T> : Algebra<T> {
* *
* @param T the type of element of this semispace. * @param T the type of element of this semispace.
*/ */
public interface Group<T> : GroupOperations<T> { public interface Group<T> : GroupOps<T> {
/** /**
* The neutral element of addition. * The neutral element of addition.
*/ */
@ -206,22 +206,22 @@ public interface Group<T> : GroupOperations<T> {
* *
* @param T the type of element of this semiring. * @param T the type of element of this semiring.
*/ */
public interface RingOperations<T> : GroupOperations<T> { public interface RingOps<T> : GroupOps<T> {
/** /**
* Multiplies two elements. * Multiplies two elements.
* *
* @param a the multiplier. * @param left the multiplier.
* @param b the multiplicand. * @param right the multiplicand.
*/ */
public fun multiply(a: T, b: T): T public fun multiply(left: T, right: T): T
/** /**
* Multiplies this element by scalar. * Multiplies this element by scalar.
* *
* @receiver the multiplier. * @receiver the multiplier.
* @param b the multiplicand. * @param other the multiplicand.
*/ */
public operator fun T.times(b: T): T = multiply(this, b) public operator fun T.times(other: T): T = multiply(this, other)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
TIMES_OPERATION -> ::multiply TIMES_OPERATION -> ::multiply
@ -242,7 +242,7 @@ public interface RingOperations<T> : GroupOperations<T> {
* *
* @param T the type of element of this ring. * @param T the type of element of this ring.
*/ */
public interface Ring<T> : Group<T>, RingOperations<T> { public interface Ring<T> : Group<T>, RingOps<T> {
/** /**
* The neutral element of multiplication * The neutral element of multiplication
*/ */
@ -256,24 +256,24 @@ public interface Ring<T> : Group<T>, RingOperations<T> {
* *
* @param T the type of element of this semifield. * @param T the type of element of this semifield.
*/ */
public interface FieldOperations<T> : RingOperations<T> { public interface FieldOps<T> : RingOps<T> {
/** /**
* Division of two elements. * Division of two elements.
* *
* @param a the dividend. * @param left the dividend.
* @param b the divisor. * @param right the divisor.
* @return the quotient. * @return the quotient.
*/ */
public fun divide(a: T, b: T): T public fun divide(left: T, right: T): T
/** /**
* Division of two elements. * Division of two elements.
* *
* @receiver the dividend. * @receiver the dividend.
* @param b the divisor. * @param other the divisor.
* @return the quotient. * @return the quotient.
*/ */
public operator fun T.div(b: T): T = divide(this, b) public operator fun T.div(other: T): T = divide(this, other)
override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) { override fun binaryOperationFunction(operation: String): (left: T, right: T) -> T = when (operation) {
DIV_OPERATION -> ::divide DIV_OPERATION -> ::divide
@ -295,6 +295,6 @@ public interface FieldOperations<T> : RingOperations<T> {
* *
* @param T the type of element of this field. * @param T the type of element of this field.
*/ */
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> { public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlgebra<T> {
override fun number(value: Number): T = scale(one, value.toDouble()) override fun number(value: Number): T = scale(one, value.toDouble())
} }

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.BufferedRingND import space.kscience.kmath.nd.BufferedRingOpsND
import space.kscience.kmath.operations.BigInt.Companion.BASE import space.kscience.kmath.operations.BigInt.Companion.BASE
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
@ -26,7 +26,7 @@ private typealias TBase = ULong
* @author Peter Klimai * @author Peter Klimai
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> { public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
override val zero: BigInt = BigInt.ZERO override val zero: BigInt = BigInt.ZERO
override val one: BigInt = BigInt.ONE override val one: BigInt = BigInt.ONE
@ -34,10 +34,10 @@ public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOp
@Suppress("EXTENSION_SHADOWED_BY_MEMBER") @Suppress("EXTENSION_SHADOWED_BY_MEMBER")
override fun BigInt.unaryMinus(): BigInt = -this override fun BigInt.unaryMinus(): BigInt = -this
override fun add(a: BigInt, b: BigInt): BigInt = a.plus(b) override fun add(left: BigInt, right: BigInt): BigInt = left.plus(right)
override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value)) override fun scale(a: BigInt, value: Double): BigInt = a.times(number(value))
override fun multiply(a: BigInt, b: BigInt): BigInt = a.times(b) override fun multiply(left: BigInt, right: BigInt): BigInt = left.times(right)
override fun divide(a: BigInt, b: BigInt): BigInt = a.div(b) override fun divide(left: BigInt, right: BigInt): BigInt = left.div(right)
public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer") public operator fun String.unaryPlus(): BigInt = this.parseBigInteger() ?: error("Can't parse $this as big integer")
public operator fun String.unaryMinus(): BigInt = public operator fun String.unaryMinus(): BigInt =
@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> = public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
Buffer.boxing(size, initializer) Buffer.boxing(size, initializer)
public fun BigIntField.nd(vararg shape: Int): BufferedRingND<BigInt, BigIntField> = public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
BufferedRingND(shape, BigIntField, BigInt::buffer) get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer))

View File

@ -5,32 +5,34 @@
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.ShortBuffer
public interface WithSize {
public val size: Int
}
/** /**
* An algebra over [Buffer] * An algebra over [Buffer]
*/ */
@UnstableKMathAPI public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
public val bufferFactory: BufferFactory<T>
public val elementAlgebra: A public val elementAlgebra: A
public val size: Int public val bufferFactory: BufferFactory<T>
public fun buffer(vararg elements: T): Buffer<T> { public fun buffer(size: Int, vararg elements: T): Buffer<T> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" } require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it] } return bufferFactory(size) { elements[it] }
} }
//TODO move to multi-receiver inline extension //TODO move to multi-receiver inline extension
public fun Buffer<T>.map(block: (T) -> T): Buffer<T> = bufferFactory(size) { block(get(it)) } public fun Buffer<T>.map(block: A.(T) -> T): Buffer<T> = mapInline(this, block)
public fun Buffer<T>.zip(other: Buffer<T>, block: (left: T, right: T) -> T): Buffer<T> { public fun Buffer<T>.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer<T> = mapIndexedInline(this, block)
require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" }
return bufferFactory(size) { block(this[it], other[it]) } public fun Buffer<T>.zip(other: Buffer<T>, block: A.(left: T, right: T) -> T): Buffer<T> =
} zipInline(this, other, block)
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> { override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
val operationFunction = elementAlgebra.unaryOperationFunction(operation) val operationFunction = elementAlgebra.unaryOperationFunction(operation)
@ -45,112 +47,149 @@ public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
} }
} }
@UnstableKMathAPI /**
public fun <T> BufferField<T, *>.buffer(initializer: (Int) -> T): Buffer<T> { * Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
buffer: Buffer<T>,
crossinline block: A.(T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
/**
* Inline map
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
buffer: Buffer<T>,
crossinline block: A.(index: Int, arg: T) -> T
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
/**
* Inline zip
*/
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
l: Buffer<T>,
r: Buffer<T>,
crossinline block: A.(l: T, r: T) -> T
): Buffer<T> {
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
}
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
return bufferFactory(size, initializer)
}
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
return bufferFactory(size, initializer) return bufferFactory(size, initializer)
} }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sin) mapInline(arg) { sin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cos) mapInline(arg) { cos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tan) mapInline(arg) { tan(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asin) mapInline(arg) { asin(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acos) mapInline(arg) { acos(it) }
@UnstableKMathAPI
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> = public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atan) mapInline(arg) { atan(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::exp) mapInline(arg) { exp(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::ln) mapInline(arg) { ln(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::sinh) mapInline(arg) { sinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::cosh) mapInline(arg) { cosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::tanh) mapInline(arg) { tanh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::asinh) mapInline(arg) { asinh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::acosh) mapInline(arg) { acosh(it) }
@UnstableKMathAPI
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> = public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
arg.map(elementAlgebra::atanh) mapInline(arg) { atanh(it) }
@UnstableKMathAPI
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> = public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
with(elementAlgebra) { arg.map { power(it, pow) } } mapInline(arg) { power(it, pow) }
@UnstableKMathAPI public open class BufferRingOps<T, A: Ring<T>>(
public class BufferField<T, A : Field<T>>(
override val bufferFactory: BufferFactory<T>,
override val elementAlgebra: A, override val elementAlgebra: A,
override val bufferFactory: BufferFactory<T>,
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.unaryOperationFunction(operation)
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferAlgebra>.binaryOperationFunction(operation)
}
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
get() = BufferRingOps(ShortRing, ::ShortBuffer)
public open class BufferFieldOps<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
override fun add(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l + r }
override fun multiply(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l * r }
override fun divide(left: Buffer<T>, right: Buffer<T>): Buffer<T> = zipInline(left, right) { l, r -> l / r }
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
super<BufferRingOps>.binaryOperationFunction(operation)
}
public class BufferField<T, A : Field<T>>(
elementAlgebra: A,
bufferFactory: BufferFactory<T>,
override val size: Int override val size: Int
) : BufferAlgebra<T, A>, Field<Buffer<T>> { ) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero } override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one } override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::add)
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::multiply)
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::divide)
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = with(elementAlgebra) { a.map { scale(it, value) } }
override fun Buffer<T>.unaryMinus(): Buffer<T> = with(elementAlgebra) { map { -it } }
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.unaryOperationFunction(operation)
}
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> {
return super<BufferAlgebra>.binaryOperationFunction(operation)
}
} }
/**
* Generate full buffer field from given buffer operations
*/
public fun <T, A : Field<T>> BufferFieldOps<T, A>.withSize(size: Int): BufferField<T, A> =
BufferField(elementAlgebra, bufferFactory, size)
//Double buffer specialization //Double buffer specialization
@UnstableKMathAPI
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> { public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
require(elements.size == size) { "Expected $size elements but found ${elements.size}" } require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
return bufferFactory(size) { elements[it].toDouble() } return bufferFactory(size) { elements[it].toDouble() }
} }
@UnstableKMathAPI public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>): BufferFieldOps<T, A> =
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>, size: Int): BufferField<T, A> = BufferFieldOps(this, bufferFactory)
BufferField(bufferFactory, this, size)
@UnstableKMathAPI public val DoubleField.bufferAlgebra: BufferFieldOps<Double, DoubleField>
public fun DoubleField.bufferAlgebra(size: Int): BufferField<Double, DoubleField> = get() = BufferFieldOps(DoubleField, ::DoubleBuffer)
BufferField(::DoubleBuffer, DoubleField, size)

View File

@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer
* *
* @property size the size of buffers to operate on. * @property size the size of buffers to operate on.
*/ */
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOperations() { public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOps() {
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } } override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } } override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.sinh(arg) override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.sinh(arg)
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.cosh(arg) override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.cosh(arg)
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.tanh(arg) override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.tanh(arg)
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.asinh(arg) override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.asinh(arg)
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.acosh(arg) override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.acosh(arg)
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOperations>.atanh(arg) override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOps>.atanh(arg)
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() } // override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
// //

View File

@ -12,39 +12,40 @@ import space.kscience.kmath.structures.DoubleBuffer
import kotlin.math.* import kotlin.math.*
/** /**
* [ExtendedFieldOperations] over [DoubleBuffer]. * [ExtendedFieldOps] over [DoubleBuffer].
*/ */
public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Double>>, Norm<Buffer<Double>, Double> { public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) { override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
DoubleBuffer(size) { -array[it] } DoubleBuffer(size) { -array[it] }
} else { } else {
DoubleBuffer(size) { -get(it) } DoubleBuffer(size) { -get(it) }
} }
override fun add(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun add(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] + bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] + bArray[it] })
} else DoubleBuffer(DoubleArray(a.size) { a[it] + b[it] }) } else DoubleBuffer(DoubleArray(left.size) { left[it] + right[it] })
} }
override fun Buffer<Double>.plus(b: Buffer<Double>): DoubleBuffer = add(this, b) override fun Buffer<Double>.plus(other: Buffer<Double>): DoubleBuffer = add(this, other)
override fun Buffer<Double>.minus(b: Buffer<Double>): DoubleBuffer { override fun Buffer<Double>.minus(other: Buffer<Double>): DoubleBuffer {
require(b.size == this.size) { require(other.size == this.size) {
"The size of the first buffer ${this.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${this.size} should be the same as for second one: ${other.size} "
} }
return if (this is DoubleBuffer && b is DoubleBuffer) { return if (this is DoubleBuffer && other is DoubleBuffer) {
val aArray = this.array val aArray = this.array
val bArray = b.array val bArray = other.array
DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] }) DoubleBuffer(DoubleArray(this.size) { aArray[it] - bArray[it] })
} else DoubleBuffer(DoubleArray(this.size) { this[it] - b[it] }) } else DoubleBuffer(DoubleArray(this.size) { this[it] - other[it] })
} }
// //
@ -66,29 +67,29 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
// } else RealBuffer(DoubleArray(a.size) { a[it] / kValue }) // } else RealBuffer(DoubleArray(a.size) { a[it] / kValue })
// } // }
override fun multiply(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun multiply(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] * bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] * bArray[it] })
} else } else
DoubleBuffer(DoubleArray(a.size) { a[it] * b[it] }) DoubleBuffer(DoubleArray(left.size) { left[it] * right[it] })
} }
override fun divide(a: Buffer<Double>, b: Buffer<Double>): DoubleBuffer { override fun divide(left: Buffer<Double>, right: Buffer<Double>): DoubleBuffer {
require(b.size == a.size) { require(right.size == left.size) {
"The size of the first buffer ${a.size} should be the same as for second one: ${b.size} " "The size of the first buffer ${left.size} should be the same as for second one: ${right.size} "
} }
return if (a is DoubleBuffer && b is DoubleBuffer) { return if (left is DoubleBuffer && right is DoubleBuffer) {
val aArray = a.array val aArray = left.array
val bArray = b.array val bArray = right.array
DoubleBuffer(DoubleArray(a.size) { aArray[it] / bArray[it] }) DoubleBuffer(DoubleArray(left.size) { aArray[it] / bArray[it] })
} else DoubleBuffer(DoubleArray(a.size) { a[it] / b[it] }) } else DoubleBuffer(DoubleArray(left.size) { left[it] / right[it] })
} }
override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) { override fun sin(arg: Buffer<Double>): DoubleBuffer = if (arg is DoubleBuffer) {
@ -185,7 +186,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value }) DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value }) } else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
public companion object : DoubleBufferOperations() public companion object : DoubleBufferOps()
} }
public object DoubleL2Norm : Norm<Point<Double>, Double> { public object DoubleL2Norm : Norm<Point<Double>, Double> {

View File

@ -139,10 +139,10 @@ public interface ScaleOperations<T> : Algebra<T> {
* Multiplication of this number by element. * Multiplication of this number by element.
* *
* @receiver the multiplier. * @receiver the multiplier.
* @param b the multiplicand. * @param other the multiplicand.
* @return the product. * @return the product.
*/ */
public operator fun Number.times(b: T): T = b * this public operator fun Number.times(other: T): T = other * this
} }
/** /**
@ -150,38 +150,38 @@ public interface ScaleOperations<T> : Algebra<T> {
* TODO to be removed and replaced by extensions after multiple receivers are there * TODO to be removed and replaced by extensions after multiple receivers are there
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public interface NumbersAddOperations<T> : Ring<T>, NumericAlgebra<T> { public interface NumbersAddOps<T> : RingOps<T>, NumericAlgebra<T> {
/** /**
* Addition of element and scalar. * Addition of element and scalar.
* *
* @receiver the augend. * @receiver the augend.
* @param b the addend. * @param other the addend.
*/ */
public operator fun T.plus(b: Number): T = this + number(b) public operator fun T.plus(other: Number): T = this + number(other)
/** /**
* Addition of scalar and element. * Addition of scalar and element.
* *
* @receiver the augend. * @receiver the augend.
* @param b the addend. * @param other the addend.
*/ */
public operator fun Number.plus(b: T): T = b + this public operator fun Number.plus(other: T): T = other + this
/** /**
* Subtraction of element from number. * Subtraction of element from number.
* *
* @receiver the minuend. * @receiver the minuend.
* @param b the subtrahend. * @param other the subtrahend.
* @receiver the difference. * @receiver the difference.
*/ */
public operator fun T.minus(b: Number): T = this - number(b) public operator fun T.minus(other: Number): T = this - number(other)
/** /**
* Subtraction of number from element. * Subtraction of number from element.
* *
* @receiver the minuend. * @receiver the minuend.
* @param b the subtrahend. * @param other the subtrahend.
* @receiver the difference. * @receiver the difference.
*/ */
public operator fun Number.minus(b: T): T = -b + this public operator fun Number.minus(other: T): T = -other + this
} }

View File

@ -10,8 +10,8 @@ import kotlin.math.pow as kpow
/** /**
* Advanced Number-like semifield that implements basic operations. * Advanced Number-like semifield that implements basic operations.
*/ */
public interface ExtendedFieldOperations<T> : public interface ExtendedFieldOps<T> :
FieldOperations<T>, FieldOps<T>,
TrigonometricOperations<T>, TrigonometricOperations<T>,
PowerOperations<T>, PowerOperations<T>,
ExponentialOperations<T>, ExponentialOperations<T>,
@ -35,14 +35,14 @@ public interface ExtendedFieldOperations<T> :
ExponentialOperations.ACOSH_OPERATION -> ::acosh ExponentialOperations.ACOSH_OPERATION -> ::acosh
ExponentialOperations.ASINH_OPERATION -> ::asinh ExponentialOperations.ASINH_OPERATION -> ::asinh
ExponentialOperations.ATANH_OPERATION -> ::atanh ExponentialOperations.ATANH_OPERATION -> ::atanh
else -> super<FieldOperations>.unaryOperationFunction(operation) else -> super<FieldOps>.unaryOperationFunction(operation)
} }
} }
/** /**
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T>{ public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T>{
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
@ -73,10 +73,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -102,10 +102,10 @@ public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOp
override inline fun norm(arg: Double): Double = abs(arg) override inline fun norm(arg: Double): Double = abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }
public val Double.Companion.algebra: DoubleField get() = DoubleField public val Double.Companion.algebra: DoubleField get() = DoubleField
@ -126,12 +126,12 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
else -> super.binaryOperationFunction(operation) else -> super.binaryOperationFunction(operation)
} }
override inline fun add(a: Float, b: Float): Float = a + b override inline fun add(left: Float, right: Float): Float = left + right
override fun scale(a: Float, value: Double): Float = a * value.toFloat() override fun scale(a: Float, value: Double): Float = a * value.toFloat()
override inline fun multiply(a: Float, b: Float): Float = a * b override inline fun multiply(left: Float, right: Float): Float = left * right
override inline fun divide(a: Float, b: Float): Float = a / b override inline fun divide(left: Float, right: Float): Float = left / right
override inline fun sin(arg: Float): Float = kotlin.math.sin(arg) override inline fun sin(arg: Float): Float = kotlin.math.sin(arg)
override inline fun cos(arg: Float): Float = kotlin.math.cos(arg) override inline fun cos(arg: Float): Float = kotlin.math.cos(arg)
@ -155,10 +155,10 @@ public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override inline fun norm(arg: Float): Float = abs(arg) override inline fun norm(arg: Float): Float = abs(arg)
override inline fun Float.unaryMinus(): Float = -this override inline fun Float.unaryMinus(): Float = -this
override inline fun Float.plus(b: Float): Float = this + b override inline fun Float.plus(other: Float): Float = this + other
override inline fun Float.minus(b: Float): Float = this - b override inline fun Float.minus(other: Float): Float = this - other
override inline fun Float.times(b: Float): Float = this * b override inline fun Float.times(other: Float): Float = this * other
override inline fun Float.div(b: Float): Float = this / b override inline fun Float.div(other: Float): Float = this / other
} }
public val Float.Companion.algebra: FloatField get() = FloatField public val Float.Companion.algebra: FloatField get() = FloatField
@ -175,14 +175,14 @@ public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
get() = 1 get() = 1
override fun number(value: Number): Int = value.toInt() override fun number(value: Number): Int = value.toInt()
override inline fun add(a: Int, b: Int): Int = a + b override inline fun add(left: Int, right: Int): Int = left + right
override inline fun multiply(a: Int, b: Int): Int = a * b override inline fun multiply(left: Int, right: Int): Int = left * right
override inline fun norm(arg: Int): Int = abs(arg) override inline fun norm(arg: Int): Int = abs(arg)
override inline fun Int.unaryMinus(): Int = -this override inline fun Int.unaryMinus(): Int = -this
override inline fun Int.plus(b: Int): Int = this + b override inline fun Int.plus(other: Int): Int = this + other
override inline fun Int.minus(b: Int): Int = this - b override inline fun Int.minus(other: Int): Int = this - other
override inline fun Int.times(b: Int): Int = this * b override inline fun Int.times(other: Int): Int = this * other
} }
public val Int.Companion.algebra: IntRing get() = IntRing public val Int.Companion.algebra: IntRing get() = IntRing
@ -199,14 +199,14 @@ public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short>
get() = 1 get() = 1
override fun number(value: Number): Short = value.toShort() override fun number(value: Number): Short = value.toShort()
override inline fun add(a: Short, b: Short): Short = (a + b).toShort() override inline fun add(left: Short, right: Short): Short = (left + right).toShort()
override inline fun multiply(a: Short, b: Short): Short = (a * b).toShort() override inline fun multiply(left: Short, right: Short): Short = (left * right).toShort()
override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort() override fun norm(arg: Short): Short = if (arg > 0) arg else (-arg).toShort()
override inline fun Short.unaryMinus(): Short = (-this).toShort() override inline fun Short.unaryMinus(): Short = (-this).toShort()
override inline fun Short.plus(b: Short): Short = (this + b).toShort() override inline fun Short.plus(other: Short): Short = (this + other).toShort()
override inline fun Short.minus(b: Short): Short = (this - b).toShort() override inline fun Short.minus(other: Short): Short = (this - other).toShort()
override inline fun Short.times(b: Short): Short = (this * b).toShort() override inline fun Short.times(other: Short): Short = (this * other).toShort()
} }
public val Short.Companion.algebra: ShortRing get() = ShortRing public val Short.Companion.algebra: ShortRing get() = ShortRing
@ -223,14 +223,14 @@ public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
get() = 1 get() = 1
override fun number(value: Number): Byte = value.toByte() override fun number(value: Number): Byte = value.toByte()
override inline fun add(a: Byte, b: Byte): Byte = (a + b).toByte() override inline fun add(left: Byte, right: Byte): Byte = (left + right).toByte()
override inline fun multiply(a: Byte, b: Byte): Byte = (a * b).toByte() override inline fun multiply(left: Byte, right: Byte): Byte = (left * right).toByte()
override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte() override fun norm(arg: Byte): Byte = if (arg > 0) arg else (-arg).toByte()
override inline fun Byte.unaryMinus(): Byte = (-this).toByte() override inline fun Byte.unaryMinus(): Byte = (-this).toByte()
override inline fun Byte.plus(b: Byte): Byte = (this + b).toByte() override inline fun Byte.plus(other: Byte): Byte = (this + other).toByte()
override inline fun Byte.minus(b: Byte): Byte = (this - b).toByte() override inline fun Byte.minus(other: Byte): Byte = (this - other).toByte()
override inline fun Byte.times(b: Byte): Byte = (this * b).toByte() override inline fun Byte.times(other: Byte): Byte = (this * other).toByte()
} }
public val Byte.Companion.algebra: ByteRing get() = ByteRing public val Byte.Companion.algebra: ByteRing get() = ByteRing
@ -247,14 +247,14 @@ public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
get() = 1L get() = 1L
override fun number(value: Number): Long = value.toLong() override fun number(value: Number): Long = value.toLong()
override inline fun add(a: Long, b: Long): Long = a + b override inline fun add(left: Long, right: Long): Long = left + right
override inline fun multiply(a: Long, b: Long): Long = a * b override inline fun multiply(left: Long, right: Long): Long = left * right
override fun norm(arg: Long): Long = abs(arg) override fun norm(arg: Long): Long = abs(arg)
override inline fun Long.unaryMinus(): Long = (-this) override inline fun Long.unaryMinus(): Long = (-this)
override inline fun Long.plus(b: Long): Long = (this + b) override inline fun Long.plus(other: Long): Long = (this + other)
override inline fun Long.minus(b: Long): Long = (this - b) override inline fun Long.minus(other: Long): Long = (this - other)
override inline fun Long.times(b: Long): Long = (this * b) override inline fun Long.times(other: Long): Long = (this * other)
} }
public val Long.Companion.algebra: LongRing get() = LongRing public val Long.Companion.algebra: LongRing get() = LongRing

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.get import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.FieldVerifier import space.kscience.kmath.testutils.FieldVerifier
@ -21,7 +22,7 @@ internal class NDFieldTest {
@Test @Test
fun testStrides() { fun testStrides() {
val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() } val ndArray = DoubleField.ndAlgebra.structureND(10, 10) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0) assertEquals(ndArray[5, 5], 10.0)
} }
} }

View File

@ -7,10 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.linear.linearSpace import space.kscience.kmath.linear.linearSpace
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.*
import space.kscience.kmath.nd.combine
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Norm import space.kscience.kmath.operations.Norm
import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.algebra
@ -22,9 +19,9 @@ import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest { class NumberNDFieldTest {
val algebra = DoubleField.ndAlgebra(3, 3) val algebra = DoubleField.ndAlgebra
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() } val array1 = algebra.structureND(3, 3) { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() } val array2 = algebra.structureND(3, 3) { (i, j) -> (i - j).toDouble() }
@Test @Test
fun testSum() { fun testSum() {
@ -77,7 +74,7 @@ class NumberNDFieldTest {
@Test @Test
fun combineTest() { fun combineTest() {
val division = array1.combine(array2, Double::div) val division = array1.zip(array2, Double::div)
} }
object L2Norm : Norm<StructureND<Number>, Double> { object L2Norm : Norm<StructureND<Number>, Double> {

View File

@ -18,9 +18,9 @@ public object JBigIntegerField : Ring<BigInteger>, NumericAlgebra<BigInteger> {
override val one: BigInteger get() = BigInteger.ONE override val one: BigInteger get() = BigInteger.ONE
override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong()) override fun number(value: Number): BigInteger = BigInteger.valueOf(value.toLong())
override fun add(a: BigInteger, b: BigInteger): BigInteger = a.add(b) override fun add(left: BigInteger, right: BigInteger): BigInteger = left.add(right)
override operator fun BigInteger.minus(b: BigInteger): BigInteger = subtract(b) override operator fun BigInteger.minus(other: BigInteger): BigInteger = subtract(other)
override fun multiply(a: BigInteger, b: BigInteger): BigInteger = a.multiply(b) override fun multiply(left: BigInteger, right: BigInteger): BigInteger = left.multiply(right)
override operator fun BigInteger.unaryMinus(): BigInteger = negate() override operator fun BigInteger.unaryMinus(): BigInteger = negate()
} }
@ -39,15 +39,15 @@ public abstract class JBigDecimalFieldBase internal constructor(
override val one: BigDecimal override val one: BigDecimal
get() = BigDecimal.ONE get() = BigDecimal.ONE
override fun add(a: BigDecimal, b: BigDecimal): BigDecimal = a.add(b) override fun add(left: BigDecimal, right: BigDecimal): BigDecimal = left.add(right)
override operator fun BigDecimal.minus(b: BigDecimal): BigDecimal = subtract(b) override operator fun BigDecimal.minus(other: BigDecimal): BigDecimal = subtract(other)
override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble()) override fun number(value: Number): BigDecimal = BigDecimal.valueOf(value.toDouble())
override fun scale(a: BigDecimal, value: Double): BigDecimal = override fun scale(a: BigDecimal, value: Double): BigDecimal =
a.multiply(value.toBigDecimal(mathContext), mathContext) a.multiply(value.toBigDecimal(mathContext), mathContext)
override fun multiply(a: BigDecimal, b: BigDecimal): BigDecimal = a.multiply(b, mathContext) override fun multiply(left: BigDecimal, right: BigDecimal): BigDecimal = left.multiply(right, mathContext)
override fun divide(a: BigDecimal, b: BigDecimal): BigDecimal = a.divide(b, mathContext) override fun divide(left: BigDecimal, right: BigDecimal): BigDecimal = left.divide(right, mathContext)
override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext) override fun power(arg: BigDecimal, pow: Number): BigDecimal = arg.pow(pow.toInt(), mathContext)
override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext) override fun sqrt(arg: BigDecimal): BigDecimal = arg.sqrt(mathContext)
override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext) override operator fun BigDecimal.unaryMinus(): BigDecimal = negate(mathContext)

View File

@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.runningReduce import kotlinx.coroutines.flow.runningReduce
import kotlinx.coroutines.flow.scan import kotlinx.coroutines.flow.scan
import space.kscience.kmath.operations.GroupOperations import space.kscience.kmath.operations.GroupOps
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
public fun <T> Flow<T>.cumulativeSum(group: GroupOperations<T>): Flow<T> = public fun <T> Flow<T>.cumulativeSum(group: GroupOps<T>): Flow<T> =
group { runningReduce { sum, element -> sum + element } } group { runningReduce { sum, element -> sum + element } }
@ExperimentalCoroutinesApi @ExperimentalCoroutinesApi

View File

@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
* Map one [BufferND] using function without indices. * Map one [BufferND] using function without indices.
*/ */
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> { public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) } val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(strides, DoubleBuffer(array)) return BufferND(indexes, DoubleBuffer(array))
} }
/** /**

View File

@ -104,12 +104,12 @@ public class PolynomialSpace<T, C>(
Polynomial(coefficients.map { -it }) Polynomial(coefficients.map { -it })
} }
override fun add(a: Polynomial<T>, b: Polynomial<T>): Polynomial<T> { override fun add(left: Polynomial<T>, right: Polynomial<T>): Polynomial<T> {
val dim = max(a.coefficients.size, b.coefficients.size) val dim = max(left.coefficients.size, right.coefficients.size)
return ring { return ring {
Polynomial(List(dim) { index -> Polynomial(List(dim) { index ->
a.coefficients.getOrElse(index) { zero } + b.coefficients.getOrElse(index) { zero } left.coefficients.getOrElse(index) { zero } + right.coefficients.getOrElse(index) { zero }
}) })
} }
} }

View File

@ -47,7 +47,7 @@ public object Euclidean2DSpace : GeometrySpace<Vector2D>, ScaleOperations<Vector
override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y) override fun Vector2D.unaryMinus(): Vector2D = Vector2D(-x, -y)
override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm() override fun Vector2D.distanceTo(other: Vector2D): Double = (this - other).norm()
override fun add(a: Vector2D, b: Vector2D): Vector2D = Vector2D(a.x + b.x, a.y + b.y) override fun add(left: Vector2D, right: Vector2D): Vector2D = Vector2D(left.x + right.x, left.y + right.y)
override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value) override fun scale(a: Vector2D, value: Double): Vector2D = Vector2D(a.x * value, a.y * value)
override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y override fun Vector2D.dot(other: Vector2D): Double = x * other.x + y * other.y
} }

View File

@ -47,8 +47,8 @@ public object Euclidean3DSpace : GeometrySpace<Vector3D>, ScaleOperations<Vector
override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm() override fun Vector3D.distanceTo(other: Vector3D): Double = (this - other).norm()
override fun add(a: Vector3D, b: Vector3D): Vector3D = override fun add(left: Vector3D, right: Vector3D): Vector3D =
Vector3D(a.x + b.x, a.y + b.y, a.z + b.z) Vector3D(left.x + right.x, left.y + right.y, left.z + right.z)
override fun scale(a: Vector3D, value: Double): Vector3D = override fun scale(a: Vector3D, value: Double): Vector3D =
Vector3D(a.x * value, a.y * value, a.z * value) Vector3D(a.x * value, a.y * value, a.z * value)

View File

@ -28,10 +28,9 @@ public class DoubleHistogramSpace(
public val dimension: Int get() = lower.size public val dimension: Int get() = lower.size
private val shape = IntArray(binNums.size) { binNums[it] + 2 } override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape) override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
override val strides: Strides get() = histogramValueSpace.strides
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] } private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
/** /**
@ -52,7 +51,7 @@ public class DoubleHistogramSpace(
val lowerBoundary = index.mapIndexed { axis, i -> val lowerBoundary = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> Double.NEGATIVE_INFINITY 0 -> Double.NEGATIVE_INFINITY
strides.shape[axis] - 1 -> upper[axis] shape[axis] - 1 -> upper[axis]
else -> lower[axis] + (i.toDouble()) * binSize[axis] else -> lower[axis] + (i.toDouble()) * binSize[axis]
} }
}.asBuffer() }.asBuffer()
@ -60,7 +59,7 @@ public class DoubleHistogramSpace(
val upperBoundary = index.mapIndexed { axis, i -> val upperBoundary = index.mapIndexed { axis, i ->
when (i) { when (i) {
0 -> lower[axis] 0 -> lower[axis]
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY shape[axis] - 1 -> Double.POSITIVE_INFINITY
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis] else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
} }
}.asBuffer() }.asBuffer()
@ -75,7 +74,7 @@ public class DoubleHistogramSpace(
} }
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> { override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
val ndCounter = StructureND.auto(strides) { Counter.real() } val ndCounter = StructureND.auto(shape) { Counter.real() }
val hBuilder = HistogramBuilder<Double> { point, value -> val hBuilder = HistogramBuilder<Double> { point, value ->
val index = getIndex(point) val index = getIndex(point)
ndCounter[index].add(value.toDouble()) ndCounter[index].add(value.toDouble())

View File

@ -8,8 +8,9 @@ package space.kscience.kmath.histogram
import space.kscience.kmath.domains.Domain import space.kscience.kmath.domains.Domain
import space.kscience.kmath.linear.Point import space.kscience.kmath.linear.Point
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.FieldND import space.kscience.kmath.nd.FieldND
import space.kscience.kmath.nd.Strides import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Group import space.kscience.kmath.operations.Group
import space.kscience.kmath.operations.ScaleOperations import space.kscience.kmath.operations.ScaleOperations
@ -34,10 +35,10 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
return context.produceBin(index, values[index]) return context.produceBin(index, values[index])
} }
override val dimension: Int get() = context.strides.shape.size override val dimension: Int get() = context.shape.size
override val bins: Iterable<Bin<T>> override val bins: Iterable<Bin<T>>
get() = context.strides.indices().map { get() = DefaultStrides(context.shape).indices().map {
context.produceBin(it, values[it]) context.produceBin(it, values[it])
}.asIterable() }.asIterable()
@ -49,7 +50,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any> public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> { : Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
//public val valueSpace: Space<V> //public val valueSpace: Space<V>
public val strides: Strides public val shape: Shape
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape), public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
/** /**
@ -66,10 +67,10 @@ public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V> public fun produce(builder: HistogramBuilder<T>.() -> Unit): IndexedHistogram<T, V>
override fun add(a: IndexedHistogram<T, V>, b: IndexedHistogram<T, V>): IndexedHistogram<T, V> { override fun add(left: IndexedHistogram<T, V>, right: IndexedHistogram<T, V>): IndexedHistogram<T, V> {
require(a.context == this) { "Can't operate on a histogram produced by external space" } require(left.context == this) { "Can't operate on a histogram produced by external space" }
require(b.context == this) { "Can't operate on a histogram produced by external space" } require(right.context == this) { "Can't operate on a histogram produced by external space" }
return IndexedHistogram(this, histogramValueSpace { a.values + b.values }) return IndexedHistogram(this, histogramValueSpace { left.values + right.values })
} }
override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> { override fun scale(a: IndexedHistogram<T, V>, value: Double): IndexedHistogram<T, V> {

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.histogram package space.kscience.kmath.histogram
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.real.DoubleVector import space.kscience.kmath.real.DoubleVector
import kotlin.random.Random import kotlin.random.Random
@ -69,7 +70,7 @@ internal class MultivariateHistogramTest {
} }
val res = histogram1 - histogram2 val res = histogram1 - histogram2
assertTrue { assertTrue {
strides.indices().all { index -> DefaultStrides(shape).indices().all { index ->
res.values[index] <= histogram1.values[index] res.values[index] <= histogram1.values[index]
} }
} }

View File

@ -88,20 +88,20 @@ public class TreeHistogramSpace(
TreeHistogramBuilder(binFactory).apply(block).build() TreeHistogramBuilder(binFactory).apply(block).build()
override fun add( override fun add(
a: UnivariateHistogram, left: UnivariateHistogram,
b: UnivariateHistogram, right: UnivariateHistogram,
): UnivariateHistogram { ): UnivariateHistogram {
// require(a.context == this) { "Histogram $a does not belong to this context" } // require(a.context == this) { "Histogram $a does not belong to this context" }
// require(b.context == this) { "Histogram $b does not belong to this context" } // require(b.context == this) { "Histogram $b does not belong to this context" }
val bins = TreeMap<Double, UnivariateBin>().apply { val bins = TreeMap<Double, UnivariateBin>().apply {
(a.bins.map { it.domain } union b.bins.map { it.domain }).forEach { def -> (left.bins.map { it.domain } union right.bins.map { it.domain }).forEach { def ->
put( put(
def.center, def.center,
UnivariateBin( UnivariateBin(
def, def,
value = (a[def.center]?.value ?: 0.0) + (b[def.center]?.value ?: 0.0), value = (left[def.center]?.value ?: 0.0) + (right[def.center]?.value ?: 0.0),
standardDeviation = (a[def.center]?.standardDeviation standardDeviation = (left[def.center]?.standardDeviation
?: 0.0) + (b[def.center]?.standardDeviation ?: 0.0) ?: 0.0) + (right[def.center]?.standardDeviation ?: 0.0)
) )
) )
} }

View File

@ -28,10 +28,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -57,10 +57,10 @@ public object JafamaDoubleField : ExtendedField<Double>, Norm<Double, Double>, S
override inline fun norm(arg: Double): Double = FastMath.abs(arg) override inline fun norm(arg: Double): Double = FastMath.abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }
/** /**
@ -79,10 +79,10 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
else -> super<ExtendedField>.binaryOperationFunction(operation) else -> super<ExtendedField>.binaryOperationFunction(operation)
} }
override inline fun add(a: Double, b: Double): Double = a + b override inline fun add(left: Double, right: Double): Double = left + right
override inline fun multiply(a: Double, b: Double): Double = a * b override inline fun multiply(left: Double, right: Double): Double = left * right
override inline fun divide(a: Double, b: Double): Double = a / b override inline fun divide(left: Double, right: Double): Double = left / right
override inline fun scale(a: Double, value: Double): Double = a * value override inline fun scale(a: Double, value: Double): Double = a * value
@ -108,8 +108,8 @@ public object StrictJafamaDoubleField : ExtendedField<Double>, Norm<Double, Doub
override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg) override inline fun norm(arg: Double): Double = StrictFastMath.abs(arg)
override inline fun Double.unaryMinus(): Double = -this override inline fun Double.unaryMinus(): Double = -this
override inline fun Double.plus(b: Double): Double = this + b override inline fun Double.plus(other: Double): Double = this + other
override inline fun Double.minus(b: Double): Double = this - b override inline fun Double.minus(other: Double): Double = this - other
override inline fun Double.times(b: Double): Double = this * b override inline fun Double.times(other: Double): Double = this * other
override inline fun Double.div(b: Double): Double = this / b override inline fun Double.div(other: Double): Double = this / other
} }

View File

@ -106,8 +106,8 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
is Symbol -> toSVar() is Symbol -> toSVar()
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> +value.toSFun<X>() GroupOps.PLUS_OPERATION -> +value.toSFun<X>()
GroupOperations.MINUS_OPERATION -> -value.toSFun<X>() GroupOps.MINUS_OPERATION -> -value.toSFun<X>()
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun()) TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun()) TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun()) TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
@ -124,10 +124,10 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
} }
is MST.Binary -> when (operation) { is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun() GroupOps.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
GroupOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun() GroupOps.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun() RingOps.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun() FieldOps.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst() PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }

View File

@ -0,0 +1,14 @@
plugins {
id("ru.mipt.npm.gradle.jvm")
}
description = "JetBrains Multik connector"
dependencies {
api(project(":kmath-tensors"))
api("org.jetbrains.kotlinx:multik-default:0.1.0")
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
}

View File

@ -0,0 +1,137 @@
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.math.cos
import org.jetbrains.kotlinx.multik.api.math.sin
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.RingOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.*
/**
* A ring algebra for Multik operations
*/
public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
public val type: DataType,
override val elementAlgebra: A
) : RingOpsND<T, A> {
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val res = mk.zeros<T, DN>(shape, type).asDNArray()
for (index in res.multiIndices) {
res[index] = elementAlgebra.initializer(index)
}
return res.wrap()
}
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
this
} else {
structureND(shape) { get(it) }
}
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el)
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
val indexIter = array.multiIndices.iterator()
var index = 0
for (item in array) {
if (indexIter.hasNext()) {
data[index++] = elementAlgebra.transform(indexIter.next(), item)
} else {
throw ArithmeticException("Index overflow has happened.")
}
}
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type)
var counter = 0
val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator()
//iterating them together
while (leftIterator.hasNext()) {
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
}
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
}
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = asMultik().array.unaryMinus().wrap()
override fun add(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
(left.asMultik().array + right.asMultik().array).wrap()
override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
asMultik().array.plus(arg).wrap()
override fun StructureND<T>.minus(arg: T): MultikTensor<T> = asMultik().array.minus(arg).wrap()
override fun T.plus(arg: StructureND<T>): MultikTensor<T> = arg + this
override fun T.minus(arg: StructureND<T>): MultikTensor<T> = arg.map { this@minus - it }
override fun multiply(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
left.asMultik().array.times(right.asMultik().array).wrap()
override fun StructureND<T>.times(arg: T): MultikTensor<T> =
asMultik().array.times(arg).wrap()
override fun T.times(arg: StructureND<T>): MultikTensor<T> = arg * this
override fun StructureND<T>.unaryPlus(): MultikTensor<T> = asMultik()
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap()
override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap()
override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap()
}
/**
* A field algebra for multik operations
*/
public class MultikFieldOpsND<T, A : Field<T>> internal constructor(
type: DataType,
elementAlgebra: A
) : MultikRingOpsND<T, A>(type, elementAlgebra), FieldOpsND<T, A> {
override fun StructureND<T>.div(other: StructureND<T>): StructureND<T> =
asMultik().array.div(other.asMultik().array).wrap()
}
public val DoubleField.multikND: MultikFieldOpsND<Double, DoubleField>
get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField)
public val FloatField.multikND: MultikFieldOpsND<Float, FloatField>
get() = MultikFieldOpsND(DataType.FloatDataType, FloatField)
public val ShortRing.multikND: MultikRingOpsND<Short, ShortRing>
get() = MultikRingOpsND(DataType.ShortDataType, ShortRing)
public val IntRing.multikND: MultikRingOpsND<Int, IntRing>
get() = MultikRingOpsND(DataType.IntDataType, IntRing)
public val LongRing.multikND: MultikRingOpsND<Long, LongRing>
get() = MultikRingOpsND(DataType.LongDataType, LongRing)

View File

@ -0,0 +1,214 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
@JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: IntArray get() = array.shape
override fun get(index: IntArray): T = array[index]
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> =
array.multiIndices.iterator().asSequence().map { it to get(it) }
override fun set(index: IntArray, value: T) {
array[index] = value
}
}
public class MultikTensorAlgebra<T> internal constructor(
public val type: DataType,
public val elementAlgebra: Ring<T>,
public val comparator: Comparator<T>
) : TensorAlgebra<T> {
/**
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source
*/
public fun Tensor<T>.asMultik(): MultikTensor<T> {
return if (this is MultikTensor) {
this
} else {
val res = mk.zeros<T, DN>(shape, type).asDNArray()
for (index in res.multiIndices) {
res[index] = this[index]
}
res.wrap()
}
}
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0))
} else null
override fun T.plus(other: Tensor<T>): MultikTensor<T> =
other.plus(this)
override fun Tensor<T>.plus(value: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap()
override fun Tensor<T>.plusAssign(value: T) {
if (this is MultikTensor) {
array.plusAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.add(t, value) }
}
}
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.plusAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.add(t, other[index]) }
}
}
override fun T.minus(other: Tensor<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap()
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap()
override fun Tensor<T>.minusAssign(value: T) {
if (this is MultikTensor) {
array.minusAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.run { t - value } }
}
}
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.minusAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.run { t - other[index] } }
}
}
override fun T.times(other: Tensor<T>): MultikTensor<T> =
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
override fun Tensor<T>.times(value: T): Tensor<T> =
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap()
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap()
override fun Tensor<T>.timesAssign(value: T) {
if (this is MultikTensor) {
array.timesAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.multiply(t, value) }
}
}
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.timesAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.multiply(t, other[index]) }
}
}
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
require(shape.all { it > 0 })
require(shape.fold(1, Int::times) == this.shape.size) {
"Cannot reshape array of size ${this.shape.size} into a new shape ${
shape.joinToString(
prefix = "(",
postfix = ")"
)
}"
}
val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) {
@Suppress("UNCHECKED_CAST")
this as NDArray<T, DN>
} else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap()
}
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape)
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
TODO("Diagonal embedding not implemented")
}
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
elementAlgebra.add(acc, t)
}
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.min(): T =
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.max(): T =
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
}
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double>
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) }
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float>
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short>
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) }
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int>
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) }
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long>
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) }

View File

@ -0,0 +1,13 @@
package space.kscience.kmath.multik
import org.junit.jupiter.api.Test
import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
internal class MultikNDTest {
@Test
fun basicAlgebra(): Unit = DoubleField.multikND{
one(2,2) + 1.0
}
}

View File

@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
return array
}
/** /**
* Represents [AlgebraND] over [Nd4jArrayAlgebra]. * Represents [AlgebraND] over [Nd4jArrayAlgebra].
* *
@ -39,33 +32,35 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> { override fun structureND(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap() val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) } struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
return struct return struct
} }
@OptIn(PerformancePitfall::class)
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> { override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
val newStruct = ndArray.dup().wrap() val newStruct = ndArray.dup().wrap()
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) } newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
return newStruct return newStruct
} }
override fun StructureND<T>.mapIndexed( override fun StructureND<T>.mapIndexed(
transform: C.(index: IntArray, T) -> T, transform: C.(index: IntArray, T) -> T,
): Nd4jArrayStructure<T> { ): Nd4jArrayStructure<T> {
val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap() val new = Nd4j.create(*shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) } new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) }
return new return new
} }
override fun combine( override fun zip(
a: StructureND<T>, left: StructureND<T>,
b: StructureND<T>, right: StructureND<T>,
transform: C.(T, T) -> T, transform: C.(T, T) -> T,
): Nd4jArrayStructure<T> { ): Nd4jArrayStructure<T> {
val new = Nd4j.create(*shape).wrap() require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) } val new = Nd4j.create(*left.shape).wrap()
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
return new return new
} }
} }
@ -76,16 +71,13 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements. * @param S the type of space of structure elements.
*/ */
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> { public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
override val zero: Nd4jArrayStructure<T> override fun add(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
get() = Nd4j.zeros(*shape).wrap() left.ndArray.add(right.ndArray).wrap()
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> = override operator fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.add(b.ndArray).wrap() ndArray.sub(other.ndArray).wrap()
override operator fun StructureND<T>.minus(b: StructureND<T>): Nd4jArrayStructure<T> =
ndArray.sub(b.ndArray).wrap()
override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = override operator fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> =
ndArray.neg().wrap() ndArray.neg().wrap()
@ -101,13 +93,10 @@ public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4j
* @param R the type of ring of structure elements. * @param R the type of ring of structure elements.
*/ */
@OptIn(UnstableKMathAPI::class) @OptIn(UnstableKMathAPI::class)
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> { public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
override val one: Nd4jArrayStructure<T> override fun multiply(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
get() = Nd4j.ones(*shape).wrap() left.ndArray.mul(right.ndArray).wrap()
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.mul(b.ndArray).wrap()
// //
// override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> { // override operator fun Nd4jArrayStructure<T>.minus(b: Number): Nd4jArrayStructure<T> {
// check(this) // check(this)
@ -125,21 +114,12 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
// } // }
public companion object { public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/** /**
* Creates a most suitable implementation of [RingND] using reified class. * Creates a most suitable implementation of [RingND] using reified class.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when { public inline fun <reified T : Number> auto(): Nd4jArrayRingOps<T, Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>> T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Long type.") else -> throw UnsupportedOperationException("This factory method only supports Long type.")
} }
} }
@ -151,38 +131,21 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
* @param T the type of the element contained in ND structure. * @param T the type of the element contained in ND structure.
* @param F the type field of structure elements. * @param F the type field of structure elements.
*/ */
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> { public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.div(b.ndArray).wrap() override fun divide(left: StructureND<T>, right: StructureND<T>): Nd4jArrayStructure<T> =
left.ndArray.div(right.ndArray).wrap()
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap() public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
public companion object { public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [FieldND] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, DoubleField> =
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
/** /**
* Creates a most suitable implementation of [FieldND] using reified class. * Creates a most suitable implementation of [FieldND] using reified class.
*/ */
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when { public inline fun <reified T : Any> auto(): Nd4jArrayField<T, Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>> T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>> T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
} }
} }
@ -191,8 +154,9 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4
/** /**
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
*/ */
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>, public sealed interface Nd4jArrayExtendedFieldOps<T, out F : ExtendedField<T>> :
Nd4jArrayField<T, F> { ExtendedFieldOps<StructureND<T>>, Nd4jArrayField<T, F> {
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap() override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap() override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap() override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
@ -221,63 +185,59 @@ public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : Ex
/** /**
* Represents [FieldND] over [Nd4jArrayDoubleStructure]. * Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/ */
public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> { public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Double, DoubleField> {
override val elementContext: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray override val StructureND<Double>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Double> -> checkShape(ndArray) is Nd4jArrayStructure<Double> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
} }
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> { override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> = a.ndArray.mul(value).wrap()
return a.ndArray.mul(value).wrap()
}
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> = ndArray.div(arg).wrap()
return ndArray.div(arg).wrap()
}
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> = ndArray.add(arg).wrap()
return ndArray.add(arg).wrap()
}
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> = ndArray.sub(arg).wrap()
return ndArray.sub(arg).wrap()
}
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> { override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> = ndArray.mul(arg).wrap()
return ndArray.mul(arg).wrap()
}
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> { override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
return arg.ndArray.rdiv(this).wrap() arg.ndArray.rdiv(this).wrap()
}
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> { override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
return arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
}
public companion object : DoubleNd4jArrayFieldOps()
} }
public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape)) public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps
public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND<Double, DoubleField>
public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField =
DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/** /**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float]. * Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/ */
public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> { public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Float, FloatField> {
override val elementContext: FloatField get() = FloatField override val elementAlgebra: FloatField get() = FloatField
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Float> = asFloatStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Float>.ndArray: INDArray override val StructureND<Float>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Float> -> checkShape(ndArray) is Nd4jArrayStructure<Float> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
@ -303,21 +263,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> = override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
public companion object : FloatNd4jArrayFieldOps()
} }
public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND<Float, FloatField>
public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps
public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField =
FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
/** /**
* Represents [RingND] over [Nd4jArrayIntStructure]. * Represents [RingND] over [Nd4jArrayIntStructure].
*/ */
public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> { public open class IntNd4jArrayRingOps : Nd4jArrayRingOps<Int, IntRing> {
override val elementContext: IntRing override val elementAlgebra: IntRing get() = IntRing
get() = IntRing
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Int> = asIntStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Int>.ndArray: INDArray override val StructureND<Int>.ndArray: INDArray
get() = when (this) { get() = when (this) {
is Nd4jArrayStructure<Int> -> checkShape(ndArray) is Nd4jArrayStructure<Int> -> ndArray
else -> Nd4j.zeros(*shape).also { else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) } elements().forEach { (idx, value) -> it.putScalar(idx, value) }
} }
@ -334,4 +302,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int,
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> = override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
arg.ndArray.rsub(this).wrap() arg.ndArray.rsub(this).wrap()
public companion object : IntNd4jArrayRingOps()
} }
public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps
public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND<Int, IntRing>
public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing =
IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest))

View File

@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j
import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.one
import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.math.PI import kotlin.math.PI
import kotlin.test.Test import kotlin.test.Test
@ -19,7 +23,7 @@ import kotlin.test.fail
internal class Nd4jArrayAlgebraTest { internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testProduce() { fun testProduce() {
val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } } val res = DoubleField.nd4j.structureND(2, 2) { it.sum().toDouble() }
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0 expected[intArrayOf(0, 1)] = 1.0
@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testMap() { fun testMap() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } } val res = IntRing.nd4j {
one(2, 2).map { it + it * 2 }
}
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 3 expected[intArrayOf(0, 0)] = 3
expected[intArrayOf(0, 1)] = 3 expected[intArrayOf(0, 1)] = 3
@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testAdd() { fun testAdd() {
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 } val res = IntRing.nd4j { one(2, 2) + 25 }
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
expected[intArrayOf(0, 0)] = 26 expected[intArrayOf(0, 0)] = 26
expected[intArrayOf(0, 1)] = 26 expected[intArrayOf(0, 1)] = 26
@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest {
} }
@Test @Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke { fun testSin() = DoubleField.nd4j{
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 } val initial = structureND(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
val transformed = sin(initial) val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 } val expected = structureND(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed) println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) } assertTrue { StructureND.contentEquals(transformed, expected) }

View File

@ -41,8 +41,8 @@ public class SamplerSpace<T : Any, out S>(public val algebra: S) : Group<Sampler
override val zero: Sampler<T> = ConstantSampler(algebra.zero) override val zero: Sampler<T> = ConstantSampler(algebra.zero)
override fun add(a: Sampler<T>, b: Sampler<T>): Sampler<T> = BasicSampler { generator -> override fun add(left: Sampler<T>, right: Sampler<T>): Sampler<T> = BasicSampler { generator ->
a.sample(generator).zip(b.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } } left.sample(generator).zip(right.sample(generator)) { aValue, bValue -> algebra { aValue + bValue } }
} }
override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator -> override fun scale(a: Sampler<T>, value: Double): Sampler<T> = BasicSampler { generator ->

View File

@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) {
} }
is MST.Unary -> when (operation) { is MST.Unary -> when (operation) {
GroupOperations.PLUS_OPERATION -> value.toIExpr() GroupOps.PLUS_OPERATION -> value.toIExpr()
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr()) GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr())
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr()) TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr()) TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr()) TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) {
} }
is MST.Binary -> when (operation) { is MST.Binary -> when (operation) {
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr() GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr() GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr() RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr()) FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value)) PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
else -> error("Binary operation $operation not defined in $this") else -> error("Binary operation $operation not defined in $this")
} }

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.RingOps
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
* *
* @param T the type of items in the tensors. * @param T the type of items in the tensors.
*/ */
public interface TensorAlgebra<T> : Algebra<Tensor<T>> { public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
/** /**
* Returns a single tensor value of unit dimension if tensor shape equals to [1]. * Returns a single tensor value of unit dimension if tensor shape equals to [1].
* *
@ -53,7 +53,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this tensor and [other]. * @return the sum of this tensor and [other].
*/ */
public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor. * Adds the scalar [value] to each element of this tensor.
@ -93,7 +93,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this tensor and [other]. * @return the difference between this tensor and [other].
*/ */
public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor. * Subtracts the scalar [value] from each element of this tensor.
@ -134,7 +134,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this tensor and [other]. * @return the product of this tensor and [other].
*/ */
public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T> override fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor. * Multiplies the scalar [value] by each element of this tensor.
@ -155,7 +155,7 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* *
* @return tensor negation of the original tensor. * @return tensor negation of the original tensor.
*/ */
public operator fun Tensor<T>.unaryMinus(): Tensor<T> override fun Tensor<T>.unaryMinus(): Tensor<T>
/** /**
* Returns the tensor at index i * Returns the tensor at index i
@ -323,4 +323,8 @@ public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
* @return the index of maximum value of each row of the input tensor in the given dimension [dim]. * @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right
} }

View File

@ -22,7 +22,7 @@ import kotlin.math.*
public open class DoubleTensorAlgebra : public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>, TensorPartialDivisionAlgebra<Double>,
AnalyticTensorAlgebra<Double>, AnalyticTensorAlgebra<Double>,
LinearOpsTensorAlgebra<Double> { LinearOpsTensorAlgebra<Double>{
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()
@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int): override fun diagonalEmbedding(
DoubleTensor { diagonalEntries: Tensor<Double>,
offset: Int,
dim1: Int,
dim2: Int
): DoubleTensor {
val n = diagonalEntries.shape.size val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1) val d1 = minusIndexFrom(n + 1, dim1)
val d2 = minusIndexFrom(n + 1, dim2) val d2 = minusIndexFrom(n + 1, dim2)

View File

@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
* *
* @param shape the shape of the tensor. * @param shape the shape of the tensor.
*/ */
internal class TensorLinearStructure(override val shape: IntArray) : Strides { internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
override val strides: IntArray override val strides: IntArray
get() = stridesFromShape(shape) get() = stridesFromShape(shape)
@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides {
override val linearSize: Int override val linearSize: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false
other as TensorLinearStructure
if (!shape.contentEquals(other.shape)) return false
return true
}
override fun hashCode(): Int {
return shape.contentHashCode()
}
} }

View File

@ -26,8 +26,11 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() BufferedTensor(this.shape, this.buffer, 0)
} else {
this.copyToBufferedTensor()
}
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }

View File

@ -5,4 +5,12 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
public fun DoubleTensorAlgebra.ones(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape)) import space.kscience.kmath.nd.Shape
import kotlin.jvm.JvmName
@JvmName("varArgOne")
public fun DoubleTensorAlgebra.one(vararg shape: Int): DoubleTensor = ones(intArrayOf(*shape))
public fun DoubleTensorAlgebra.one(shape: Shape): DoubleTensor = ones(shape)
@JvmName("varArgZero")
public fun DoubleTensorAlgebra.zero(vararg shape: Int): DoubleTensor = zeros(intArrayOf(*shape))
public fun DoubleTensorAlgebra.zero(shape: Shape): DoubleTensor = zeros(shape)

View File

@ -0,0 +1,124 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.viktor
import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedFieldOps
import space.kscience.kmath.operations.NumbersAddOps
@OptIn(UnstableKMathAPI::class)
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public open class ViktorFieldOpsND :
FieldOpsND<Double, DoubleField>,
ExtendedFieldOps<StructureND<Double>> {
public val StructureND<Double>.f64Buffer: F64Array
get() = when (this) {
is ViktorStructureND -> this.f64Buffer
else -> structureND(shape) { this@f64Buffer[it] }.f64Buffer
}
override val elementAlgebra: DoubleField get() = DoubleField
override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
F64Array(*shape).apply {
DefaultStrides(shape).indices().forEach { index ->
set(value = DoubleField.initializer(index), indices = index)
}
}.asStructure()
override fun StructureND<Double>.unaryMinus(): StructureND<Double> = -1 * this
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): ViktorStructureND =
F64Array(*shape).apply {
DefaultStrides(shape).indices().forEach { index ->
set(value = DoubleField.transform(this@map[index]), indices = index)
}
}.asStructure()
override fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): ViktorStructureND = F64Array(*shape).apply {
DefaultStrides(shape).indices().forEach { index ->
set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index)
}
}.asStructure()
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): ViktorStructureND {
require(left.shape.contentEquals(right.shape))
return F64Array(*left.shape).apply {
DefaultStrides(left.shape).indices().forEach { index ->
set(value = DoubleField.transform(left[index], right[index]), indices = index)
}
}.asStructure()
}
override fun add(left: StructureND<Double>, right: StructureND<Double>): ViktorStructureND =
(left.f64Buffer + right.f64Buffer).asStructure()
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
(a.f64Buffer * value).asStructure()
override fun StructureND<Double>.plus(other: StructureND<Double>): ViktorStructureND =
(f64Buffer + other.f64Buffer).asStructure()
override fun StructureND<Double>.minus(other: StructureND<Double>): ViktorStructureND =
(f64Buffer - other.f64Buffer).asStructure()
override fun StructureND<Double>.times(k: Number): ViktorStructureND =
(f64Buffer * k.toDouble()).asStructure()
override fun StructureND<Double>.plus(arg: Double): ViktorStructureND =
(f64Buffer.plus(arg)).asStructure()
override fun sin(arg: StructureND<Double>): ViktorStructureND = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): ViktorStructureND = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): ViktorStructureND = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): ViktorStructureND = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): ViktorStructureND = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): ViktorStructureND = arg.map { atan(it) }
override fun power(arg: StructureND<Double>, pow: Number): ViktorStructureND = arg.map { it.pow(pow) }
override fun exp(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.exp().asStructure()
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
override fun sinh(arg: StructureND<Double>): ViktorStructureND = arg.map { sinh(it) }
override fun cosh(arg: StructureND<Double>): ViktorStructureND = arg.map { cosh(it) }
override fun asinh(arg: StructureND<Double>): ViktorStructureND = arg.map { asinh(it) }
override fun acosh(arg: StructureND<Double>): ViktorStructureND = arg.map { acosh(it) }
override fun atanh(arg: StructureND<Double>): ViktorStructureND = arg.map { atanh(it) }
public companion object : ViktorFieldOpsND()
}
public val DoubleField.viktorAlgebra: ViktorFieldOpsND get() = ViktorFieldOpsND
public open class ViktorFieldND(
override val shape: Shape
) : ViktorFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() }
override fun number(value: Number): ViktorStructureND =
F64Array.full(init = value.toDouble(), shape = shape).asStructure()
}
public fun DoubleField.viktorAlgebra(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)

View File

@ -7,12 +7,8 @@ package space.kscience.kmath.viktor
import org.jetbrains.bio.viktor.F64Array import org.jetbrains.bio.viktor.F64Array
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.MutableStructureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.NumbersAddOperations
import space.kscience.kmath.operations.ScaleOperations
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND<Double> { public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructureND<Double> {
@ -31,96 +27,4 @@ public class ViktorStructureND(public val f64Buffer: F64Array) : MutableStructur
public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this) public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this)
@OptIn(UnstableKMathAPI::class)
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
NumbersAddOperations<StructureND<Double>>, ExtendedField<StructureND<Double>>,
ScaleOperations<StructureND<Double>> {
public val StructureND<Double>.f64Buffer: F64Array
get() = when {
!shape.contentEquals(this@ViktorFieldND.shape) -> throw ShapeMismatchException(
this@ViktorFieldND.shape,
shape
)
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
else -> produce { this@f64Buffer[it] }.f64Buffer
}
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
override val one: ViktorStructureND by lazy { F64Array.full(init = 1.0, shape = shape).asStructure() }
private val strides: Strides = DefaultStrides(shape)
override val elementContext: DoubleField get() = DoubleField
override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.initializer(index), indices = index)
}
}.asStructure()
override fun StructureND<Double>.unaryMinus(): StructureND<Double> = -1 * this
override fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): ViktorStructureND =
F64Array(*this@ViktorFieldND.shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.transform(this@map[index]), indices = index)
}
}.asStructure()
override fun StructureND<Double>.mapIndexed(
transform: DoubleField.(index: IntArray, Double) -> Double,
): ViktorStructureND = F64Array(*this@ViktorFieldND.shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.transform(index, this@mapIndexed[index]), indices = index)
}
}.asStructure()
override fun combine(
a: StructureND<Double>,
b: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double,
): ViktorStructureND = F64Array(*shape).apply {
this@ViktorFieldND.strides.indices().forEach { index ->
set(value = DoubleField.transform(a[index], b[index]), indices = index)
}
}.asStructure()
override fun add(a: StructureND<Double>, b: StructureND<Double>): ViktorStructureND =
(a.f64Buffer + b.f64Buffer).asStructure()
override fun scale(a: StructureND<Double>, value: Double): ViktorStructureND =
(a.f64Buffer * value).asStructure()
override inline fun StructureND<Double>.plus(b: StructureND<Double>): ViktorStructureND =
(f64Buffer + b.f64Buffer).asStructure()
override inline fun StructureND<Double>.minus(b: StructureND<Double>): ViktorStructureND =
(f64Buffer - b.f64Buffer).asStructure()
override inline fun StructureND<Double>.times(k: Number): ViktorStructureND =
(f64Buffer * k.toDouble()).asStructure()
override inline fun StructureND<Double>.plus(arg: Double): ViktorStructureND =
(f64Buffer.plus(arg)).asStructure()
override fun number(value: Number): ViktorStructureND =
F64Array.full(init = value.toDouble(), shape = shape).asStructure()
override fun sin(arg: StructureND<Double>): ViktorStructureND = arg.map { sin(it) }
override fun cos(arg: StructureND<Double>): ViktorStructureND = arg.map { cos(it) }
override fun tan(arg: StructureND<Double>): ViktorStructureND = arg.map { tan(it) }
override fun asin(arg: StructureND<Double>): ViktorStructureND = arg.map { asin(it) }
override fun acos(arg: StructureND<Double>): ViktorStructureND = arg.map { acos(it) }
override fun atan(arg: StructureND<Double>): ViktorStructureND = arg.map { atan(it) }
override fun power(arg: StructureND<Double>, pow: Number): ViktorStructureND = arg.map { it.pow(pow) }
override fun exp(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.exp().asStructure()
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
}
public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)

View File

@ -1,16 +1,18 @@
pluginManagement { pluginManagement {
repositories { repositories {
mavenLocal()
maven("https://repo.kotlin.link") maven("https://repo.kotlin.link")
mavenCentral() mavenCentral()
gradlePluginPortal() gradlePluginPortal()
} }
val kotlinVersion = "1.6.0-M1" val kotlinVersion = "1.6.0-RC"
val toolsVersion = "0.10.5"
plugins { plugins {
id("org.jetbrains.kotlinx.benchmark") version "0.3.1" id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
id("ru.mipt.npm.gradle.project") version "0.10.5" id("ru.mipt.npm.gradle.project") version toolsVersion
id("ru.mipt.npm.gradle.jvm") version toolsVersion
id("ru.mipt.npm.gradle.mpp") version toolsVersion
kotlin("multiplatform") version kotlinVersion kotlin("multiplatform") version kotlinVersion
kotlin("plugin.allopen") version kotlinVersion kotlin("plugin.allopen") version kotlinVersion
} }
@ -30,6 +32,7 @@ include(
":kmath-histograms", ":kmath-histograms",
":kmath-commons", ":kmath-commons",
":kmath-viktor", ":kmath-viktor",
":kmath-multik",
":kmath-optimization", ":kmath-optimization",
":kmath-stat", ":kmath-stat",
":kmath-nd4j", ":kmath-nd4j",