From 7e59ec5804370829d85b858362a5187b3cec1740 Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 26 Oct 2021 09:16:24 +0300 Subject: [PATCH 1/3] Refactor TensorAlgebra to take StructureND and inherit AlgebraND --- .../kmath/structures/StreamDoubleFieldND.kt | 2 +- .../kscience/kmath/tensors/neuralNetwork.kt | 4 +- .../space/kscience/kmath/nd/AlgebraND.kt | 8 +- .../kscience/kmath/nd/BufferAlgebraND.kt | 8 +- .../space/kscience/kmath/nd/BufferND.kt | 16 +- .../space/kscience/kmath/nd/DoubleFieldND.kt | 6 +- .../space/kscience/kmath/nd/StructureND.kt | 4 +- .../space/kscience/kmath/real/realND.kt | 4 +- .../kmath/multik/MultikTensorAlgebra.kt | 53 +++--- .../tensors/api/AnalyticTensorAlgebra.kt | 4 +- .../tensors/api/LinearOpsTensorAlgebra.kt | 4 +- .../kmath/tensors/api/TensorAlgebra.kt | 58 +++--- .../api/TensorPartialDivisionAlgebra.kt | 26 ++- .../core/BroadcastDoubleTensorAlgebra.kt | 31 +-- .../kmath/tensors/core/BufferedTensor.kt | 10 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 180 +++++++++++------- .../tensors/core/internal/broadcastUtils.kt | 14 +- .../kmath/tensors/core/internal/checks.kt | 3 +- .../tensors/core/internal/tensorCastsUtils.kt | 9 +- .../kmath/tensors/core/internal/utils.kt | 2 +- .../kmath/tensors/core/tensorCasts.kt | 9 +- 21 files changed, 258 insertions(+), 197 deletions(-) diff --git a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt index 2b3e72136..05a13f5d2 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/structures/StreamDoubleFieldND.kt @@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND this.buffer as DoubleBuffer + this is BufferND && this.indices == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt index 3025ff8a3..cdfc06922 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/neuralNetwork.kt @@ -9,7 +9,7 @@ import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensorAlgebra -import space.kscience.kmath.tensors.core.toDoubleArray +import space.kscience.kmath.tensors.core.copyArray import kotlin.math.sqrt const val seed = 100500L @@ -111,7 +111,7 @@ class NeuralNetwork(private val layers: List) { private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra { val onesForAnswers = yPred.zeroesLike() - yTrue.toDoubleArray().forEachIndexed { index, labelDouble -> + yTrue.copyArray().forEachIndexed { index, labelDouble -> val label = labelDouble.toInt() onesForAnswers[intArrayOf(index, label)] = 1.0 } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 30cb01146..3d2d08fac 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -131,19 +131,19 @@ public interface GroupOpsND> : GroupOps>, * Adds an element to ND structure of it. * * @receiver the augend. - * @param arg the addend. + * @param other the addend. * @return the sum. */ - public operator fun T.plus(arg: StructureND): StructureND = arg + this + public operator fun T.plus(other: StructureND): StructureND = other.map { value -> add(this@plus, value) } /** * Subtracts an ND structure from an element of it. * * @receiver the dividend. - * @param arg the divisor. + * @param other the divisor. * @return the quotient. */ - public operator fun T.minus(arg: StructureND): StructureND = arg.map { value -> add(-this@minus, value) } + public operator fun T.minus(other: StructureND): StructureND = other.map { value -> add(-this@minus, value) } public companion object } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt index 859edefb8..cf007e7c9 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferAlgebraND.kt @@ -51,7 +51,7 @@ public inline fun > BufferAlgebraND.mapInline( arg: BufferND, crossinline transform: A.(T) -> T ): BufferND { - val indexes = arg.indexes + val indexes = arg.indices return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) } @@ -59,7 +59,7 @@ internal inline fun > BufferAlgebraND.mapIndexedInline( arg: BufferND, crossinline transform: A.(index: IntArray, arg: T) -> T ): BufferND { - val indexes = arg.indexes + val indexes = arg.indices return BufferND( indexes, bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> @@ -73,8 +73,8 @@ internal inline fun > BufferAlgebraND.zipInline( r: BufferND, crossinline block: A.(l: T, r: T) -> T ): BufferND { - require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } - val indexes = l.indexes + require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indices return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt index afa8f8250..2b6fd3693 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/BufferND.kt @@ -15,20 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory * Represents [StructureND] over [Buffer]. * * @param T the type of items. - * @param indexes The strides to access elements of [Buffer] by linear indices. + * @param indices The strides to access elements of [Buffer] by linear indices. * @param buffer The underlying buffer. */ public open class BufferND( - public val indexes: ShapeIndexer, + public val indices: ShapeIndexer, public open val buffer: Buffer, ) : StructureND { - override operator fun get(index: IntArray): T = buffer[indexes.offset(index)] + override operator fun get(index: IntArray): T = buffer[indices.offset(index)] - override val shape: IntArray get() = indexes.shape + override val shape: IntArray get() = indices.shape @PerformancePitfall - override fun elements(): Sequence> = indexes.indices().map { + override fun elements(): Sequence> = indices.indices().map { it to this[it] } @@ -43,7 +43,7 @@ public inline fun StructureND.mapToBuffer( crossinline transform: (T) -> R, ): BufferND { return if (this is BufferND) - BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) + BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) @@ -62,7 +62,7 @@ public class MutableBufferND( override val buffer: MutableBuffer, ) : MutableStructureND, BufferND(strides, buffer) { override fun set(index: IntArray, value: T) { - buffer[indexes.offset(index)] = value + buffer[indices.offset(index)] = value } } @@ -74,7 +74,7 @@ public inline fun MutableStructureND.mapToMutableBuffer( crossinline transform: (T) -> R, ): MutableBufferND { return if (this is MutableBufferND) - MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) + MutableBufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) }) else { val strides = DefaultStrides(shape) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index 853ac43b0..abb8e46ea 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -33,7 +33,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D arg: DoubleBufferND, transform: (Double) -> Double ): DoubleBufferND { - val indexes = arg.indexes + val indexes = arg.indices val array = arg.buffer.array return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) }) } @@ -43,8 +43,8 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D r: DoubleBufferND, block: (l: Double, r: Double) -> Double ): DoubleBufferND { - require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } - val indexes = l.indexes + require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } + val indexes = l.indices val lArray = l.buffer.array val rArray = r.buffer.array return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) }) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt index b4e62366a..496abf60f 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/StructureND.kt @@ -71,7 +71,7 @@ public interface StructureND : Featured { if (st1 === st2) return true // fast comparison of buffers if possible - if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) + if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices) return Buffer.contentEquals(st1.buffer, st2.buffer) //element by element comparison if it could not be avoided @@ -87,7 +87,7 @@ public interface StructureND : Featured { if (st1 === st2) return true // fast comparison of buffers if possible - if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) + if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices) return Buffer.contentEquals(st1.buffer, st2.buffer) //element by element comparison if it could not be avoided diff --git a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt index 0edd51be2..56f50acbc 100644 --- a/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt +++ b/kmath-for-real/src/commonMain/kotlin/space/kscience/kmath/real/realND.kt @@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer * Map one [BufferND] using function without indices. */ public inline fun BufferND.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND { - val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) } - return BufferND(indexes, DoubleBuffer(array)) + val array = DoubleArray(indices.linearSize) { offset -> DoubleField.transform(buffer[offset]) } + return BufferND(indices, DoubleBuffer(array)) } /** diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index ed15199a6..3bbe8e672 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.* import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.Shape +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.Tensor @@ -49,17 +50,17 @@ private fun MultiArray.asD2Array(): D2Array { else throw ClassCastException("Cannot cast MultiArray to NDArray.") } -public class MultikTensorAlgebra internal constructor( +public class MultikTensorAlgebra> internal constructor( public val type: DataType, - public val elementAlgebra: Ring, + override val elementAlgebra: A, public val comparator: Comparator -) : TensorAlgebra { +) : TensorAlgebra { /** * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * are not reflected back onto the source */ - public fun Tensor.asMultik(): MultikTensor { + public fun StructureND.asMultik(): MultikTensor { return if (this is MultikTensor) { this } else { @@ -73,17 +74,17 @@ public class MultikTensorAlgebra internal constructor( public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this) - override fun Tensor.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { + override fun StructureND.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { get(intArrayOf(0)) } else null - override fun T.plus(other: Tensor): MultikTensor = + override fun T.plus(other: StructureND): MultikTensor = other.plus(this) - override fun Tensor.plus(value: T): MultikTensor = + override fun StructureND.plus(value: T): MultikTensor = asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() - override fun Tensor.plus(other: Tensor): MultikTensor = + override fun StructureND.plus(other: StructureND): MultikTensor = asMultik().array.plus(other.asMultik().array).wrap() override fun Tensor.plusAssign(value: T) { @@ -94,7 +95,7 @@ public class MultikTensorAlgebra internal constructor( } } - override fun Tensor.plusAssign(other: Tensor) { + override fun Tensor.plusAssign(other: StructureND) { if (this is MultikTensor) { array.plusAssign(other.asMultik().array) } else { @@ -102,12 +103,12 @@ public class MultikTensorAlgebra internal constructor( } } - override fun T.minus(other: Tensor): MultikTensor = (-(other.asMultik().array - this)).wrap() + override fun T.minus(other: StructureND): MultikTensor = (-(other.asMultik().array - this)).wrap() - override fun Tensor.minus(value: T): MultikTensor = - asMultik().array.deepCopy().apply { minusAssign(value) }.wrap() + override fun StructureND.minus(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap() - override fun Tensor.minus(other: Tensor): MultikTensor = + override fun StructureND.minus(other: StructureND): MultikTensor = asMultik().array.minus(other.asMultik().array).wrap() override fun Tensor.minusAssign(value: T) { @@ -118,7 +119,7 @@ public class MultikTensorAlgebra internal constructor( } } - override fun Tensor.minusAssign(other: Tensor) { + override fun Tensor.minusAssign(other: StructureND) { if (this is MultikTensor) { array.minusAssign(other.asMultik().array) } else { @@ -126,13 +127,13 @@ public class MultikTensorAlgebra internal constructor( } } - override fun T.times(other: Tensor): MultikTensor = - other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() + override fun T.times(arg: StructureND): MultikTensor = + arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() - override fun Tensor.times(value: T): Tensor = - asMultik().array.deepCopy().apply { timesAssign(value) }.wrap() + override fun StructureND.times(arg: T): Tensor = + asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap() - override fun Tensor.times(other: Tensor): MultikTensor = + override fun StructureND.times(other: StructureND): MultikTensor = asMultik().array.times(other.asMultik().array).wrap() override fun Tensor.timesAssign(value: T) { @@ -143,7 +144,7 @@ public class MultikTensorAlgebra internal constructor( } } - override fun Tensor.timesAssign(other: Tensor) { + override fun Tensor.timesAssign(other: StructureND) { if (this is MultikTensor) { array.timesAssign(other.asMultik().array) } else { @@ -151,7 +152,7 @@ public class MultikTensorAlgebra internal constructor( } } - override fun Tensor.unaryMinus(): MultikTensor = + override fun StructureND.unaryMinus(): MultikTensor = asMultik().array.unaryMinus().wrap() override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() @@ -224,17 +225,17 @@ public class MultikTensorAlgebra internal constructor( } } -public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra +public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } -public val FloatField.multikTensorAlgebra: MultikTensorAlgebra +public val FloatField.multikTensorAlgebra: MultikTensorAlgebra get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) } -public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra +public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } -public val IntRing.multikTensorAlgebra: MultikTensorAlgebra +public val IntRing.multikTensorAlgebra: MultikTensorAlgebra get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } -public val LongRing.multikTensorAlgebra: MultikTensorAlgebra +public val LongRing.multikTensorAlgebra: MultikTensorAlgebra get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index 851810c8d..caafcc7c1 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -5,13 +5,15 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.operations.Field + /** * Analytic operations on [Tensor]. * * @param T the type of items closed under analytic functions in the tensors. */ -public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { +public interface AnalyticTensorAlgebra> : TensorPartialDivisionAlgebra { /** * @return the mean of all elements in the input tensor. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 78aad2189..3f32eb9ca 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -5,12 +5,14 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.operations.Field + /** * Common linear algebra operations. Operates on [Tensor]. * * @param T the type of items closed under division in the tensors. */ -public interface LinearOpsTensorAlgebra : TensorPartialDivisionAlgebra { +public interface LinearOpsTensorAlgebra> : TensorPartialDivisionAlgebra { /** * Computes the determinant of a square matrix input, or of each square matrix in a batched input. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 810ebe777..e910c5c31 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -5,7 +5,9 @@ package space.kscience.kmath.tensors.api -import space.kscience.kmath.operations.RingOps +import space.kscience.kmath.nd.RingOpsND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Ring /** * Algebra over a ring on [Tensor]. @@ -13,20 +15,20 @@ import space.kscience.kmath.operations.RingOps * * @param T the type of items in the tensors. */ -public interface TensorAlgebra : RingOps> { +public interface TensorAlgebra> : RingOpsND { /** * Returns a single tensor value of unit dimension if tensor shape equals to [1]. * * @return a nullable value of a potentially scalar tensor. */ - public fun Tensor.valueOrNull(): T? + public fun StructureND.valueOrNull(): T? /** * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * * @return the value of a scalar tensor. */ - public fun Tensor.value(): T = + public fun StructureND.value(): T = valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") /** @@ -36,7 +38,7 @@ public interface TensorAlgebra : RingOps> { * @param other tensor to be added. * @return the sum of this value and tensor [other]. */ - public operator fun T.plus(other: Tensor): Tensor + override operator fun T.plus(other: StructureND): Tensor /** * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. @@ -44,7 +46,7 @@ public interface TensorAlgebra : RingOps> { * @param value the number to be added to each element of this tensor. * @return the sum of this tensor and [value]. */ - public operator fun Tensor.plus(value: T): Tensor + override operator fun StructureND.plus(value: T): Tensor /** * Each element of the tensor [other] is added to each element of this tensor. @@ -53,7 +55,7 @@ public interface TensorAlgebra : RingOps> { * @param other tensor to be added. * @return the sum of this tensor and [other]. */ - override fun Tensor.plus(other: Tensor): Tensor + override operator fun StructureND.plus(other: StructureND): Tensor /** * Adds the scalar [value] to each element of this tensor. @@ -67,7 +69,7 @@ public interface TensorAlgebra : RingOps> { * * @param other tensor to be added. */ - public operator fun Tensor.plusAssign(other: Tensor) + public operator fun Tensor.plusAssign(other: StructureND) /** * Each element of the tensor [other] is subtracted from this value. @@ -76,15 +78,15 @@ public interface TensorAlgebra : RingOps> { * @param other tensor to be subtracted. * @return the difference between this value and tensor [other]. */ - public operator fun T.minus(other: Tensor): Tensor + override operator fun T.minus(other: StructureND): Tensor /** - * Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. + * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor. * - * @param value the number to be subtracted from each element of this tensor. - * @return the difference between this tensor and [value]. + * @param arg the number to be subtracted from each element of this tensor. + * @return the difference between this tensor and [arg]. */ - public operator fun Tensor.minus(value: T): Tensor + override operator fun StructureND.minus(arg: T): Tensor /** * Each element of the tensor [other] is subtracted from each element of this tensor. @@ -93,7 +95,7 @@ public interface TensorAlgebra : RingOps> { * @param other tensor to be subtracted. * @return the difference between this tensor and [other]. */ - override fun Tensor.minus(other: Tensor): Tensor + override operator fun StructureND.minus(other: StructureND): Tensor /** * Subtracts the scalar [value] from each element of this tensor. @@ -107,25 +109,25 @@ public interface TensorAlgebra : RingOps> { * * @param other tensor to be subtracted. */ - public operator fun Tensor.minusAssign(other: Tensor) + public operator fun Tensor.minusAssign(other: StructureND) /** - * Each element of the tensor [other] is multiplied by this value. + * Each element of the tensor [arg] is multiplied by this value. * The resulting tensor is returned. * - * @param other tensor to be multiplied. - * @return the product of this value and tensor [other]. + * @param arg tensor to be multiplied. + * @return the product of this value and tensor [arg]. */ - public operator fun T.times(other: Tensor): Tensor + override operator fun T.times(arg: StructureND): Tensor /** - * Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. + * Multiplies the scalar [arg] by each element of this tensor and returns a new resulting tensor. * - * @param value the number to be multiplied by each element of this tensor. - * @return the product of this tensor and [value]. + * @param arg the number to be multiplied by each element of this tensor. + * @return the product of this tensor and [arg]. */ - public operator fun Tensor.times(value: T): Tensor + override operator fun StructureND.times(arg: T): Tensor /** * Each element of the tensor [other] is multiplied by each element of this tensor. @@ -134,7 +136,7 @@ public interface TensorAlgebra : RingOps> { * @param other tensor to be multiplied. * @return the product of this tensor and [other]. */ - override fun Tensor.times(other: Tensor): Tensor + override operator fun StructureND.times(other: StructureND): Tensor /** * Multiplies the scalar [value] by each element of this tensor. @@ -148,14 +150,14 @@ public interface TensorAlgebra : RingOps> { * * @param other tensor to be multiplied. */ - public operator fun Tensor.timesAssign(other: Tensor) + public operator fun Tensor.timesAssign(other: StructureND) /** * Numerical negative, element-wise. * * @return tensor negation of the original tensor. */ - override fun Tensor.unaryMinus(): Tensor + override operator fun StructureND.unaryMinus(): Tensor /** * Returns the tensor at index i @@ -324,7 +326,7 @@ public interface TensorAlgebra : RingOps> { */ public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor - override fun add(left: Tensor, right: Tensor): Tensor = left + right + override fun add(left: StructureND, right: StructureND): Tensor = left + right - override fun multiply(left: Tensor, right: Tensor): Tensor = left * right + override fun multiply(left: StructureND, right: StructureND): Tensor = left * right } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index ce519288b..43c901ed5 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -5,30 +5,34 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.FieldOpsND +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.Field + /** * Algebra over a field with partial division on [Tensor]. * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * * @param T the type of items closed under division in the tensors. */ -public interface TensorPartialDivisionAlgebra : TensorAlgebra { +public interface TensorPartialDivisionAlgebra> : TensorAlgebra, FieldOpsND { /** - * Each element of the tensor [other] is divided by this value. + * Each element of the tensor [arg] is divided by this value. * The resulting tensor is returned. * - * @param other tensor to divide by. - * @return the division of this value by the tensor [other]. + * @param arg tensor to divide by. + * @return the division of this value by the tensor [arg]. */ - public operator fun T.div(other: Tensor): Tensor + override operator fun T.div(arg: StructureND): Tensor /** - * Divide by the scalar [value] each element of this tensor returns a new resulting tensor. + * Divide by the scalar [arg] each element of this tensor returns a new resulting tensor. * - * @param value the number to divide by each element of this tensor. - * @return the division of this tensor by the [value]. + * @param arg the number to divide by each element of this tensor. + * @return the division of this tensor by the [arg]. */ - public operator fun Tensor.div(value: T): Tensor + override operator fun StructureND.div(arg: T): Tensor /** * Each element of the tensor [other] is divided by each element of this tensor. @@ -37,7 +41,9 @@ public interface TensorPartialDivisionAlgebra : TensorAlgebra { * @param other tensor to be divided by. * @return the division of this tensor by [other]. */ - public operator fun Tensor.div(other: Tensor): Tensor + override operator fun StructureND.div(other: StructureND): Tensor + + override fun divide(left: StructureND, right: StructureND): StructureND = left.div(right) /** * Divides by the scalar [value] each element of this tensor. diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index 3d73fd53b..f3c5d2540 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.misc.UnstableKMathAPI +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.broadcastTensors @@ -18,66 +19,66 @@ import space.kscience.kmath.tensors.core.internal.tensor */ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { - override fun Tensor.plus(other: Tensor): DoubleTensor { + override fun StructureND.plus(other: StructureND): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.plusAssign(other: Tensor) { + override fun Tensor.plusAssign(other: StructureND) { val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.minus(other: Tensor): DoubleTensor { + override fun StructureND.minus(other: StructureND): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.minusAssign(other: Tensor) { + override fun Tensor.minusAssign(other: StructureND) { val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.times(other: Tensor): DoubleTensor { + override fun StructureND.times(other: StructureND): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[newThis.bufferStart + i] * newOther.mutableBuffer.array()[newOther.bufferStart + i] } return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.timesAssign(other: Tensor) { + override fun Tensor.timesAssign(other: StructureND) { val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= newOther.mutableBuffer.array()[tensor.bufferStart + i] } } - override fun Tensor.div(other: Tensor): DoubleTensor { + override fun StructureND.div(other: StructureND): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] - val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> + val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> newThis.mutableBuffer.array()[newOther.bufferStart + i] / newOther.mutableBuffer.array()[newOther.bufferStart + i] } @@ -86,7 +87,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { override fun Tensor.divAssign(other: Tensor) { val newOther = broadcastTo(other.tensor, tensor.shape) - for (i in 0 until tensor.linearStructure.linearSize) { + for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= newOther.mutableBuffer.array()[tensor.bufferStart + i] } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index bf9a9f7f7..ba3331067 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -23,23 +23,23 @@ public open class BufferedTensor internal constructor( /** * Buffer strides based on [TensorLinearStructure] implementation */ - public val linearStructure: Strides + public val indices: Strides get() = TensorLinearStructure(shape) /** * Number of elements in tensor */ public val numElements: Int - get() = linearStructure.linearSize + get() = indices.linearSize - override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] + override fun get(index: IntArray): T = mutableBuffer[bufferStart + indices.offset(index)] override fun set(index: IntArray, value: T) { - mutableBuffer[bufferStart + linearStructure.offset(index)] = value + mutableBuffer[bufferStart + indices.offset(index)] = value } @PerformancePitfall - override fun elements(): Sequence> = linearStructure.indices().map { + override fun elements(): Sequence> = indices.indices().map { it to get(it) } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 16ed4b834..472db8f5f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -6,8 +6,10 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.nd.MutableStructure2D +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.structures.indices import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra @@ -20,16 +22,71 @@ import kotlin.math.* * Implementation of basic operations over double tensors and basic algebra operations on them. */ public open class DoubleTensorAlgebra : - TensorPartialDivisionAlgebra, - AnalyticTensorAlgebra, - LinearOpsTensorAlgebra{ + TensorPartialDivisionAlgebra, + AnalyticTensorAlgebra, + LinearOpsTensorAlgebra { public companion object : DoubleTensorAlgebra() - override fun Tensor.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) + override val elementAlgebra: DoubleField + get() = DoubleField + + + /** + * Applies the [transform] function to each element of the tensor and returns the resulting modified tensor. + * + * @param transform the function to be applied to each element of the tensor. + * @return the resulting tensor after applying the function. + */ + @Suppress("OVERRIDE_BY_INLINE") + final override inline fun StructureND.map(transform: DoubleField.(Double) -> Double): DoubleTensor { + val tensor = this.tensor + //TODO remove additional copy + val sourceArray = tensor.copyArray() + val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) } + return DoubleTensor( + tensor.shape, + array, + tensor.bufferStart + ) + } + + @Suppress("OVERRIDE_BY_INLINE") + final override inline fun StructureND.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor { + val tensor = this.tensor + //TODO remove additional copy + val sourceArray = tensor.copyArray() + val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) } + return DoubleTensor( + tensor.shape, + array, + tensor.bufferStart + ) + } + + override fun zip( + left: StructureND, + right: StructureND, + transform: DoubleField.(Double, Double) -> Double + ): DoubleTensor { + require(left.shape.contentEquals(right.shape)){ + "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" + } + val leftTensor = left.tensor + val leftArray = leftTensor.copyArray() + val rightTensor = right.tensor + val rightArray = rightTensor.copyArray() + val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) } + return DoubleTensor( + leftTensor.shape, + array + ) + } + + override fun StructureND.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) tensor.mutableBuffer.array()[tensor.bufferStart] else null - override fun Tensor.value(): Double = valueOrNull() + override fun StructureND.value(): Double = valueOrNull() ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") /** @@ -53,11 +110,10 @@ public open class DoubleTensorAlgebra : * @param initializer mapping tensor indices to values. * @return tensor with the [shape] shape and data generated by the [initializer]. */ - public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor = - fromArray( - shape, - TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray() - ) + override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray( + shape, + TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() + ) override operator fun Tensor.get(i: Int): DoubleTensor { val lastShape = tensor.shape.drop(1).toIntArray() @@ -146,16 +202,16 @@ public open class DoubleTensorAlgebra : return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) } - override fun Double.plus(other: Tensor): DoubleTensor { + override fun Double.plus(other: StructureND): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this } return DoubleTensor(other.shape, resBuffer) } - override fun Tensor.plus(value: Double): DoubleTensor = value + tensor + override fun StructureND.plus(value: Double): DoubleTensor = value + tensor - override fun Tensor.plus(other: Tensor): DoubleTensor { + override fun StructureND.plus(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other.tensor) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] @@ -169,7 +225,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.plusAssign(other: Tensor) { + override fun Tensor.plusAssign(other: StructureND) { checkShapesCompatible(tensor, other.tensor) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += @@ -177,21 +233,21 @@ public open class DoubleTensorAlgebra : } } - override fun Double.minus(other: Tensor): DoubleTensor { + override fun Double.minus(other: StructureND): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } - override fun Tensor.minus(value: Double): DoubleTensor { + override fun StructureND.minus(arg: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] - value + tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg } return DoubleTensor(tensor.shape, resBuffer) } - override fun Tensor.minus(other: Tensor): DoubleTensor { + override fun StructureND.minus(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] @@ -205,7 +261,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.minusAssign(other: Tensor) { + override fun Tensor.minusAssign(other: StructureND) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= @@ -213,16 +269,16 @@ public open class DoubleTensorAlgebra : } } - override fun Double.times(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this + override fun Double.times(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.times(value: Double): DoubleTensor = value * tensor + override fun StructureND.times(arg: Double): DoubleTensor = arg * tensor - override fun Tensor.times(other: Tensor): DoubleTensor { + override fun StructureND.times(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i] * @@ -237,7 +293,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.timesAssign(other: Tensor) { + override fun Tensor.timesAssign(other: StructureND) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= @@ -245,21 +301,21 @@ public open class DoubleTensorAlgebra : } } - override fun Double.div(other: Tensor): DoubleTensor { - val resBuffer = DoubleArray(other.tensor.numElements) { i -> - this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + override fun Double.div(arg: StructureND): DoubleTensor { + val resBuffer = DoubleArray(arg.tensor.numElements) { i -> + this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] } - return DoubleTensor(other.shape, resBuffer) + return DoubleTensor(arg.shape, resBuffer) } - override fun Tensor.div(value: Double): DoubleTensor { + override fun StructureND.div(arg: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] / value + tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg } return DoubleTensor(shape, resBuffer) } - override fun Tensor.div(other: Tensor): DoubleTensor { + override fun StructureND.div(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / @@ -282,7 +338,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.unaryMinus(): DoubleTensor { + override fun StructureND.unaryMinus(): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() } @@ -302,11 +358,11 @@ public open class DoubleTensorAlgebra : val resTensor = DoubleTensor(resShape, resBuffer) for (offset in 0 until n) { - val oldMultiIndex = tensor.linearStructure.index(offset) + val oldMultiIndex = tensor.indices.index(offset) val newMultiIndex = oldMultiIndex.copyOf() newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } - val linearIndex = resTensor.linearStructure.offset(newMultiIndex) + val linearIndex = resTensor.indices.offset(newMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = tensor.mutableBuffer.array()[tensor.bufferStart + offset] } @@ -406,7 +462,7 @@ public open class DoubleTensorAlgebra : val resTensor = zeros(resShape) for (i in 0 until diagonalEntries.tensor.numElements) { - val multiIndex = diagonalEntries.tensor.linearStructure.index(i) + val multiIndex = diagonalEntries.tensor.indices.index(i) var offset1 = 0 var offset2 = abs(realOffset) @@ -425,18 +481,6 @@ public open class DoubleTensorAlgebra : return resTensor.tensor } - /** - * Applies the [transform] function to each element of the tensor and returns the resulting modified tensor. - * - * @param transform the function to be applied to each element of the tensor. - * @return the resulting tensor after applying the function. - */ - public inline fun Tensor.map(transform: (Double) -> Double): DoubleTensor = DoubleTensor( - tensor.shape, - tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(), - tensor.bufferStart - ) - /** * Compares element-wise two tensors with a specified precision. * @@ -526,7 +570,7 @@ public open class DoubleTensorAlgebra : public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) internal inline fun Tensor.fold(foldFunction: (DoubleArray) -> Double): Double = - foldFunction(tensor.toDoubleArray()) + foldFunction(tensor.copyArray()) internal inline fun Tensor.foldDim( foldFunction: (DoubleArray) -> Double, @@ -541,7 +585,7 @@ public open class DoubleTensorAlgebra : } val resNumElements = resShape.reduce(Int::times) val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) - for (index in resTensor.linearStructure.indices()) { + for (index in resTensor.indices.indices()) { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> @@ -645,39 +689,39 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.exp(): DoubleTensor = tensor.map(::exp) + override fun Tensor.exp(): DoubleTensor = tensor.map { exp(it) } - override fun Tensor.ln(): DoubleTensor = tensor.map(::ln) + override fun Tensor.ln(): DoubleTensor = tensor.map { ln(it) } - override fun Tensor.sqrt(): DoubleTensor = tensor.map(::sqrt) + override fun Tensor.sqrt(): DoubleTensor = tensor.map { sqrt(it) } - override fun Tensor.cos(): DoubleTensor = tensor.map(::cos) + override fun Tensor.cos(): DoubleTensor = tensor.map { cos(it) } - override fun Tensor.acos(): DoubleTensor = tensor.map(::acos) + override fun Tensor.acos(): DoubleTensor = tensor.map { acos(it) } - override fun Tensor.cosh(): DoubleTensor = tensor.map(::cosh) + override fun Tensor.cosh(): DoubleTensor = tensor.map { cosh(it) } - override fun Tensor.acosh(): DoubleTensor = tensor.map(::acosh) + override fun Tensor.acosh(): DoubleTensor = tensor.map { acosh(it) } - override fun Tensor.sin(): DoubleTensor = tensor.map(::sin) + override fun Tensor.sin(): DoubleTensor = tensor.map { sin(it) } - override fun Tensor.asin(): DoubleTensor = tensor.map(::asin) + override fun Tensor.asin(): DoubleTensor = tensor.map { asin(it) } - override fun Tensor.sinh(): DoubleTensor = tensor.map(::sinh) + override fun Tensor.sinh(): DoubleTensor = tensor.map { sinh(it) } - override fun Tensor.asinh(): DoubleTensor = tensor.map(::asinh) + override fun Tensor.asinh(): DoubleTensor = tensor.map { asinh(it) } - override fun Tensor.tan(): DoubleTensor = tensor.map(::tan) + override fun Tensor.tan(): DoubleTensor = tensor.map { tan(it) } - override fun Tensor.atan(): DoubleTensor = tensor.map(::atan) + override fun Tensor.atan(): DoubleTensor = tensor.map { atan(it) } - override fun Tensor.tanh(): DoubleTensor = tensor.map(::tanh) + override fun Tensor.tanh(): DoubleTensor = tensor.map { tanh(it) } - override fun Tensor.atanh(): DoubleTensor = tensor.map(::atanh) + override fun Tensor.atanh(): DoubleTensor = tensor.map { atanh(it) } - override fun Tensor.ceil(): DoubleTensor = tensor.map(::ceil) + override fun Tensor.ceil(): DoubleTensor = tensor.map { ceil(it) } - override fun Tensor.floor(): DoubleTensor = tensor.map(::floor) + override fun Tensor.floor(): DoubleTensor = tensor.map { floor(it) } override fun Tensor.inv(): DoubleTensor = invLU(1e-9) diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt index 4b9c0c382..3787c0972 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/broadcastUtils.kt @@ -10,7 +10,7 @@ import kotlin.math.max internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) { for (linearIndex in 0 until linearSize) { - val totalMultiIndex = resTensor.linearStructure.index(linearIndex) + val totalMultiIndex = resTensor.indices.index(linearIndex) val curMultiIndex = tensor.shape.copyOf() val offset = totalMultiIndex.size - curMultiIndex.size @@ -23,7 +23,7 @@ internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTenso } } - val curLinearIndex = tensor.linearStructure.offset(curMultiIndex) + val curLinearIndex = tensor.indices.offset(curMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex] } @@ -112,7 +112,7 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List checkShapesCompatible(a: Tensor, b: Tensor) = +internal fun checkShapesCompatible(a: StructureND, b: StructureND) = check(a.shape contentEquals b.shape) { "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt index 1f5778d5e..4123ff923 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt @@ -6,6 +6,7 @@ package space.kscience.kmath.tensors.core.internal import space.kscience.kmath.nd.MutableBufferND +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.BufferedTensor @@ -18,15 +19,15 @@ internal fun BufferedTensor.asTensor(): IntTensor = internal fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) -internal fun Tensor.copyToBufferedTensor(): BufferedTensor = +internal fun StructureND.copyToBufferedTensor(): BufferedTensor = BufferedTensor( this.shape, TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 ) -internal fun Tensor.toBufferedTensor(): BufferedTensor = when (this) { +internal fun StructureND.toBufferedTensor(): BufferedTensor = when (this) { is BufferedTensor -> this - is MutableBufferND -> if (this.indexes == TensorLinearStructure(this.shape)) { + is MutableBufferND -> if (this.indices == TensorLinearStructure(this.shape)) { BufferedTensor(this.shape, this.buffer, 0) } else { this.copyToBufferedTensor() @@ -35,7 +36,7 @@ internal fun Tensor.toBufferedTensor(): BufferedTensor = when (this) { } @PublishedApi -internal val Tensor.tensor: DoubleTensor +internal val StructureND.tensor: DoubleTensor get() = when (this) { is DoubleTensor -> this else -> this.toBufferedTensor().asTensor() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt index 8428dae5c..553ed6add 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/utils.kt @@ -85,7 +85,7 @@ internal fun format(value: Double, digits: Int = 4): String = buildString { internal fun DoubleTensor.toPrettyString(): String = buildString { var offset = 0 val shape = this@toPrettyString.shape - val linearStructure = this@toPrettyString.linearStructure + val linearStructure = this@toPrettyString.indices val vectorSize = shape.last() append("DoubleTensor(\n") var charOffset = 3 diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt index 021ca539c..feade56de 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt @@ -19,18 +19,19 @@ public fun Tensor.toDoubleTensor(): DoubleTensor = this.tensor public fun Tensor.toIntTensor(): IntTensor = this.tensor /** - * Returns [DoubleArray] of tensor elements + * Returns a copy-protected [DoubleArray] of tensor elements */ -public fun DoubleTensor.toDoubleArray(): DoubleArray { +public fun DoubleTensor.copyArray(): DoubleArray { + //TODO use ArrayCopy return DoubleArray(numElements) { i -> mutableBuffer[bufferStart + i] } } /** - * Returns [IntArray] of tensor elements + * Returns a copy-protected [IntArray] of tensor elements */ -public fun IntTensor.toIntArray(): IntArray { +public fun IntTensor.copyArray(): IntArray { return IntArray(numElements) { i -> mutableBuffer[bufferStart + i] } From 4635cd3fb3eaab8aac97936b50a08d463e7f91c6 Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 26 Oct 2021 16:08:02 +0300 Subject: [PATCH 2/3] Refactor TensorAlgebra to take StructureND and inherit AlgebraND --- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 123 +++++++++++------- .../api/TensorPartialDivisionAlgebra.kt | 2 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 2 +- 3 files changed, 76 insertions(+), 51 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index ee9251dd1..e40ea3dbc 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -13,7 +13,10 @@ import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.Field import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra @@ -22,7 +25,24 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra /** * ND4J based [TensorAlgebra] implementation. */ -public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra { +public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { + + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure { + val array = + } + + override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + /** * Wraps [INDArray] to [Nd4jArrayStructure]. */ @@ -33,105 +53,107 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public val StructureND.ndArray: INDArray - override fun T.plus(other: Tensor): Tensor = other.ndArray.add(this).wrap() - override fun Tensor.plus(value: T): Tensor = ndArray.add(value).wrap() + override fun T.plus(other: StructureND): Nd4jArrayStructure = other.ndArray.add(this).wrap() + override fun StructureND.plus(value: T): Nd4jArrayStructure = ndArray.add(value).wrap() - override fun Tensor.plus(other: Tensor): Tensor = ndArray.add(other.ndArray).wrap() + override fun StructureND.plus(other: StructureND): Nd4jArrayStructure = ndArray.add(other.ndArray).wrap() override fun Tensor.plusAssign(value: T) { ndArray.addi(value) } - override fun Tensor.plusAssign(other: Tensor) { + override fun Tensor.plusAssign(other: StructureND) { ndArray.addi(other.ndArray) } - override fun T.minus(other: Tensor): Tensor = other.ndArray.rsub(this).wrap() - override fun Tensor.minus(value: T): Tensor = ndArray.sub(value).wrap() - override fun Tensor.minus(other: Tensor): Tensor = ndArray.sub(other.ndArray).wrap() + override fun T.minus(other: StructureND): Nd4jArrayStructure = other.ndArray.rsub(this).wrap() + override fun StructureND.minus(arg: T): Nd4jArrayStructure = ndArray.sub(arg).wrap() + override fun StructureND.minus(other: StructureND): Nd4jArrayStructure = ndArray.sub(other.ndArray).wrap() override fun Tensor.minusAssign(value: T) { ndArray.rsubi(value) } - override fun Tensor.minusAssign(other: Tensor) { + override fun Tensor.minusAssign(other: StructureND) { ndArray.subi(other.ndArray) } - override fun T.times(other: Tensor): Tensor = other.ndArray.mul(this).wrap() + override fun T.times(arg: StructureND): Nd4jArrayStructure = arg.ndArray.mul(this).wrap() - override fun Tensor.times(value: T): Tensor = - ndArray.mul(value).wrap() + override fun StructureND.times(arg: T): Nd4jArrayStructure = + ndArray.mul(arg).wrap() - override fun Tensor.times(other: Tensor): Tensor = ndArray.mul(other.ndArray).wrap() + override fun StructureND.times(other: StructureND): Nd4jArrayStructure = ndArray.mul(other.ndArray).wrap() override fun Tensor.timesAssign(value: T) { ndArray.muli(value) } - override fun Tensor.timesAssign(other: Tensor) { + override fun Tensor.timesAssign(other: StructureND) { ndArray.mmuli(other.ndArray) } - override fun Tensor.unaryMinus(): Tensor = ndArray.neg().wrap() - override fun Tensor.get(i: Int): Tensor = ndArray.slice(i.toLong()).wrap() - override fun Tensor.transpose(i: Int, j: Int): Tensor = ndArray.swapAxes(i, j).wrap() - override fun Tensor.dot(other: Tensor): Tensor = ndArray.mmul(other.ndArray).wrap() + override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() + override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() + override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun Tensor.dot(other: Tensor): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() - override fun Tensor.min(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.min(keepDim, dim).wrap() - override fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.sum(keepDim, dim).wrap() - override fun Tensor.max(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.max(keepDim, dim).wrap() - override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() - override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) + override fun Tensor.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() + override fun Tensor.viewAs(other: Tensor): Nd4jArrayStructure = view(other.shape) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndBase.get().argmax(ndArray, keepDim, dim).wrap() - override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() + override fun Tensor.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.mean(keepDim, dim).wrap() - override fun Tensor.exp(): Tensor = Transforms.exp(ndArray).wrap() - override fun Tensor.ln(): Tensor = Transforms.log(ndArray).wrap() - override fun Tensor.sqrt(): Tensor = Transforms.sqrt(ndArray).wrap() - override fun Tensor.cos(): Tensor = Transforms.cos(ndArray).wrap() - override fun Tensor.acos(): Tensor = Transforms.acos(ndArray).wrap() - override fun Tensor.cosh(): Tensor = Transforms.cosh(ndArray).wrap() + override fun Tensor.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() + override fun Tensor.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() + override fun Tensor.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() + override fun Tensor.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() + override fun Tensor.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() + override fun Tensor.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - override fun Tensor.acosh(): Tensor = + override fun Tensor.acosh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.sin(): Tensor = Transforms.sin(ndArray).wrap() - override fun Tensor.asin(): Tensor = Transforms.asin(ndArray).wrap() + override fun Tensor.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() + override fun Tensor.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() - override fun Tensor.asinh(): Tensor = + override fun Tensor.asinh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.tan(): Tensor = Transforms.tan(ndArray).wrap() - override fun Tensor.atan(): Tensor = Transforms.atan(ndArray).wrap() - override fun Tensor.tanh(): Tensor = Transforms.tanh(ndArray).wrap() - override fun Tensor.atanh(): Tensor = Transforms.atanh(ndArray).wrap() - override fun Tensor.ceil(): Tensor = Transforms.ceil(ndArray).wrap() - override fun Tensor.floor(): Tensor = Transforms.floor(ndArray).wrap() - override fun Tensor.std(dim: Int, keepDim: Boolean): Tensor = ndArray.std(true, keepDim, dim).wrap() - override fun T.div(other: Tensor): Tensor = other.ndArray.rdiv(this).wrap() - override fun Tensor.div(value: T): Tensor = ndArray.div(value).wrap() - override fun Tensor.div(other: Tensor): Tensor = ndArray.div(other.ndArray).wrap() + override fun Tensor.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() + override fun Tensor.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() + override fun Tensor.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() + override fun Tensor.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun Tensor.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() + override fun Tensor.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() + override fun Tensor.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.std(true, keepDim, dim).wrap() + + override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() + override fun StructureND.div(arg: T): Nd4jArrayStructure = ndArray.div(arg).wrap() + override fun StructureND.div(other: StructureND): Nd4jArrayStructure = ndArray.div(other.ndArray).wrap() override fun Tensor.divAssign(value: T) { ndArray.divi(value) } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { ndArray.divi(other.ndArray) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() private companion object { @@ -142,7 +164,10 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra /** * [Double] specialization of [Nd4jTensorAlgebra]. */ -public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { +public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { + + override val elementAlgebra: DoubleField get() = DoubleField + override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() @OptIn(PerformancePitfall::class) @@ -154,7 +179,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { } } - override fun Tensor.valueOrNull(): Double? = + override fun StructureND.valueOrNull(): Double? = if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null // TODO rewrite diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index 43c901ed5..dece54834 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -57,5 +57,5 @@ public interface TensorPartialDivisionAlgebra> : TensorAlgebra.divAssign(other: Tensor) + public operator fun Tensor.divAssign(other: StructureND) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 472db8f5f..3d343604b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -330,7 +330,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= From 29a90efca5eaa3c7669cae8c6dd0f24a8f31058f Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 27 Oct 2021 14:48:36 +0300 Subject: [PATCH 3/3] Tensor algebra generified --- CHANGELOG.md | 1 + .../kmath/benchmarks/NDFieldBenchmark.kt | 7 +- .../space/kscience/kmath/tensors/multik.kt | 6 +- .../kscience/kmath/multik/MultikOpsND.kt | 137 ------------- .../kmath/multik/MultikTensorAlgebra.kt | 187 ++++++++++++++---- .../kscience/kmath/multik/MultikNDTest.kt | 2 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 113 ++++++----- .../tensors/api/AnalyticTensorAlgebra.kt | 47 ++--- .../tensors/api/LinearOpsTensorAlgebra.kt | 15 +- .../kmath/tensors/api/TensorAlgebra.kt | 36 ++-- .../core/BroadcastDoubleTensorAlgebra.kt | 2 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 117 ++++++----- 12 files changed, 323 insertions(+), 347 deletions(-) delete mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index bb267744e..6733c1211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - Buffer algebra does not require size anymore - Operations -> Ops - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. +- Tensor algebra takes read-only structures as input and inherits AlgebraND ### Deprecated - Specialized `DoubleBufferAlgebra` diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 8f9a3e2b8..b5af5aa19 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -13,8 +13,7 @@ import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ones import org.jetbrains.kotlinx.multik.ndarray.data.DN import org.jetbrains.kotlinx.multik.ndarray.data.DataType -import space.kscience.kmath.multik.multikND -import space.kscience.kmath.multik.multikTensorAlgebra +import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.ndAlgebra @@ -79,7 +78,7 @@ internal class NDFieldBenchmark { } @Benchmark - fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) { + fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) { val res = Multik.ones(shape, DataType.DoubleDataType).wrap() repeat(n) { res += 1.0 } blackhole.consume(res) @@ -100,7 +99,7 @@ internal class NDFieldBenchmark { private val specializedField = DoubleField.ndAlgebra private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) private val nd4jField = DoubleField.nd4j - private val multikField = DoubleField.multikND + private val multikField = DoubleField.multikAlgebra private val viktorField = DoubleField.viktorAlgebra } } diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt index f0b776f75..fad68fa96 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt @@ -7,12 +7,12 @@ package space.kscience.kmath.tensors import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarray -import space.kscience.kmath.multik.multikND +import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField -fun main(): Unit = with(DoubleField.multikND) { +fun main(): Unit = with(DoubleField.multikAlgebra) { val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType().wrap() val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() - one(a.shape) - a + b * 3 + one(a.shape) - a + b * 3.0 } diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt deleted file mode 100644 index 4068ba9b9..000000000 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikOpsND.kt +++ /dev/null @@ -1,137 +0,0 @@ -package space.kscience.kmath.multik - -import org.jetbrains.kotlinx.multik.api.math.cos -import org.jetbrains.kotlinx.multik.api.math.sin -import org.jetbrains.kotlinx.multik.api.mk -import org.jetbrains.kotlinx.multik.api.zeros -import org.jetbrains.kotlinx.multik.ndarray.data.* -import org.jetbrains.kotlinx.multik.ndarray.operations.* -import space.kscience.kmath.nd.FieldOpsND -import space.kscience.kmath.nd.RingOpsND -import space.kscience.kmath.nd.Shape -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.operations.* - -/** - * A ring algebra for Multik operations - */ -public open class MultikRingOpsND> internal constructor( - public val type: DataType, - override val elementAlgebra: A -) : RingOpsND { - - public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) - - override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { - val res = mk.zeros(shape, type).asDNArray() - for (index in res.multiIndices) { - res[index] = elementAlgebra.initializer(index) - } - return res.wrap() - } - - public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { - this - } else { - structureND(shape) { get(it) } - } - - override fun StructureND.map(transform: A.(T) -> T): MultikTensor { - //taken directly from Multik sources - val array = asMultik().array - val data = initMemoryView(array.size, type) - var count = 0 - for (el in array) data[count++] = elementAlgebra.transform(el) - return NDArray(data, shape = array.shape, dim = array.dim).wrap() - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor { - //taken directly from Multik sources - val array = asMultik().array - val data = initMemoryView(array.size, type) - val indexIter = array.multiIndices.iterator() - var index = 0 - for (item in array) { - if (indexIter.hasNext()) { - data[index++] = elementAlgebra.transform(indexIter.next(), item) - } else { - throw ArithmeticException("Index overflow has happened.") - } - } - return NDArray(data, shape = array.shape, dim = array.dim).wrap() - } - - override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { - require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException - val leftArray = left.asMultik().array - val rightArray = right.asMultik().array - val data = initMemoryView(leftArray.size, type) - var counter = 0 - val leftIterator = leftArray.iterator() - val rightIterator = rightArray.iterator() - //iterating them together - while (leftIterator.hasNext()) { - data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next()) - } - return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap() - } - - override fun StructureND.unaryMinus(): MultikTensor = asMultik().array.unaryMinus().wrap() - - override fun add(left: StructureND, right: StructureND): MultikTensor = - (left.asMultik().array + right.asMultik().array).wrap() - - override fun StructureND.plus(arg: T): MultikTensor = - asMultik().array.plus(arg).wrap() - - override fun StructureND.minus(arg: T): MultikTensor = asMultik().array.minus(arg).wrap() - - override fun T.plus(arg: StructureND): MultikTensor = arg + this - - override fun T.minus(arg: StructureND): MultikTensor = arg.map { this@minus - it } - - override fun multiply(left: StructureND, right: StructureND): MultikTensor = - left.asMultik().array.times(right.asMultik().array).wrap() - - override fun StructureND.times(arg: T): MultikTensor = - asMultik().array.times(arg).wrap() - - override fun T.times(arg: StructureND): MultikTensor = arg * this - - override fun StructureND.unaryPlus(): MultikTensor = asMultik() - - override fun StructureND.plus(other: StructureND): MultikTensor = - asMultik().array.plus(other.asMultik().array).wrap() - - override fun StructureND.minus(other: StructureND): MultikTensor = - asMultik().array.minus(other.asMultik().array).wrap() - - override fun StructureND.times(other: StructureND): MultikTensor = - asMultik().array.times(other.asMultik().array).wrap() -} - -/** - * A field algebra for multik operations - */ -public class MultikFieldOpsND> internal constructor( - type: DataType, - elementAlgebra: A -) : MultikRingOpsND(type, elementAlgebra), FieldOpsND { - override fun StructureND.div(other: StructureND): StructureND = - asMultik().array.div(other.asMultik().array).wrap() -} - -public val DoubleField.multikND: MultikFieldOpsND - get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField) - -public val FloatField.multikND: MultikFieldOpsND - get() = MultikFieldOpsND(DataType.FloatDataType, FloatField) - -public val ShortRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.ShortDataType, ShortRing) - -public val IntRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.IntDataType, IntRing) - -public val LongRing.multikND: MultikRingOpsND - get() = MultikRingOpsND(DataType.LongDataType, LongRing) \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index 3bbe8e672..ec0d4c6d9 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -15,12 +15,14 @@ import org.jetbrains.kotlinx.multik.api.zeros import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.* import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra +import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra @JvmInline public value class MultikTensor(public val array: MutableMultiArray) : Tensor { @@ -50,29 +52,80 @@ private fun MultiArray.asD2Array(): D2Array { else throw ClassCastException("Cannot cast MultiArray to NDArray.") } -public class MultikTensorAlgebra> internal constructor( - public val type: DataType, - override val elementAlgebra: A, - public val comparator: Comparator -) : TensorAlgebra { +public abstract class MultikTensorAlgebra> : TensorAlgebra where T : Number, T : Comparable { + + public abstract val type: DataType + + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { + val strides = DefaultStrides(shape) + val memoryView = initMemoryView(strides.linearSize, type) + strides.indices().forEachIndexed { linearIndex, tensorIndex -> + memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex) + } + return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size))) + } + + override fun StructureND.map(transform: A.(T) -> T): MultikTensor = if (this is MultikTensor) { + val data = initMemoryView(array.size, type) + var count = 0 + for (el in array) data[count++] = elementAlgebra.transform(el) + NDArray(data, shape = shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(get(index)) + } + } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor = + if (this is MultikTensor) { + val array = asMultik().array + val data = initMemoryView(array.size, type) + val indexIter = array.multiIndices.iterator() + var index = 0 + for (item in array) { + if (indexIter.hasNext()) { + data[index++] = elementAlgebra.transform(indexIter.next(), item) + } else { + throw ArithmeticException("Index overflow has happened.") + } + } + NDArray(data, shape = array.shape, dim = array.dim).wrap() + } else { + structureND(shape) { index -> + transform(index, get(index)) + } + } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): MultikTensor { + require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException + val leftArray = left.asMultik().array + val rightArray = right.asMultik().array + val data = initMemoryView(leftArray.size, type) + var counter = 0 + val leftIterator = leftArray.iterator() + val rightIterator = rightArray.iterator() + //iterating them together + while (leftIterator.hasNext()) { + data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next()) + } + return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap() + } /** * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * are not reflected back onto the source */ - public fun StructureND.asMultik(): MultikTensor { - return if (this is MultikTensor) { - this - } else { - val res = mk.zeros(shape, type).asDNArray() - for (index in res.multiIndices) { - res[index] = this[index] - } - res.wrap() + public fun StructureND.asMultik(): MultikTensor = if (this is MultikTensor) { + this + } else { + val res = mk.zeros(shape, type).asDNArray() + for (index in res.multiIndices) { + res[index] = this[index] } + res.wrap() } - public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this) + public fun MutableMultiArray.wrap(): MultikTensor = MultikTensor(this.asDNArray()) override fun StructureND.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { get(intArrayOf(0)) @@ -81,8 +134,8 @@ public class MultikTensorAlgebra> internal constructor( override fun T.plus(other: StructureND): MultikTensor = other.plus(this) - override fun StructureND.plus(value: T): MultikTensor = - asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() + override fun StructureND.plus(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap() override fun StructureND.plus(other: StructureND): MultikTensor = asMultik().array.plus(other.asMultik().array).wrap() @@ -155,11 +208,11 @@ public class MultikTensorAlgebra> internal constructor( override fun StructureND.unaryMinus(): MultikTensor = asMultik().array.unaryMinus().wrap() - override fun Tensor.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() + override fun StructureND.get(i: Int): MultikTensor = asMultik().array.mutableView(i).wrap() - override fun Tensor.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() + override fun StructureND.transpose(i: Int, j: Int): MultikTensor = asMultik().array.transpose(i, j).wrap() - override fun Tensor.view(shape: IntArray): MultikTensor { + override fun StructureND.view(shape: IntArray): MultikTensor { require(shape.all { it > 0 }) require(shape.fold(1, Int::times) == this.shape.size) { "Cannot reshape array of size ${this.shape.size} into a new shape ${ @@ -178,9 +231,9 @@ public class MultikTensorAlgebra> internal constructor( }.wrap() } - override fun Tensor.viewAs(other: Tensor): MultikTensor = view(other.shape) + override fun StructureND.viewAs(other: StructureND): MultikTensor = view(other.shape) - override fun Tensor.dot(other: Tensor): MultikTensor = + override fun StructureND.dot(other: StructureND): MultikTensor = if (this.shape.size == 1 && other.shape.size == 1) { Multik.ndarrayOf( asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() @@ -197,45 +250,95 @@ public class MultikTensorAlgebra> internal constructor( TODO("Diagonal embedding not implemented") } - override fun Tensor.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> + override fun StructureND.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> elementAlgebra.add(acc, t) } - override fun Tensor.sum(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.sum(dim: Int, keepDim: Boolean): MultikTensor { TODO("Not yet implemented") } - override fun Tensor.min(): T = - asMultik().array.minWith(comparator) ?: error("No elements in tensor") + override fun StructureND.min(): T? = asMultik().array.min() - override fun Tensor.min(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.min(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.max(): T = - asMultik().array.maxWith(comparator) ?: error("No elements in tensor") + override fun StructureND.max(): T? = asMultik().array.max() - - override fun Tensor.max(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.max(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } - override fun Tensor.argMax(dim: Int, keepDim: Boolean): MultikTensor { + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { TODO("Not yet implemented") } } -public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } +public abstract class MultikDivisionTensorAlgebra> + : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { -public val FloatField.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) } + override fun T.div(arg: StructureND): MultikTensor = arg.map { elementAlgebra.divide(this@div, it) } -public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } + override fun StructureND.div(arg: T): MultikTensor = + asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() -public val IntRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } + override fun StructureND.div(other: StructureND): MultikTensor = + asMultik().array.div(other.asMultik().array).wrap() -public val LongRing.multikTensorAlgebra: MultikTensorAlgebra - get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } \ No newline at end of file + override fun Tensor.divAssign(value: T) { + if (this is MultikTensor) { + array.divAssign(value) + } else { + mapInPlace { _, t -> elementAlgebra.divide(t, value) } + } + } + + override fun Tensor.divAssign(other: StructureND) { + if (this is MultikTensor) { + array.divAssign(other.asMultik().array) + } else { + mapInPlace { index, t -> elementAlgebra.divide(t, other[index]) } + } + } +} + +public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: DoubleField get() = DoubleField + override val type: DataType get() = DataType.DoubleDataType +} + +public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra +public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra + +public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: FloatField get() = FloatField + override val type: DataType get() = DataType.FloatDataType +} + +public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra +public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra + +public object MultikShortAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: ShortRing get() = ShortRing + override val type: DataType get() = DataType.ShortDataType +} + +public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra +public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra + +public object MultikIntAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: IntRing get() = IntRing + override val type: DataType get() = DataType.IntDataType +} + +public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra +public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra + +public object MultikLongAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: LongRing get() = LongRing + override val type: DataType get() = DataType.LongDataType +} + +public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra +public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file diff --git a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt index 9dd9854c7..66ba7db2d 100644 --- a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt +++ b/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt @@ -7,7 +7,7 @@ import space.kscience.kmath.operations.invoke internal class MultikNDTest { @Test - fun basicAlgebra(): Unit = DoubleField.multikND{ + fun basicAlgebra(): Unit = DoubleField.multikAlgebra{ one(2,2) + 1.0 } } \ No newline at end of file diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index e40ea3dbc..3fb03c964 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -13,6 +13,7 @@ import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.DefaultStrides import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.DoubleField @@ -27,22 +28,6 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra */ public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { - override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure { - val array = - } - - override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - - override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - - override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { - TODO("Not yet implemented") - } - /** * Wraps [INDArray] to [Nd4jArrayStructure]. */ @@ -53,8 +38,21 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe */ public val StructureND.ndArray: INDArray + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure + + override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(get(index)) } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure = + structureND(shape) { index -> elementAlgebra.transform(index, get(index)) } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { + require(left.shape.contentEquals(right.shape)) + return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) } + } + override fun T.plus(other: StructureND): Nd4jArrayStructure = other.ndArray.add(this).wrap() - override fun StructureND.plus(value: T): Nd4jArrayStructure = ndArray.add(value).wrap() + override fun StructureND.plus(arg: T): Nd4jArrayStructure = ndArray.add(arg).wrap() override fun StructureND.plus(other: StructureND): Nd4jArrayStructure = ndArray.add(other.ndArray).wrap() @@ -94,51 +92,52 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe } override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() - override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() - override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() - override fun Tensor.dot(other: Tensor): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() + override fun StructureND.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() + override fun StructureND.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun StructureND.dot(other: StructureND): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() - override fun Tensor.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.min(keepDim, dim).wrap() - override fun Tensor.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.sum(keepDim, dim).wrap() - override fun Tensor.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.max(keepDim, dim).wrap() - override fun Tensor.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() - override fun Tensor.viewAs(other: Tensor): Nd4jArrayStructure = view(other.shape) + override fun StructureND.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() + override fun StructureND.viewAs(other: StructureND): Nd4jArrayStructure = view(other.shape) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndBase.get().argmax(ndArray, keepDim, dim).wrap() - override fun Tensor.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.mean(keepDim, dim).wrap() + override fun StructureND.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.mean(keepDim, dim).wrap() - override fun Tensor.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() - override fun Tensor.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() - override fun Tensor.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() - override fun Tensor.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() - override fun Tensor.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() - override fun Tensor.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() + override fun StructureND.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() + override fun StructureND.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() + override fun StructureND.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() + override fun StructureND.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() + override fun StructureND.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() + override fun StructureND.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - override fun Tensor.acosh(): Nd4jArrayStructure = + override fun StructureND.acosh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() - override fun Tensor.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() - override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() + override fun StructureND.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() + override fun StructureND.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() + override fun StructureND.sinh(): Tensor = Transforms.sinh(ndArray).wrap() - override fun Tensor.asinh(): Nd4jArrayStructure = + override fun StructureND.asinh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() - override fun Tensor.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() - override fun Tensor.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() - override fun Tensor.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() - override fun Tensor.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() - override fun Tensor.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() - override fun Tensor.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() + override fun StructureND.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() + override fun StructureND.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() + override fun StructureND.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun StructureND.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() + override fun StructureND.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() + override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.std(true, keepDim, dim).wrap() override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() @@ -153,7 +152,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe ndArray.divi(other.ndArray) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + override fun StructureND.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() private companion object { @@ -170,6 +169,16 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() + override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure { + val array: INDArray = Nd4j.zeros(*shape) + val indices = DefaultStrides(shape) + indices.indices().forEach { index -> + array.putScalar(index, elementAlgebra.initializer(index)) + } + return array.wrap() + } + + @OptIn(PerformancePitfall::class) override val StructureND.ndArray: INDArray get() = when (this) { @@ -190,10 +199,10 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { dim2: Int, ): Tensor = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) - override fun Tensor.sum(): Double = ndArray.sumNumber().toDouble() - override fun Tensor.min(): Double = ndArray.minNumber().toDouble() - override fun Tensor.max(): Double = ndArray.maxNumber().toDouble() - override fun Tensor.mean(): Double = ndArray.meanNumber().toDouble() - override fun Tensor.std(): Double = ndArray.stdNumber().toDouble() - override fun Tensor.variance(): Double = ndArray.varNumber().toDouble() + override fun StructureND.sum(): Double = ndArray.sumNumber().toDouble() + override fun StructureND.min(): Double = ndArray.minNumber().toDouble() + override fun StructureND.max(): Double = ndArray.maxNumber().toDouble() + override fun StructureND.mean(): Double = ndArray.meanNumber().toDouble() + override fun StructureND.std(): Double = ndArray.stdNumber().toDouble() + override fun StructureND.variance(): Double = ndArray.varNumber().toDouble() } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index caafcc7c1..debfb3ef0 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Field @@ -18,7 +19,7 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA /** * @return the mean of all elements in the input tensor. */ - public fun Tensor.mean(): T + public fun StructureND.mean(): T /** * Returns the mean of each row of the input tensor in the given dimension [dim]. @@ -31,12 +32,12 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the mean of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.mean(dim: Int, keepDim: Boolean): Tensor /** * @return the standard deviation of all elements in the input tensor. */ - public fun Tensor.std(): T + public fun StructureND.std(): T /** * Returns the standard deviation of each row of the input tensor in the given dimension [dim]. @@ -49,12 +50,12 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the standard deviation of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.std(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.std(dim: Int, keepDim: Boolean): Tensor /** * @return the variance of all elements in the input tensor. */ - public fun Tensor.variance(): T + public fun StructureND.variance(): T /** * Returns the variance of each row of the input tensor in the given dimension [dim]. @@ -67,57 +68,57 @@ public interface AnalyticTensorAlgebra> : TensorPartialDivisionA * @param keepDim whether the output tensor has [dim] retained or not. * @return the variance of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.variance(dim: Int, keepDim: Boolean): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.exp.html - public fun Tensor.exp(): Tensor + public fun StructureND.exp(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.log.html - public fun Tensor.ln(): Tensor + public fun StructureND.ln(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html - public fun Tensor.sqrt(): Tensor + public fun StructureND.sqrt(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun Tensor.cos(): Tensor + public fun StructureND.cos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun Tensor.acos(): Tensor + public fun StructureND.acos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun Tensor.cosh(): Tensor + public fun StructureND.cosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun Tensor.acosh(): Tensor + public fun StructureND.acosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun Tensor.sin(): Tensor + public fun StructureND.sin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun Tensor.asin(): Tensor + public fun StructureND.asin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun Tensor.sinh(): Tensor + public fun StructureND.sinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun Tensor.asinh(): Tensor + public fun StructureND.asinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun Tensor.tan(): Tensor + public fun StructureND.tan(): Tensor //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun Tensor.atan(): Tensor + public fun StructureND.atan(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun Tensor.tanh(): Tensor + public fun StructureND.tanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun Tensor.atanh(): Tensor + public fun StructureND.atanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil - public fun Tensor.ceil(): Tensor + public fun StructureND.ceil(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor - public fun Tensor.floor(): Tensor + public fun StructureND.floor(): Tensor } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index 3f32eb9ca..e8fa7dacd 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.tensors.api +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.operations.Field /** @@ -20,7 +21,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * * @return the determinant. */ - public fun Tensor.det(): Tensor + public fun StructureND.det(): Tensor /** * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. @@ -30,7 +31,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * * @return the multiplicative inverse of a matrix. */ - public fun Tensor.inv(): Tensor + public fun StructureND.inv(): Tensor /** * Cholesky decomposition. @@ -46,7 +47,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return the batch of `L` matrices. */ - public fun Tensor.cholesky(): Tensor + public fun StructureND.cholesky(): Tensor /** * QR decomposition. @@ -60,7 +61,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return pair of `Q` and `R` tensors. */ - public fun Tensor.qr(): Pair, Tensor> + public fun StructureND.qr(): Pair, Tensor> /** * LUP decomposition @@ -74,7 +75,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return triple of P, L and U tensors */ - public fun Tensor.lu(): Triple, Tensor, Tensor> + public fun StructureND.lu(): Triple, Tensor, Tensor> /** * Singular Value Decomposition. @@ -90,7 +91,7 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return triple `Triple(U, S, V)`. */ - public fun Tensor.svd(): Triple, Tensor, Tensor> + public fun StructureND.svd(): Triple, Tensor, Tensor> /** * Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices, @@ -100,6 +101,6 @@ public interface LinearOpsTensorAlgebra> : TensorPartialDivision * @receiver the `input`. * @return a pair `eigenvalues to eigenvectors` */ - public fun Tensor.symEig(): Pair, Tensor> + public fun StructureND.symEig(): Pair, Tensor> } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index e910c5c31..28d7c4677 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -41,12 +41,12 @@ public interface TensorAlgebra> : RingOpsND { override operator fun T.plus(other: StructureND): Tensor /** - * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. + * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor. * - * @param value the number to be added to each element of this tensor. - * @return the sum of this tensor and [value]. + * @param arg the number to be added to each element of this tensor. + * @return the sum of this tensor and [arg]. */ - override operator fun StructureND.plus(value: T): Tensor + override operator fun StructureND.plus(arg: T): Tensor /** * Each element of the tensor [other] is added to each element of this tensor. @@ -166,7 +166,7 @@ public interface TensorAlgebra> : RingOpsND { * @param i index of the extractable tensor * @return subtensor of the original tensor with index [i] */ - public operator fun Tensor.get(i: Int): Tensor + public operator fun StructureND.get(i: Int): Tensor /** * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. @@ -176,7 +176,7 @@ public interface TensorAlgebra> : RingOpsND { * @param j the second dimension to be transposed * @return transposed tensor */ - public fun Tensor.transpose(i: Int = -2, j: Int = -1): Tensor + public fun StructureND.transpose(i: Int = -2, j: Int = -1): Tensor /** * Returns a new tensor with the same data as the self tensor but of a different shape. @@ -186,7 +186,7 @@ public interface TensorAlgebra> : RingOpsND { * @param shape the desired size * @return tensor with new shape */ - public fun Tensor.view(shape: IntArray): Tensor + public fun StructureND.view(shape: IntArray): Tensor /** * View this tensor as the same size as [other]. @@ -196,7 +196,7 @@ public interface TensorAlgebra> : RingOpsND { * @param other the result tensor has the same size as other. * @return the result tensor with the same size as other. */ - public fun Tensor.viewAs(other: Tensor): Tensor + public fun StructureND.viewAs(other: StructureND): Tensor /** * Matrix product of two tensors. @@ -227,7 +227,7 @@ public interface TensorAlgebra> : RingOpsND { * @param other tensor to be multiplied. * @return a mathematical product of two tensors. */ - public infix fun Tensor.dot(other: Tensor): Tensor + public infix fun StructureND.dot(other: StructureND): Tensor /** * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) @@ -262,7 +262,7 @@ public interface TensorAlgebra> : RingOpsND { /** * @return the sum of all elements in the input tensor. */ - public fun Tensor.sum(): T + public fun StructureND.sum(): T /** * Returns the sum of each row of the input tensor in the given dimension [dim]. @@ -275,12 +275,12 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the sum of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.sum(dim: Int, keepDim: Boolean): Tensor /** - * @return the minimum value of all elements in the input tensor. + * @return the minimum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.min(): T + public fun StructureND.min(): T? /** * Returns the minimum value of each row of the input tensor in the given dimension [dim]. @@ -293,12 +293,12 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the minimum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.min(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.min(dim: Int, keepDim: Boolean): Tensor /** - * Returns the maximum value of all elements in the input tensor. + * Returns the maximum value of all elements in the input tensor or null if there are no values */ - public fun Tensor.max(): T + public fun StructureND.max(): T? /** * Returns the maximum value of each row of the input tensor in the given dimension [dim]. @@ -311,7 +311,7 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.max(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.max(dim: Int, keepDim: Boolean): Tensor /** * Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. @@ -324,7 +324,7 @@ public interface TensorAlgebra> : RingOpsND { * @param keepDim whether the output tensor has [dim] retained or not. * @return the index of maximum value of each row of the input tensor in the given dimension [dim]. */ - public fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor + public fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor override fun add(left: StructureND, right: StructureND): Tensor = left + right diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index f3c5d2540..f53a0ca55 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -85,7 +85,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.indices.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 3d343604b..be87f8019 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -115,7 +115,7 @@ public open class DoubleTensorAlgebra : TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() ) - override operator fun Tensor.get(i: Int): DoubleTensor { + override operator fun StructureND.get(i: Int): DoubleTensor { val lastShape = tensor.shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart @@ -160,7 +160,7 @@ public open class DoubleTensorAlgebra : * * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. */ - public fun Tensor.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) + public fun StructureND.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) /** * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. @@ -198,9 +198,8 @@ public open class DoubleTensorAlgebra : * * @return a copy of the `input` tensor with a copied buffer. */ - public fun Tensor.copy(): DoubleTensor { - return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) - } + public fun StructureND.copy(): DoubleTensor = + DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) override fun Double.plus(other: StructureND): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> @@ -209,7 +208,7 @@ public open class DoubleTensorAlgebra : return DoubleTensor(other.shape, resBuffer) } - override fun StructureND.plus(value: Double): DoubleTensor = value + tensor + override fun StructureND.plus(arg: Double): DoubleTensor = arg + tensor override fun StructureND.plus(other: StructureND): DoubleTensor { checkShapesCompatible(tensor, other.tensor) @@ -345,7 +344,7 @@ public open class DoubleTensorAlgebra : return DoubleTensor(tensor.shape, resBuffer) } - override fun Tensor.transpose(i: Int, j: Int): DoubleTensor { + override fun StructureND.transpose(i: Int, j: Int): DoubleTensor { val ii = tensor.minusIndex(i) val jj = tensor.minusIndex(j) checkTranspose(tensor.dimension, ii, jj) @@ -369,15 +368,15 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.view(shape: IntArray): DoubleTensor { + override fun StructureND.view(shape: IntArray): DoubleTensor { checkView(tensor, shape) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) } - override fun Tensor.viewAs(other: Tensor): DoubleTensor = + override fun StructureND.viewAs(other: StructureND): DoubleTensor = tensor.view(other.shape) - override infix fun Tensor.dot(other: Tensor): DoubleTensor { + override infix fun StructureND.dot(other: StructureND): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) } @@ -569,10 +568,10 @@ public open class DoubleTensorAlgebra : */ public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) - internal inline fun Tensor.fold(foldFunction: (DoubleArray) -> Double): Double = + internal inline fun StructureND.fold(foldFunction: (DoubleArray) -> Double): Double = foldFunction(tensor.copyArray()) - internal inline fun Tensor.foldDim( + internal inline fun StructureND.foldDim( foldFunction: (DoubleArray) -> Double, dim: Int, keepDim: Boolean, @@ -596,30 +595,30 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.sum(): Double = tensor.fold { it.sum() } + override fun StructureND.sum(): Double = tensor.fold { it.sum() } - override fun Tensor.sum(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.sum(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.sum() }, dim, keepDim) - override fun Tensor.min(): Double = this.fold { it.minOrNull()!! } + override fun StructureND.min(): Double = this.fold { it.minOrNull()!! } - override fun Tensor.min(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.min(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.minOrNull()!! }, dim, keepDim) - override fun Tensor.max(): Double = this.fold { it.maxOrNull()!! } + override fun StructureND.max(): Double = this.fold { it.maxOrNull()!! } - override fun Tensor.max(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.max(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.argMax(dim: Int, keepDim: Boolean): DoubleTensor = foldDim({ x -> x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() }, dim, keepDim) - override fun Tensor.mean(): Double = this.fold { it.sum() / tensor.numElements } + override fun StructureND.mean(): Double = this.fold { it.sum() / tensor.numElements } - override fun Tensor.mean(dim: Int, keepDim: Boolean): DoubleTensor = + override fun StructureND.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } @@ -629,12 +628,12 @@ public open class DoubleTensorAlgebra : keepDim ) - override fun Tensor.std(): Double = this.fold { arr -> + override fun StructureND.std(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) } - override fun Tensor.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -644,12 +643,12 @@ public open class DoubleTensorAlgebra : keepDim ) - override fun Tensor.variance(): Double = this.fold { arr -> + override fun StructureND.variance(): Double = this.fold { arr -> val mean = arr.sum() / tensor.numElements arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( + override fun StructureND.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] @@ -672,7 +671,7 @@ public open class DoubleTensorAlgebra : * @param tensors the [List] of 1-dimensional tensors with same shape * @return `M`. */ - public fun cov(tensors: List>): DoubleTensor { + public fun cov(tensors: List>): DoubleTensor { check(tensors.isNotEmpty()) { "List must have at least 1 element" } val n = tensors.size val m = tensors[0].shape[0] @@ -689,43 +688,43 @@ public open class DoubleTensorAlgebra : return resTensor } - override fun Tensor.exp(): DoubleTensor = tensor.map { exp(it) } + override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } - override fun Tensor.ln(): DoubleTensor = tensor.map { ln(it) } + override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } - override fun Tensor.sqrt(): DoubleTensor = tensor.map { sqrt(it) } + override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } - override fun Tensor.cos(): DoubleTensor = tensor.map { cos(it) } + override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } - override fun Tensor.acos(): DoubleTensor = tensor.map { acos(it) } + override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } - override fun Tensor.cosh(): DoubleTensor = tensor.map { cosh(it) } + override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } - override fun Tensor.acosh(): DoubleTensor = tensor.map { acosh(it) } + override fun StructureND.acosh(): DoubleTensor = tensor.map { acosh(it) } - override fun Tensor.sin(): DoubleTensor = tensor.map { sin(it) } + override fun StructureND.sin(): DoubleTensor = tensor.map { sin(it) } - override fun Tensor.asin(): DoubleTensor = tensor.map { asin(it) } + override fun StructureND.asin(): DoubleTensor = tensor.map { asin(it) } - override fun Tensor.sinh(): DoubleTensor = tensor.map { sinh(it) } + override fun StructureND.sinh(): DoubleTensor = tensor.map { sinh(it) } - override fun Tensor.asinh(): DoubleTensor = tensor.map { asinh(it) } + override fun StructureND.asinh(): DoubleTensor = tensor.map { asinh(it) } - override fun Tensor.tan(): DoubleTensor = tensor.map { tan(it) } + override fun StructureND.tan(): DoubleTensor = tensor.map { tan(it) } - override fun Tensor.atan(): DoubleTensor = tensor.map { atan(it) } + override fun StructureND.atan(): DoubleTensor = tensor.map { atan(it) } - override fun Tensor.tanh(): DoubleTensor = tensor.map { tanh(it) } + override fun StructureND.tanh(): DoubleTensor = tensor.map { tanh(it) } - override fun Tensor.atanh(): DoubleTensor = tensor.map { atanh(it) } + override fun StructureND.atanh(): DoubleTensor = tensor.map { atanh(it) } - override fun Tensor.ceil(): DoubleTensor = tensor.map { ceil(it) } + override fun StructureND.ceil(): DoubleTensor = tensor.map { ceil(it) } - override fun Tensor.floor(): DoubleTensor = tensor.map { floor(it) } + override fun StructureND.floor(): DoubleTensor = tensor.map { floor(it) } - override fun Tensor.inv(): DoubleTensor = invLU(1e-9) + override fun StructureND.inv(): DoubleTensor = invLU(1e-9) - override fun Tensor.det(): DoubleTensor = detLU(1e-9) + override fun StructureND.det(): DoubleTensor = detLU(1e-9) /** * Computes the LU factorization of a matrix or batches of matrices `input`. @@ -736,7 +735,7 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(epsilon: Double): Pair = + public fun StructureND.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") @@ -749,7 +748,7 @@ public open class DoubleTensorAlgebra : * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ - public fun Tensor.luFactor(): Pair = luFactor(1e-9) + public fun StructureND.luFactor(): Pair = luFactor(1e-9) /** * Unpacks the data and pivots from a LU factorization of a tensor. @@ -763,7 +762,7 @@ public open class DoubleTensorAlgebra : * @return triple of `P`, `L` and `U` tensors */ public fun luPivot( - luTensor: Tensor, + luTensor: StructureND, pivotsTensor: Tensor, ): Triple { checkSquareMatrix(luTensor.shape) @@ -806,7 +805,7 @@ public open class DoubleTensorAlgebra : * Used when checking the positive definiteness of the input matrix or matrices. * @return a pair of `Q` and `R` tensors. */ - public fun Tensor.cholesky(epsilon: Double): DoubleTensor { + public fun StructureND.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) checkPositiveDefinite(tensor, epsilon) @@ -819,9 +818,9 @@ public open class DoubleTensorAlgebra : return lTensor } - override fun Tensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun StructureND.cholesky(): DoubleTensor = cholesky(1e-6) - override fun Tensor.qr(): Pair { + override fun StructureND.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() @@ -837,7 +836,7 @@ public open class DoubleTensorAlgebra : return qTensor to rTensor } - override fun Tensor.svd(): Triple = + override fun StructureND.svd(): Triple = svd(epsilon = 1e-10) /** @@ -853,7 +852,7 @@ public open class DoubleTensorAlgebra : * i.e., the precision with which the cosine approaches 1 in an iterative algorithm. * @return a triple `Triple(U, S, V)`. */ - public fun Tensor.svd(epsilon: Double): Triple { + public fun StructureND.svd(epsilon: Double): Triple { val size = tensor.dimension val commonShape = tensor.shape.sliceArray(0 until size - 2) val (n, m) = tensor.shape.sliceArray(size - 2 until size) @@ -886,7 +885,7 @@ public open class DoubleTensorAlgebra : return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun Tensor.symEig(): Pair = symEig(epsilon = 1e-15) + override fun StructureND.symEig(): Pair = symEig(epsilon = 1e-15) /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -896,7 +895,7 @@ public open class DoubleTensorAlgebra : * and when the cosine approaches 1 in the SVD algorithm. * @return a pair `eigenvalues to eigenvectors`. */ - public fun Tensor.symEig(epsilon: Double): Pair { + public fun StructureND.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) fun MutableStructure2D.cleanSym(n: Int) { @@ -931,7 +930,7 @@ public open class DoubleTensorAlgebra : * with zero. * @return the determinant. */ - public fun Tensor.detLU(epsilon: Double = 1e-9): DoubleTensor { + public fun StructureND.detLU(epsilon: Double = 1e-9): DoubleTensor { checkSquareMatrix(tensor.shape) val luTensor = tensor.copy() val pivotsTensor = tensor.setUpPivots() @@ -964,7 +963,7 @@ public open class DoubleTensorAlgebra : * @param epsilon error in the LU algorithm—permissible error when comparing the determinant of a matrix with zero * @return the multiplicative inverse of a matrix. */ - public fun Tensor.invLU(epsilon: Double = 1e-9): DoubleTensor { + public fun StructureND.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() @@ -989,12 +988,12 @@ public open class DoubleTensorAlgebra : * @param epsilon permissible error when comparing the determinant of a matrix with zero. * @return triple of `P`, `L` and `U` tensors. */ - public fun Tensor.lu(epsilon: Double = 1e-9): Triple { + public fun StructureND.lu(epsilon: Double = 1e-9): Triple { val (lu, pivots) = tensor.luFactor(epsilon) return luPivot(lu, pivots) } - override fun Tensor.lu(): Triple = lu(1e-9) + override fun StructureND.lu(): Triple = lu(1e-9) } public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra