From 29a90efca5eaa3c7669cae8c6dd0f24a8f31058f Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 27 Oct 2021 14:48:36 +0300 Subject: [PATCH] Tensor algebra generified --- CHANGELOG.md | 1 + .../kmath/benchmarks/NDFieldBenchmark.kt | 7 +- .../space/kscience/kmath/tensors/multik.kt | 6 +- .../kscience/kmath/multik/MultikOpsND.kt | 137 ------------- .../kmath/multik/MultikTensorAlgebra.kt | 187 ++++++++++++++---- .../kscience/kmath/multik/MultikNDTest.kt | 2 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 113 ++++++----- .../tensors/api/AnalyticTensorAlgebra.kt | 47 ++--- .../tensors/api/LinearOpsTensorAlgebra.kt | 15 +- .../kmath/tensors/api/TensorAlgebra.kt | 36 ++-- .../core/BroadcastDoubleTensorAlgebra.kt | 2 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 117 ++++++----- 12 files changed, 323 insertions(+), 347 deletions(-) delete mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index bb267744e..6733c1211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - Buffer algebra does not require size anymore - Operations -> Ops - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. +- Tensor algebra takes read-only structures as input and inherits AlgebraND ### Deprecated - Specialized `DoubleBufferAlgebra` diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 8f9a3e2b8..b5af5aa19 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -13,8 +13,7 @@ import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ones import org.jetbrains.kotlinx.multik.ndarray.data.DN import org.jetbrains.kotlinx.multik.ndarray.data.DataType -import space.kscience.kmath.multik.multikND -import space.kscience.kmath.multik.multikTensorAlgebra +import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.ndAlgebra @@ -79,7 +78,7 @@ internal class NDFieldBenchmark { } @Benchmark - fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) { + fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) { val res = Multik.ones(shape, DataType.DoubleDataType).wrap() repeat(n) { res += 1.0 } blackhole.consume(res) @@ -100,7 +99,7 @@ internal class NDFieldBenchmark { private val specializedField = DoubleField.ndAlgebra private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) private val nd4jField = DoubleField.nd4j - private val multikField = DoubleField.multikND + private val multikField = DoubleField.multikAlgebra private val viktorField = DoubleField.viktorAlgebra } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt index f0b776f75..fad68fa96 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt @@ -7,12 +7,12 @@ package space.kscience.kmath.tensors import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarray -import space.kscience.kmath.multik.multikND +import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField -fun main(): Unit = with(DoubleField.multikND) { +fun main(): Unit = with(DoubleField.multikAlgebra) { val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType().wrap() val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() - one(a.shape) - a + b * 3 + one(a.shape) - a + b * 3.0 } diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt deleted file mode 100644 index 4068ba9b9..000000000 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt +++ /dev/null @@ -1,137 +0,0 @@ -package space.kscience.kmath.multik - -import org.jetbrains.kotlinx.multik.api.math.cos -import org.jetbrains.kotlinx.multik.api.math.sin -import org.jetbrains.kotlinx.multik.api.mk -import org.jetbrains.kotlinx.multik.api.zeros -import org.jetbrains.kotlinx.multik.ndarray.data.* -import org.jetbrains.kotlinx.multik.ndarray.operations.* -import space.kscience.kmath.nd.FieldOpsND -import space.kscience.kmath.nd.RingOpsND -import space.kscience.kmath.nd.Shape -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.* - -/** - * A ring algebra for Multik operations - */ -public open class MultikRingOpsND> internal constructor( - public val type: DataType, - override val elementAlgebra: A -) : RingOpsND { - - public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) - - override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { - val res = mk.zeros(shape, type).asDNArray() - for (index in res.multiIndices) { - res[index] = elementAlgebra.initializer(index) - } - return res.wrap() - } - - public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { - this - } else { - structureND(shape) { get(it) } - } - - override fun StructureND.map(transform: A.(T) -> T): MultikTensor { - //taken directly from Multik sources - val array = asMultik().array - val data = initMemoryView(array.size, type) - var count = 0 - for (el in array) data[count++] = elementAlgebra.transform(el) - return NDArray(data, shape = array.shape, dim = array.dim).wrap() - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor { - //taken directly from Multik sources - val array = asMultik().array - val data = initMemoryView(array.size, type) - val indexIter = array.multiIndices.iterator() - var index = 0 - for (item in array) { - if (indexIter.hasNext()) { - data[index++] = elementAlgebra.transform(indexIter.next(), item) - } else { - throw ArithmeticException("Index overflow has happened.") - } - } - return NDArray(data, shape = array.shape, dim = array.dim).wrap() - } - - override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { - require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException - val leftArray = left.asMultik().array - val rightArray = right.asMultik().array - val data = initMemoryView(leftArray.size, type) - var counter = 0 - val leftIterator = leftArray.iterator() - val rightIterator = rightArray.iterator() - //iterating them together - while (leftIterator.hasNext()) { - data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next()) - } - return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap() - } - - override fun StructureND.unaryMinus(): MultikTensor = asMultik().array.unaryMinus().wrap() - - override fun add(left: StructureND, right: StructureND): MultikTensor = - (left.asMultik().array + right.asMultik().array).wrap() - - override fun StructureND.plus(arg: T): MultikTensor = - asMultik().array.plus(arg).wrap() - - override fun StructureND.minus(arg: T): MultikTensor = asMultik().array.minus(arg).wrap() - - override fun T.plus(arg: StructureND): MultikTensor = arg + this - - override fun T.minus(arg: StructureND): MultikTensor = arg.map { this@minus - it } - - override fun multiply(left: StructureND, right: StructureND): MultikTensor = - left.asMultik().array.times(right.asMultik().array).wrap() - - override fun StructureND.times(arg: T): MultikTensor = - asMultik().array.times(arg).wrap() - - override fun T.times(arg: StructureND): MultikTensor = arg * this - - override fun StructureND.unaryPlus(): MultikTensor = asMultik() - - override fun StructureND.plus(other: StructureND): MultikTensor = - asMultik().array.plus(other.asMultik().array).wrap() - - override fun StructureND.minus(other: StructureND): MultikTensor = - asMultik().array.minus(other.asMultik().array).wrap() - - override fun StructureND.times(other: StructureND): MultikTensor = - asMultik().array.times(other.asMultik().array).wrap() -} - -/** - * A field algebra for multik operations - */ -public class MultikFieldOpsND> internal constructor( - type: DataType, - elementAlgebra: A -) : MultikRingOpsND(type, elementAlgebra), FieldOpsND { - override fun StructureND.div(other: StructureND): StructureND = - asMultik().array.div(other.asMultik().array).wrap() -} - -public val DoubleField.multikND: MultikFieldOpsND - get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField) - -public val FloatField.multikND: MultikFieldOpsND - get() = MultikFieldOpsND(DataType.FloatDataType, FloatField) - -public val ShortRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.ShortDataType, ShortRing) - -public val IntRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.IntDataType, IntRing) - -public val LongRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.LongDataType, LongRing) \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index 3bbe8e672..ec0d4c6d9 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -15,12 +15,14 @@ import org.jetbrains.kotlinx.multik.api.zeros import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.* import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra +import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra @JvmInline public value class MultikTensor(public val array: MutableMultiArray) : Tensor { @@ -50,29 +52,80 @@ private fun MultiArray.asD2Array(): D2Array { else throw ClassCastException("Cannot cast MultiArray to NDArray.") } -public class MultikTensorAlgebra> internal constructor( - public val type: DataType, - override val elementAlgebra: A, - public val comparator: Comparator -) : TensorAlgebra { +public abstract class MultikTensorAlgebra> : TensorAlgebra where T : Number, T : Comparable { + + public abstract val type: DataType + + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { + val strides = DefaultStrides(shape) + val memoryView = initMemoryView(strides.linearSize, type) + strides.indices().forEachIndexed { linearIndex, tensorIndex -> + memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex) + } + return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size))) + } + + override fun StructureND.map(transform: A.(T) -> T): MultikTensor = if (this is MultikTensor) { + val data = initMemoryView(array.size, type) + var count = 0 + for (el in array) data[count++] = elementAlgebra.transform(el) + NDArray(data, shape = shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(get(index)) + } + } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor = + if (this is MultikTensor) { + val array = asMultik().array + val data = initMemoryView(array.size, type) + val indexIter = array.multiIndices.iterator() + var index = 0 + for (item in array) { + if (indexIter.hasNext()) { + data[index++] = elementAlgebra.transform(indexIter.next(), item) + } else { + throw ArithmeticException("Index overflow has happened.") + } + } + NDArray(data, shape = array.shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(index, get(index)) + } + } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { + require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException + val leftArray = left.asMultik().array + val rightArray = right.asMultik().array + val data = initMemoryView(leftArray.size, type) + var counter = 0 + val leftIterator = leftArray.iterator() + val rightIterator = rightArray.iterator() + //iterating them together + while (leftIterator.hasNext()) { + data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next()) + } + return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap() + } /** * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * are not reflected back onto the source */ - public fun StructureND.asMultik(): MultikTensor { - return if (this is MultikTensor) { - this - } else { - val res = mk.zeros(shape, type).asDNArray() - for (index in res.multiIndices) { - res[index] = this[index] - } - res.wrap() + public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { + this + } else { + val res = mk.zeros(shape, type).asDNArray() + for (index in res.multiIndices) { + res[index] = this[index] } + res.wrap() } - public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this) + public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) override fun StructureND.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { get(intArrayOf(0)) @@ -81,8 +134,8 @@ public class MultikTensorAlgebra> internal constructor( override fun T.plus(other: StructureND): MultikTensor = other.plus(this) - override fun StructureND.plus(value: T): MultikTensor = - asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() + override fun StructureND.plus(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap() override fun StructureND.plus(other: StructureND): MultikTensor = asMultik().array.plus(other.asMultik().array).wrap() @@ -155,11 +208,11 @@ public class MultikTensorAlgebra> internal constructor( override fun StructureND.unaryMinus(): MultikTensor = asMultik().array.unaryMinus().wrap() - override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() + override fun StructureND.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() - override fun Tensor.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() + override fun StructureND.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() - override fun Tensor.view(shape: IntArray): MultikTensor { + override fun StructureND.view(shape: IntArray): MultikTensor { require(shape.all { it > 0 }) require(shape.fold(1, Int::times) == this.shape.size) { "Cannot reshape array of size ${this.shape.size} into a new shape ${ @@ -178,9 +231,9 @@ public class MultikTensorAlgebra> internal constructor( }.wrap() } - override fun Tensor.viewAs(other: Tensor): MultikTensor = view(other.shape) + override fun StructureND.viewAs(other: StructureND): MultikTensor = view(other.shape) - override fun Tensor.dot(other: Tensor): MultikTensor = + override fun StructureND.dot(other: StructureND): MultikTensor = if (this.shape.size == 1 && other.shape.size == 1) { Multik.ndarrayOf( asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() @@ -197,45 +250,95 @@ public class MultikTensorAlgebra> internal constructor( TODO("Diagonal embedding not implemented") } - override fun Tensor.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> + override fun StructureND.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> elementAlgebra.add(acc, t) } - override fun Tensor.sum(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.sum(dim: Int, keepDim: Boolean): MultikTensor { TODO("Not yet implemented") } - override fun Tensor.min(): T = - asMultik().array.minWith(comparator) ?: error("No elements in tensor") + override fun StructureND.min(): T? = asMultik().array.min() - override fun Tensor.min(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.min(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.max(): T = - asMultik().array.maxWith(comparator) ?: error("No elements in tensor") + override fun StructureND.max(): T? = asMultik().array.max() - - override fun Tensor.max(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.max(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.argMax(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } } -public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } +public abstract class MultikDivisionTensorAlgebra> + : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { -public val FloatField.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) } + override fun T.div(arg: StructureND): MultikTensor = arg.map { elementAlgebra.divide(this@div, it) } -public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } + override fun StructureND.div(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() -public val IntRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } + override fun StructureND.div(other: StructureND): MultikTensor = + asMultik().array.div(other.asMultik().array).wrap() -public val LongRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } \ No newline at end of file + override fun Tensor.divAssign(value: T) { + if (this is MultikTensor) { + array.divAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.divide(t, value) } + } + } + + override fun Tensor.divAssign(other: StructureND) { + if (this is MultikTensor) { + array.divAssign(other.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.divide(t, other[index]) } + } + } +} + +public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: DoubleField get() = DoubleField + override val type: DataType get() = DataType.DoubleDataType +} + +public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra +public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra + +public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: FloatField get() = FloatField + override val type: DataType get() = DataType.FloatDataType +} + +public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra +public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra + +public object MultikShortAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: ShortRing get() = ShortRing + override val type: DataType get() = DataType.ShortDataType +} + +public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra +public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra + +public object MultikIntAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: IntRing get() = IntRing + override val type: DataType get() = DataType.IntDataType +} + +public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra +public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra + +public object MultikLongAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: LongRing get() = LongRing + override val type: DataType get() = DataType.LongDataType +} + +public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra +public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file diff --git a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt index 9dd9854c7..66ba7db2d 100644 --- a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt +++ b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt @@ -7,7 +7,7 @@ import space.kscience.kmath.operations.invoke internal class MultikNDTest { @Test - fun basicAlgebra(): Unit = DoubleField.multikND{ + fun basicAlgebra(): Unit = DoubleField.multikAlgebra{ one(2,2) + 1.0 } } \ No newline at end of file diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index e40ea3dbc..3fb03c964 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -13,6 +13,7 @@ import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.DoubleField @@ -27,22 +28,6 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra */ public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { - override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure { - val array = - } - - override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - - override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - /** * Wraps [INDArray] to [Nd4jArrayStructure]. */ @@ -53,8 +38,21 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe */ public val StructureND.ndArray: INDArray + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure + + override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(get(index)) } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(index, get(index)) } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { + require(left.shape.contentEquals(right.shape)) + return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } + } + override fun T.plus(other: StructureND): Nd4jArrayStructure = other.ndArray.add(this).wrap() - override fun StructureND.plus(value: T): Nd4jArrayStructure = ndArray.add(value).wrap() + override fun StructureND.plus(arg: T): Nd4jArrayStructure = ndArray.add(arg).wrap() override fun StructureND.plus(other: StructureND): Nd4jArrayStructure = ndArray.add(other.ndArray).wrap() @@ -94,51 +92,52 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe } override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() - override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() - override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() - override fun Tensor.dot(other: Tensor): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() + override fun StructureND.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() + override fun StructureND.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun StructureND.dot(other: StructureND): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() - override fun Tensor.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.min(keepDim, dim).wrap() - override fun Tensor.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.sum(keepDim, dim).wrap() - override fun Tensor.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.max(keepDim, dim).wrap() - override fun Tensor.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() - override fun Tensor.viewAs(other: Tensor): Nd4jArrayStructure = view(other.shape) + override fun StructureND.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() + override fun StructureND.viewAs(other: StructureND): Nd4jArrayStructure = view(other.shape) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndBase.get().argmax(ndArray, keepDim, dim).wrap() - override fun Tensor.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.mean(keepDim, dim).wrap() + override fun StructureND.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.mean(keepDim, dim).wrap() - override fun Tensor.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() - override fun Tensor.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() - override fun Tensor.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() - override fun Tensor.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() - override fun Tensor.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() - override fun Tensor.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() + override fun StructureND.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() + override fun StructureND.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() + override fun StructureND.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() + override fun StructureND.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() + override fun StructureND.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() + override fun StructureND.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - override fun Tensor.acosh(): Nd4jArrayStructure = + override fun StructureND.acosh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() - override fun Tensor.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() - override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() + override fun StructureND.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() + override fun StructureND.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() + override fun StructureND.sinh(): Tensor = Transforms.sinh(ndArray).wrap() - override fun Tensor.asinh(): Nd4jArrayStructure = + override fun StructureND.asinh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() - override fun Tensor.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() - override fun Tensor.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() - override fun Tensor.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() - override fun Tensor.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() - override fun Tensor.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() - override fun Tensor.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() + override fun StructureND.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() + override fun StructureND.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() + override fun StructureND.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun StructureND.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() + override fun StructureND.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() + override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.std(true, keepDim, dim).wrap() override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() @@ -153,7 +152,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe ndArray.divi(other.ndArray) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() private companion object { @@ -170,6 +169,16 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() + override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure { + val array: INDArray = Nd4j.zeros(*shape) + val indices = DefaultStrides(shape) + indices.indices().forEach { index -> + array.putScalar(index, elementAlgebra.initializer(index)) + } + return array.wrap() + } + + @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { @@ -190,10 +199,10 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { dim2: Int, ): Tensor = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) - override fun Tensor.sum(): Double = ndArray.sumNumber().toDouble() - override fun Tensor.min(): Double = ndArray.minNumber().toDouble() - override fun Tensor.max(): Double = ndArray.maxNumber().toDouble() - override fun Tensor.mean(): Double = ndArray.meanNumber().toDouble() - override fun Tensor.std(): Double = ndArray.stdNumber().toDouble() - override fun Tensor.variance(): Double = ndArray.varNumber().toDouble() + override fun StructureND.sum(): Double = ndArray.sumNumber().toDouble() + override fun StructureND.min(): Double = ndArray.minNumber().toDouble() + override fun StructureND.max(): Double = ndArray.maxNumber().toDouble() + override fun StructureND.mean(): Double = ndArray.meanNumber().toDouble() + override fun StructureND.std(): Double = ndArray.stdNumber().toDouble() + override fun StructureND.variance(): Double = ndArray.varNumber().toDouble() } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index caafcc7c1..debfb3ef0 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Field @@ -18,7 +19,7 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA /** * @return the mean of all elements in the input tensor. */ - public fun Tensor.mean(): T + public fun StructureND.mean(): T /** * Returns the mean of each row of the input tensor in the given dimension [dim]. @@ -31,12 +32,12 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the mean of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.mean(dim: Int, keepDim: Boolean): Tensor /** * @return the standard deviation of all elements in the input tensor. */ - public fun Tensor.std(): T + public fun StructureND.std(): T /** * Returns the standard deviation of each row of the input tensor in the given dimension [dim]. @@ -49,12 +50,12 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the standard deviation of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.std(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.std(dim: Int, keepDim: Boolean): Tensor /** * @return the variance of all elements in the input tensor. */ - public fun Tensor.variance(): T + public fun StructureND.variance(): T /** * Returns the variance of each row of the input tensor in the given dimension [dim]. @@ -67,57 +68,57 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the variance of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.variance(dim: Int, keepDim: Boolean): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.exp.html - public fun Tensor.exp(): Tensor + public fun StructureND.exp(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.log.html - public fun Tensor.ln(): Tensor + public fun StructureND.ln(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html - public fun Tensor.sqrt(): Tensor + public fun StructureND.sqrt(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun Tensor.cos(): Tensor + public fun StructureND.cos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun Tensor.acos(): Tensor + public fun StructureND.acos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun Tensor.cosh(): Tensor + public fun StructureND.cosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun Tensor.acosh(): Tensor + public fun StructureND.acosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun Tensor.sin(): Tensor + public fun StructureND.sin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun Tensor.asin(): Tensor + public fun StructureND.asin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun Tensor.sinh(): Tensor + public fun StructureND.sinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun Tensor.asinh(): Tensor + public fun StructureND.asinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun Tensor.tan(): Tensor + public fun StructureND.tan(): Tensor //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun Tensor.atan(): Tensor + public fun StructureND.atan(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun Tensor.tanh(): Tensor + public fun StructureND.tanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun Tensor.atanh(): Tensor + public fun StructureND.atanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil - public fun Tensor.ceil(): Tensor + public fun StructureND.ceil(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor - public fun Tensor.floor(): Tensor + public fun StructureND.floor(): Tensor } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 3f32eb9ca..e8fa7dacd 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Field /** @@ -20,7 +21,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * * @return the determinant. */ - public fun Tensor.det(): Tensor + public fun StructureND.det(): Tensor /** * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. @@ -30,7 +31,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * * @return the multiplicative inverse of a matrix. */ - public fun Tensor.inv(): Tensor + public fun StructureND.inv(): Tensor /** * Cholesky decomposition. @@ -46,7 +47,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return the batch of `L` matrices. */ - public fun Tensor.cholesky(): Tensor + public fun StructureND.cholesky(): Tensor /** * QR decomposition. @@ -60,7 +61,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return pair of `Q` and `R` tensors. */ - public fun Tensor.qr(): Pair, Tensor> + public fun StructureND.qr(): Pair, Tensor> /** * LUP decomposition @@ -74,7 +75,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return triple of P, L and U tensors */ - public fun Tensor.lu(): Triple, Tensor, Tensor> + public fun StructureND.lu(): Triple, Tensor, Tensor> /** * Singular Value Decomposition. @@ -90,7 +91,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return triple `Triple(U, S, V)`. */ - public fun Tensor.svd(): Triple, Tensor, Tensor> + public fun StructureND.svd(): Triple, Tensor, Tensor> /** * Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices, @@ -100,6 +101,6 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return a pair `eigenvalues to eigenvectors` */ - public fun Tensor.symEig(): Pair, Tensor> + public fun StructureND.symEig(): Pair, Tensor> } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index e910c5c31..28d7c4677 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -41,12 +41,12 @@ public interface TensorAlgebra> : RingOpsND { override operator fun T.plus(other: StructureND): Tensor /** - * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. + * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor. * - * @param value the number to be added to each element of this tensor. - * @return the sum of this tensor and [value]. + * @param arg the number to be added to each element of this tensor. + * @return the sum of this tensor and [arg]. */ - override operator fun StructureND.plus(value: T): Tensor + override operator fun StructureND.plus(arg: T): Tensor /** * Each element of the tensor [other] is added to each element of this tensor. @@ -166,7 +166,7 @@ public interface TensorAlgebra> : RingOpsND { * @param i index of the extractable tensor * @return subtensor of the original tensor with index [i] */ - public operator fun Tensor.get(i: Int): Tensor + public operator fun StructureND.get(i: Int): Tensor /** * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. @@ -176,7 +176,7 @@ public interface TensorAlgebra> : RingOpsND { * @param j the second dimension to be transposed * @return transposed tensor */ - public fun Tensor.transpose(i: Int = -2, j: Int = -1): Tensor + public fun StructureND.transpose(i: Int = -2, j: Int = -1): Tensor /** * Returns a new tensor with the same data as the self tensor but of a different shape. @@ -186,7 +186,7 @@ public interface TensorAlgebra> : RingOpsND { * @param shape the desired size * @return tensor with new shape */ - public fun Tensor.view(shape: IntArray): Tensor + public fun StructureND.view(shape: IntArray): Tensor /** * View this tensor as the same size as [other]. @@ -196,7 +196,7 @@ public interface TensorAlgebra> : RingOpsND { * @param other the result tensor has the same size as other. * @return the result tensor with the same size as other. */ - public fun Tensor.viewAs(other: Tensor): Tensor + public fun StructureND.viewAs(other: StructureND): Tensor /** * Matrix product of two tensors. @@ -227,7 +227,7 @@ public interface TensorAlgebra> : RingOpsND { * @param other tensor to be multiplied. * @return a mathematical product of two tensors. */ - public infix fun Tensor.dot(other: Tensor): Tensor + public infix fun StructureND.dot(other: StructureND): Tensor /** * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) @@ -262,7 +262,7 @@ public interface TensorAlgebra> : RingOpsND { /** * @return the sum of all elements in the input tensor. */ - public fun Tensor.sum(): T + public fun StructureND.sum(): T /** * Returns the sum of each row of the input tensor in the given dimension [dim]. @@ -275,12 +275,12 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the sum of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.sum(dim: Int, keepDim: Boolean): Tensor /** - * @return the minimum value of all elements in the input tensor. + * @return the minimum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.min(): T + public fun StructureND.min(): T? /** * Returns the minimum value of each row of the input tensor in the given dimension [dim]. @@ -293,12 +293,12 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the minimum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.min(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.min(dim: Int, keepDim: Boolean): Tensor /** - * Returns the maximum value of all elements in the input tensor. + * Returns the maximum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.max(): T + public fun StructureND.max(): T? /** * Returns the maximum value of each row of the input tensor in the given dimension [dim]. @@ -311,7 +311,7 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.max(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.max(dim: Int, keepDim: Boolean): Tensor /** * Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. @@ -324,7 +324,7 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the index of maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor override fun add(left: StructureND, right: StructureND): Tensor = left + right diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index f3c5d2540..f53a0ca55 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -85,7 +85,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 3d343604b..be87f8019 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -115,7 +115,7 @@ public open class DoubleTensorAlgebra : TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() ) - override operator fun Tensor.get(i: Int): DoubleTensor { + override operator fun StructureND.get(i: Int): DoubleTensor { val lastShape = tensor.shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart @@ -160,7 +160,7 @@ public open class DoubleTensorAlgebra : * * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. */ - public fun Tensor.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) + public fun StructureND.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) /** * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. @@ -198,9 +198,8 @@ public open class DoubleTensorAlgebra : * * @return a copy of the `input` tensor with a copied buffer. */ - public fun Tensor.copy(): DoubleTensor { - return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) - } + public fun StructureND.copy(): DoubleTensor = + DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) override fun Double.plus(other: StructureND): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> @@ -209,7 +208,7 @@ public open class DoubleTensorAlgebra : return DoubleTensor(other.shape, resBuffer) } - override fun StructureND.plus(value: Double): DoubleTensor = value + tensor + override fun StructureND.plus(arg: Double): DoubleTensor = arg + tensor override fun StructureND.plus(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other.tensor) @@ -345,7 +344,7 @@ public open class DoubleTensorAlgebra : return DoubleTensor(tensor.shape, resBuffer) } - override fun Tensor.transpose(i: Int, j: Int): DoubleTensor { + override fun StructureND.transpose(i: Int, j: Int): DoubleTensor { val ii = tensor.minusIndex(i) val jj = tensor.minusIndex(j) checkTranspose(tensor.dimension, ii, jj) @@ -369,15 +368,15 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.view(shape: IntArray): DoubleTensor { + override fun StructureND.view(shape: IntArray): DoubleTensor { checkView(tensor, shape) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) } - override fun Tensor.viewAs(other: Tensor): DoubleTensor = + override fun StructureND.viewAs(other: StructureND): DoubleTensor = tensor.view(other.shape) - override infix fun Tensor.dot(other: Tensor): DoubleTensor { + override infix fun StructureND.dot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) } @@ -569,10 +568,10 @@ public open class DoubleTensorAlgebra : */ public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) - internal inline fun Tensor.fold(foldFunction: (DoubleArray) -> Double): Double = + internal inline fun StructureND.fold(foldFunction: (DoubleArray) -> Double): Double = foldFunction(tensor.copyArray()) - internal inline fun Tensor.foldDim( + internal inline fun StructureND.foldDim( foldFunction: (DoubleArray) -> Double, dim: Int, keepDim: Boolean, @@ -596,30 +595,30 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.sum(): Double = tensor.fold { it.sum() } + override fun StructureND.sum(): Double = tensor.fold { it.sum() } - override fun Tensor.sum(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.sum(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.sum() }, dim, keepDim) - override fun Tensor.min(): Double = this.fold { it.minOrNull()!! } + override fun StructureND.min(): Double = this.fold { it.minOrNull()!! } - override fun Tensor.min(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.min(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.minOrNull()!! }, dim, keepDim) - override fun Tensor.max(): Double = this.fold { it.maxOrNull()!! } + override fun StructureND.max(): Double = this.fold { it.maxOrNull()!! } - override fun Tensor.max(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.max(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() }, dim, keepDim) - override fun Tensor.mean(): Double = this.fold { it.sum() / tensor.numElements } + override fun StructureND.mean(): Double = this.fold { it.sum() / tensor.numElements } - override fun Tensor.mean(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } @@ -629,12 +628,12 @@ public open class DoubleTensorAlgebra : keepDim ) - override fun Tensor.std(): Double = this.fold { arr -> + override fun StructureND.std(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) } - override fun Tensor.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -644,12 +643,12 @@ public open class DoubleTensorAlgebra : keepDim ) - override fun Tensor.variance(): Double = this.fold { arr -> + override fun StructureND.variance(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -672,7 +671,7 @@ public open class DoubleTensorAlgebra : * @param tensors the [List] of 1-dimensional tensors with same shape * @return `M`. */ - public fun cov(tensors: List>): DoubleTensor { + public fun cov(tensors: List>): DoubleTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val n = tensors.size val m = tensors[0].shape[0] @@ -689,43 +688,43 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.exp(): DoubleTensor = tensor.map { exp(it) } + override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } - override fun Tensor.ln(): DoubleTensor = tensor.map { ln(it) } + override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } - override fun Tensor.sqrt(): DoubleTensor = tensor.map { sqrt(it) } + override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } - override fun Tensor.cos(): DoubleTensor = tensor.map { cos(it) } + override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } - override fun Tensor.acos(): DoubleTensor = tensor.map { acos(it) } + override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } - override fun Tensor.cosh(): DoubleTensor = tensor.map { cosh(it) } + override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } - override fun Tensor.acosh(): DoubleTensor = tensor.map { acosh(it) } + override fun StructureND.acosh(): DoubleTensor = tensor.map { acosh(it) } - override fun Tensor.sin(): DoubleTensor = tensor.map { sin(it) } + override fun StructureND.sin(): DoubleTensor = tensor.map { sin(it) } - override fun Tensor.asin(): DoubleTensor = tensor.map { asin(it) } + override fun StructureND.asin(): DoubleTensor = tensor.map { asin(it) } - override fun Tensor.sinh(): DoubleTensor = tensor.map { sinh(it) } + override fun StructureND.sinh(): DoubleTensor = tensor.map { sinh(it) } - override fun Tensor.asinh(): DoubleTensor = tensor.map { asinh(it) } + override fun StructureND.asinh(): DoubleTensor = tensor.map { asinh(it) } - override fun Tensor.tan(): DoubleTensor = tensor.map { tan(it) } + override fun StructureND.tan(): DoubleTensor = tensor.map { tan(it) } - override fun Tensor.atan(): DoubleTensor = tensor.map { atan(it) } + override fun StructureND.atan(): DoubleTensor = tensor.map { atan(it) } - override fun Tensor.tanh(): DoubleTensor = tensor.map { tanh(it) } + override fun StructureND.tanh(): DoubleTensor = tensor.map { tanh(it) } - override fun Tensor.atanh(): DoubleTensor = tensor.map { atanh(it) } + override fun StructureND.atanh(): DoubleTensor = tensor.map { atanh(it) } - override fun Tensor.ceil(): DoubleTensor = tensor.map { ceil(it) } + override fun StructureND.ceil(): DoubleTensor = tensor.map { ceil(it) } - override fun Tensor.floor(): DoubleTensor = tensor.map { floor(it) } + override fun StructureND.floor(): DoubleTensor = tensor.map { floor(it) } - override fun Tensor.inv(): DoubleTensor = invLU(1e-9) + override fun StructureND.inv(): DoubleTensor = invLU(1e-9) - override fun Tensor.det(): DoubleTensor = detLU(1e-9) + override fun StructureND.det(): DoubleTensor = detLU(1e-9) /** * Computes the LU factorization of a matrix or batches of matrices `input`. @@ -736,7 +735,7 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(epsilon: Double): Pair = + public fun StructureND.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") @@ -749,7 +748,7 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(): Pair = luFactor(1e-9) + public fun StructureND.luFactor(): Pair = luFactor(1e-9) /** * Unpacks the data and pivots from a LU factorization of a tensor. @@ -763,7 +762,7 @@ public open class DoubleTensorAlgebra : * @return triple of `P`, `L` and `U` tensors */ public fun luPivot( - luTensor: Tensor, + luTensor: StructureND, pivotsTensor: Tensor, ): Triple { checkSquareMatrix(luTensor.shape) @@ -806,7 +805,7 @@ public open class DoubleTensorAlgebra : * Used when checking the positive definiteness of the input matrix or matrices. * @return a pair of `Q` and `R` tensors. */ - public fun Tensor.cholesky(epsilon: Double): DoubleTensor { + public fun StructureND.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) checkPositiveDefinite(tensor, epsilon) @@ -819,9 +818,9 @@ public open class DoubleTensorAlgebra : return lTensor } - override fun Tensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun StructureND.cholesky(): DoubleTensor = cholesky(1e-6) - override fun Tensor.qr(): Pair { + override fun StructureND.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() @@ -837,7 +836,7 @@ public open class DoubleTensorAlgebra : return qTensor to rTensor } - override fun Tensor.svd(): Triple = + override fun StructureND.svd(): Triple = svd(epsilon = 1e-10) /** @@ -853,7 +852,7 @@ public open class DoubleTensorAlgebra : * i.e., the precision with which the cosine approaches 1 in an iterative algorithm. * @return a triple `Triple(U, S, V)`. */ - public fun Tensor.svd(epsilon: Double): Triple { + public fun StructureND.svd(epsilon: Double): Triple { val size = tensor.dimension val commonShape = tensor.shape.sliceArray(0 until size - 2) val (n, m) = tensor.shape.sliceArray(size - 2 until size) @@ -886,7 +885,7 @@ public open class DoubleTensorAlgebra : return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun Tensor.symEig(): Pair = symEig(epsilon = 1e-15) + override fun StructureND.symEig(): Pair = symEig(epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -896,7 +895,7 @@ public open class DoubleTensorAlgebra : * and when the cosine approaches 1 in the SVD algorithm. * @return a pair `eigenvalues to eigenvectors`. */ - public fun Tensor.symEig(epsilon: Double): Pair { + public fun StructureND.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) fun MutableStructure2D.cleanSym(n: Int) { @@ -931,7 +930,7 @@ public open class DoubleTensorAlgebra : * with zero. * @return the determinant. */ - public fun Tensor.detLU(epsilon: Double = 1e-9): DoubleTensor { + public fun StructureND.detLU(epsilon: Double = 1e-9): DoubleTensor { checkSquareMatrix(tensor.shape) val luTensor = tensor.copy() val pivotsTensor = tensor.setUpPivots() @@ -964,7 +963,7 @@ public open class DoubleTensorAlgebra : * @param epsilon error in the LU algorithm—permissible error when comparing the determinant of a matrix with zero * @return the multiplicative inverse of a matrix. */ - public fun Tensor.invLU(epsilon: Double = 1e-9): DoubleTensor { + public fun StructureND.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() @@ -989,12 +988,12 @@ public open class DoubleTensorAlgebra : * @param epsilon permissible error when comparing the determinant of a matrix with zero. * @return triple of `P`, `L` and `U` tensors. */ - public fun Tensor.lu(epsilon: Double = 1e-9): Triple { + public fun StructureND.lu(epsilon: Double = 1e-9): Triple { val (lu, pivots) = tensor.luFactor(epsilon) return luPivot(lu, pivots) } - override fun Tensor.lu(): Triple = lu(1e-9) + override fun StructureND.lu(): Triple = lu(1e-9) } public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra