From 15a0258b7d860c861239d26d33528935ce98083c Mon Sep 17 00:00:00 2001 From: Iaroslav Postovalov Date: Fri, 14 May 2021 00:37:33 +0700 Subject: [PATCH] Nd4j based TensorAlgebra implementation, drop Nd4jArrayLongStructure --- .../space/kscience/kmath/nd/AlgebraND.kt | 1 - kmath-nd4j/build.gradle.kts | 20 +- .../kscience/kmath/nd4j/Nd4jArrayAlgebra.kt | 179 ++++++++---------- .../kscience/kmath/nd4j/Nd4jArrayIterator.kt | 6 - .../kscience/kmath/nd4j/Nd4jArrayStructure.kt | 15 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 175 +++++++++++++++++ .../space/kscience/kmath/nd4j/arrays.kt | 1 - .../kmath/nd4j/Nd4jArrayAlgebraTest.kt | 2 +- .../kmath/nd4j/Nd4jArrayStructureTest.kt | 2 +- .../kmath/tensors/api/TensorAlgebra.kt | 22 +-- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 23 +-- 11 files changed, 280 insertions(+), 166 deletions(-) create mode 100644 kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt index 93187eb41..35bbc44f6 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/AlgebraND.kt @@ -227,7 +227,6 @@ public interface RingND> : Ring>, GroupND> : Field>, RingND, ScaleOperations> { diff --git a/kmath-nd4j/build.gradle.kts b/kmath-nd4j/build.gradle.kts index bc61060db..12ea5c130 100644 --- a/kmath-nd4j/build.gradle.kts +++ b/kmath-nd4j/build.gradle.kts @@ -4,7 +4,7 @@ plugins { } dependencies { - api(project(":kmath-core")) + api(project(":kmath-tensors")) api("org.nd4j:nd4j-api:1.0.0-beta7") testImplementation("org.nd4j:nd4j-native:1.0.0-beta7") testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") @@ -15,19 +15,7 @@ readme { description = "ND4J NDStructure implementation and according NDAlgebra classes" maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md")) - - feature( - id = "nd4jarraystructure", - description = "NDStructure wrapper for INDArray" - ) - - feature( - id = "nd4jarrayrings", - description = "Rings over Nd4jArrayStructure of Int and Long" - ) - - feature( - id = "nd4jarrayfields", - description = "Fields over Nd4jArrayStructure of Float and Double" - ) + feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" } + feature(id = "nd4jarrayrings") { "Rings over Nd4jArrayStructure of Int and Long" } + feature(id = "nd4jarrayfields") { "Fields over Nd4jArrayStructure of Float and Double" } } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt index 3cac44c47..e94bda12a 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebra.kt @@ -6,15 +6,14 @@ package space.kscience.kmath.nd4j import org.nd4j.linalg.api.ndarray.INDArray -import org.nd4j.linalg.api.ops.impl.scalar.Pow -import org.nd4j.linalg.api.ops.impl.transforms.strict.* +import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh +import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.nd.* import space.kscience.kmath.operations.* -import space.kscience.kmath.structures.* internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray { val arrayShape = array.shape().toIntArray() @@ -29,23 +28,16 @@ internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray { * @param T the type of ND-structure element. * @param C the type of the element context. */ -public interface Nd4jArrayAlgebra> : AlgebraND { +public sealed interface Nd4jArrayAlgebra> : AlgebraND { /** - * Wraps [INDArray] to [N]. + * Wraps [INDArray] to [Nd4jArrayStructure]. */ public fun INDArray.wrap(): Nd4jArrayStructure + /** + * Unwraps to or acquires [INDArray] from [StructureND]. + */ public val StructureND.ndArray: INDArray - get() = when { - !shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException( - this@Nd4jArrayAlgebra.shape, - shape - ) - this is Nd4jArrayStructure -> ndArray //TODO check strides - else -> { - TODO() - } - } public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure { val struct = Nd4j.create(*shape)!!.wrap() @@ -85,7 +77,7 @@ public interface Nd4jArrayAlgebra> : AlgebraND { * @param T the type of the element contained in ND structure. * @param S the type of space of structure elements. */ -public interface Nd4JArrayGroup> : GroupND, Nd4jArrayAlgebra { +public sealed interface Nd4jArrayGroup> : GroupND, Nd4jArrayAlgebra { public override val zero: Nd4jArrayStructure get() = Nd4j.zeros(*shape).wrap() @@ -110,7 +102,7 @@ public interface Nd4JArrayGroup> : GroupND, Nd4jArrayAlgebr * @param R the type of ring of structure elements. */ @OptIn(UnstableKMathAPI::class) -public interface Nd4jArrayRing> : RingND, Nd4JArrayGroup { +public sealed interface Nd4jArrayRing> : RingND, Nd4jArrayGroup { public override val one: Nd4jArrayStructure get() = Nd4j.ones(*shape).wrap() @@ -135,10 +127,7 @@ public interface Nd4jArrayRing> : RingND, Nd4JArrayGroup> = - ThreadLocal.withInitial { hashMapOf() } - - private val longNd4jArrayRingCache: ThreadLocal> = - ThreadLocal.withInitial { hashMapOf() } + ThreadLocal.withInitial(::HashMap) /** * Creates an [RingND] for [Int] values or pull it from cache if it was created previously. @@ -146,20 +135,13 @@ public interface Nd4jArrayRing> : RingND, Nd4JArrayGroup = intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) } - /** - * Creates an [RingND] for [Long] values or pull it from cache if it was created previously. - */ - public fun long(vararg shape: Int): Nd4jArrayRing = - longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) } - /** * Creates a most suitable implementation of [RingND] using reified class. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { - T::class == Int::class -> int(*shape) as Nd4jArrayRing> - T::class == Long::class -> long(*shape) as Nd4jArrayRing> - else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.") + public inline fun auto(vararg shape: Int): Nd4jArrayRing> = when { + T::class == Int::class -> int(*shape) as Nd4jArrayRing> + else -> throw UnsupportedOperationException("This factory method only supports Long type.") } } } @@ -168,11 +150,9 @@ public interface Nd4jArrayRing> : RingND, Nd4JArrayGroup> : FieldND, Nd4jArrayRing { - +public sealed interface Nd4jArrayField> : FieldND, Nd4jArrayRing { public override fun divide(a: StructureND, b: StructureND): Nd4jArrayStructure = a.ndArray.div(b.ndArray).wrap() @@ -180,10 +160,10 @@ public interface Nd4jArrayField> : FieldND, Nd4jArrayRing< public companion object { private val floatNd4jArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial { hashMapOf() } + ThreadLocal.withInitial(::HashMap) private val doubleNd4JArrayFieldCache: ThreadLocal> = - ThreadLocal.withInitial { hashMapOf() } + ThreadLocal.withInitial(::HashMap) /** * Creates an [FieldND] for [Float] values or pull it from cache if it was created previously. @@ -198,26 +178,64 @@ public interface Nd4jArrayField> : FieldND, Nd4jArrayRing< doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) } /** - * Creates a most suitable implementation of [RingND] using reified class. + * Creates a most suitable implementation of [FieldND] using reified class. */ @Suppress("UNCHECKED_CAST") - public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { - T::class == Float::class -> float(*shape) as Nd4jArrayField> - T::class == Double::class -> real(*shape) as Nd4jArrayField> + public inline fun auto(vararg shape: Int): Nd4jArrayField> = when { + T::class == Float::class -> float(*shape) as Nd4jArrayField> + T::class == Double::class -> real(*shape) as Nd4jArrayField> else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.") } } } +/** + * Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure]. + */ +public sealed interface Nd4jArrayExtendedField> : ExtendedField>, + Nd4jArrayField { + public override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() + public override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() + public override fun asin(arg: StructureND): StructureND = Transforms.asin(arg.ndArray).wrap() + public override fun acos(arg: StructureND): StructureND = Transforms.acos(arg.ndArray).wrap() + public override fun atan(arg: StructureND): StructureND = Transforms.atan(arg.ndArray).wrap() + + public override fun power(arg: StructureND, pow: Number): StructureND = + Transforms.pow(arg.ndArray, pow).wrap() + + public override fun exp(arg: StructureND): StructureND = Transforms.exp(arg.ndArray).wrap() + public override fun ln(arg: StructureND): StructureND = Transforms.log(arg.ndArray).wrap() + public override fun sqrt(arg: StructureND): StructureND = Transforms.sqrt(arg.ndArray).wrap() + public override fun sinh(arg: StructureND): StructureND = Transforms.sinh(arg.ndArray).wrap() + public override fun cosh(arg: StructureND): StructureND = Transforms.cosh(arg.ndArray).wrap() + public override fun tanh(arg: StructureND): StructureND = Transforms.tanh(arg.ndArray).wrap() + + public override fun asinh(arg: StructureND): StructureND = + Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap() + + public override fun acosh(arg: StructureND): StructureND = + Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap() + + public override fun atanh(arg: StructureND): StructureND = Transforms.atanh(arg.ndArray).wrap() +} + /** * Represents [FieldND] over [Nd4jArrayDoubleStructure]. */ -public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField, - ExtendedField> { +public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField { public override val elementContext: DoubleField get() = DoubleField public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asDoubleStructure() + @OptIn(PerformancePitfall::class) + override val StructureND.ndArray: INDArray + get() = when (this) { + is Nd4jArrayStructure -> checkShape(ndArray) + else -> Nd4j.zeros(*shape).also { + elements().forEach { (idx, value) -> it.putScalar(idx, value) } + } + } + override fun scale(a: StructureND, value: Double): Nd4jArrayStructure { return a.ndArray.mul(value).wrap() } @@ -245,34 +263,25 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr public override operator fun Double.minus(arg: StructureND): Nd4jArrayStructure { return arg.ndArray.rsub(this).wrap() } - - override fun sin(arg: StructureND): StructureND = Transforms.sin(arg.ndArray).wrap() - - override fun cos(arg: StructureND): StructureND = Transforms.cos(arg.ndArray).wrap() - - override fun asin(arg: StructureND): StructureND = Transforms.asin(arg.ndArray).wrap() - - override fun acos(arg: StructureND): StructureND = Transforms.acos(arg.ndArray).wrap() - - override fun atan(arg: StructureND): StructureND = Transforms.atan(arg.ndArray).wrap() - - override fun power(arg: StructureND, pow: Number): StructureND = - Transforms.pow(arg.ndArray,pow).wrap() - - override fun exp(arg: StructureND): StructureND = Transforms.exp(arg.ndArray).wrap() - - override fun ln(arg: StructureND): StructureND = Transforms.log(arg.ndArray).wrap() } /** * Represents [FieldND] over [Nd4jArrayStructure] of [Float]. */ -public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField, - ExtendedField> { +public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField { public override val elementContext: FloatField get() = FloatField public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asFloatStructure() + @OptIn(PerformancePitfall::class) + public override val StructureND.ndArray: INDArray + get() = when (this) { + is Nd4jArrayStructure -> checkShape(ndArray) + else -> Nd4j.zeros(*shape).also { + elements().forEach { (idx, value) -> it.putScalar(idx, value) } + } + } + override fun scale(a: StructureND, value: Double): StructureND = a.ndArray.mul(value).wrap() @@ -293,23 +302,6 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra public override operator fun Float.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() - - override fun sin(arg: StructureND): StructureND = Sin(arg.ndArray).z().wrap() - - override fun cos(arg: StructureND): StructureND = Cos(arg.ndArray).z().wrap() - - override fun asin(arg: StructureND): StructureND = ASin(arg.ndArray).z().wrap() - - override fun acos(arg: StructureND): StructureND = ACos(arg.ndArray).z().wrap() - - override fun atan(arg: StructureND): StructureND = ATan(arg.ndArray).z().wrap() - - override fun power(arg: StructureND, pow: Number): StructureND = - Pow(arg.ndArray, pow.toDouble()).z().wrap() - - override fun exp(arg: StructureND): StructureND = Exp(arg.ndArray).z().wrap() - - override fun ln(arg: StructureND): StructureND = Log(arg.ndArray).z().wrap() } /** @@ -321,6 +313,15 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asIntStructure() + @OptIn(PerformancePitfall::class) + public override val StructureND.ndArray: INDArray + get() = when (this) { + is Nd4jArrayStructure -> checkShape(ndArray) + else -> Nd4j.zeros(*shape).also { + elements().forEach { (idx, value) -> it.putScalar(idx, value) } + } + } + public override operator fun StructureND.plus(arg: Int): Nd4jArrayStructure = ndArray.add(arg).wrap() @@ -333,25 +334,3 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi public override operator fun Int.minus(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rsub(this).wrap() } - -/** - * Represents [RingND] over [Nd4jArrayStructure] of [Long]. - */ -public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing { - public override val elementContext: LongRing - get() = LongRing - - public override fun INDArray.wrap(): Nd4jArrayStructure = checkShape(this).asLongStructure() - - public override operator fun StructureND.plus(arg: Long): Nd4jArrayStructure = - ndArray.add(arg).wrap() - - public override operator fun StructureND.minus(arg: Long): Nd4jArrayStructure = - ndArray.sub(arg).wrap() - - public override operator fun StructureND.times(arg: Long): Nd4jArrayStructure = - ndArray.mul(arg).wrap() - - public override operator fun Long.minus(arg: StructureND): Nd4jArrayStructure = - arg.ndArray.rsub(this).wrap() -} diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt index 71887ca3c..140a212f8 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayIterator.kt @@ -48,12 +48,6 @@ private class Nd4jArrayDoubleIterator(iterateOver: INDArray) : Nd4jArrayIterator internal fun INDArray.realIterator(): Iterator> = Nd4jArrayDoubleIterator(this) -private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { - override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices) -} - -internal fun INDArray.longIterator(): Iterator> = Nd4jArrayLongIterator(this) - private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase(iterateOver) { override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray()) } diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt index d37c91cb6..ffddcef90 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructure.kt @@ -17,7 +17,8 @@ import space.kscience.kmath.nd.StructureND */ public sealed class Nd4jArrayStructure : MutableStructureND { /** - * The wrapped [INDArray]. + * The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming that the size of [INDArray] is less or equal to + * [Int.MAX_VALUE]. */ public abstract val ndArray: INDArray @@ -25,6 +26,7 @@ public sealed class Nd4jArrayStructure : MutableStructureND { internal abstract fun elementsIterator(): Iterator> internal fun indicesIterator(): Iterator = ndArray.indicesIterator() + @PerformancePitfall public override fun elements(): Sequence> = Sequence(::elementsIterator) } @@ -40,17 +42,6 @@ private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jA */ public fun INDArray.asIntStructure(): Nd4jArrayStructure = Nd4jArrayIntStructure(this) -private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { - override fun elementsIterator(): Iterator> = ndArray.longIterator() - override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray()) - override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) } -} - -/** - * Wraps this [INDArray] to [Nd4jArrayStructure]. - */ -public fun INDArray.asLongStructure(): Nd4jArrayStructure = Nd4jArrayLongStructure(this) - private data class Nd4jArrayDoubleStructure(override val ndArray: INDArray) : Nd4jArrayStructure() { override fun elementsIterator(): Iterator> = ndArray.realIterator() override fun get(index: IntArray): Double = ndArray.getDouble(*index) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt new file mode 100644 index 000000000..456f7c2a9 --- /dev/null +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -0,0 +1,175 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.nd4j + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.api.ops.impl.summarystats.Variance +import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh +import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh +import org.nd4j.linalg.factory.Nd4j +import org.nd4j.linalg.factory.ops.NDBase +import org.nd4j.linalg.ops.transforms.Transforms +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra +import space.kscience.kmath.tensors.api.Tensor +import space.kscience.kmath.tensors.api.TensorAlgebra +import space.kscience.kmath.tensors.core.DoubleTensorAlgebra + +/** + * ND4J based [TensorAlgebra] implementation. + */ +public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra { + /** + * Wraps [INDArray] to [Nd4jArrayStructure]. + */ + public fun INDArray.wrap(): Nd4jArrayStructure + + /** + * Unwraps to or acquires [INDArray] from [StructureND]. + */ + public val StructureND.ndArray: INDArray + + public override fun T.plus(other: Tensor): Tensor = other.ndArray.add(this).wrap() + public override fun Tensor.plus(value: T): Tensor = ndArray.add(value).wrap() + + public override fun Tensor.plus(other: Tensor): Tensor = ndArray.add(other.ndArray).wrap() + + public override fun Tensor.plusAssign(value: T) { + ndArray.addi(value) + } + + public override fun Tensor.plusAssign(other: Tensor) { + ndArray.addi(other.ndArray) + } + + public override fun T.minus(other: Tensor): Tensor = other.ndArray.rsub(this).wrap() + public override fun Tensor.minus(value: T): Tensor = ndArray.sub(value).wrap() + public override fun Tensor.minus(other: Tensor): Tensor = ndArray.sub(other.ndArray).wrap() + + public override fun Tensor.minusAssign(value: T) { + ndArray.rsubi(value) + } + + public override fun Tensor.minusAssign(other: Tensor) { + ndArray.subi(other.ndArray) + } + + public override fun T.times(other: Tensor): Tensor = other.ndArray.mul(this).wrap() + + public override fun Tensor.times(value: T): Tensor = + ndArray.mul(value).wrap() + + public override fun Tensor.times(other: Tensor): Tensor = ndArray.mul(other.ndArray).wrap() + + public override fun Tensor.timesAssign(value: T) { + ndArray.muli(value) + } + + public override fun Tensor.timesAssign(other: Tensor) { + ndArray.mmuli(other.ndArray) + } + + public override fun Tensor.unaryMinus(): Tensor = ndArray.neg().wrap() + public override fun Tensor.get(i: Int): Tensor = ndArray.slice(i.toLong()).wrap() + public override fun Tensor.transpose(i: Int, j: Int): Tensor = ndArray.swapAxes(i, j).wrap() + public override fun Tensor.dot(other: Tensor): Tensor = ndArray.mmul(other.ndArray).wrap() + + public override fun Tensor.min(dim: Int, keepDim: Boolean): Tensor = + ndArray.min(keepDim, dim).wrap() + + public override fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor = + ndArray.sum(keepDim, dim).wrap() + + public override fun Tensor.max(dim: Int, keepDim: Boolean): Tensor = + ndArray.max(keepDim, dim).wrap() + + public override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() + public override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) + + public override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + ndBase.get().argmax(ndArray, keepDim, dim).wrap() + + public override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() + + public override fun Tensor.exp(): Tensor = Transforms.exp(ndArray).wrap() + public override fun Tensor.ln(): Tensor = Transforms.log(ndArray).wrap() + public override fun Tensor.sqrt(): Tensor = Transforms.sqrt(ndArray).wrap() + public override fun Tensor.cos(): Tensor = Transforms.cos(ndArray).wrap() + public override fun Tensor.acos(): Tensor = Transforms.acos(ndArray).wrap() + public override fun Tensor.cosh(): Tensor = Transforms.cosh(ndArray).wrap() + + public override fun Tensor.acosh(): Tensor = + Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() + + public override fun Tensor.sin(): Tensor = Transforms.sin(ndArray).wrap() + public override fun Tensor.asin(): Tensor = Transforms.asin(ndArray).wrap() + public override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() + + public override fun Tensor.asinh(): Tensor = + Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() + + public override fun Tensor.tan(): Tensor = Transforms.tan(ndArray).wrap() + public override fun Tensor.atan(): Tensor = Transforms.atan(ndArray).wrap() + public override fun Tensor.tanh(): Tensor = Transforms.tanh(ndArray).wrap() + public override fun Tensor.atanh(): Tensor = Transforms.atanh(ndArray).wrap() + public override fun Tensor.ceil(): Tensor = Transforms.ceil(ndArray).wrap() + public override fun Tensor.floor(): Tensor = Transforms.floor(ndArray).wrap() + public override fun Tensor.std(dim: Int, keepDim: Boolean): Tensor = ndArray.std(true, keepDim, dim).wrap() + public override fun T.div(other: Tensor): Tensor = other.ndArray.rdiv(this).wrap() + public override fun Tensor.div(value: T): Tensor = ndArray.div(value).wrap() + public override fun Tensor.div(other: Tensor): Tensor = ndArray.div(other.ndArray).wrap() + + public override fun Tensor.divAssign(value: T) { + ndArray.divi(value) + } + + public override fun Tensor.divAssign(other: Tensor) { + ndArray.divi(other.ndArray) + } + + public override fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor = + Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() + + private companion object { + private val ndBase: ThreadLocal = ThreadLocal.withInitial(::NDBase) + } +} + +/** + * [Double] specialization of [Nd4jTensorAlgebra]. + */ +public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { + public override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() + + @OptIn(PerformancePitfall::class) + public override val StructureND.ndArray: INDArray + get() = when (this) { + is Nd4jArrayStructure -> ndArray + else -> Nd4j.zeros(*shape).also { + elements().forEach { (idx, value) -> it.putScalar(idx, value) } + } + } + + public override fun Tensor.valueOrNull(): Double? = + if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null + + // TODO rewrite + @PerformancePitfall + public override fun diagonalEmbedding( + diagonalEntries: Tensor, + offset: Int, + dim1: Int, + dim2: Int, + ): Tensor = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) + + public override fun Tensor.sum(): Double = ndArray.sumNumber().toDouble() + public override fun Tensor.min(): Double = ndArray.minNumber().toDouble() + public override fun Tensor.max(): Double = ndArray.maxNumber().toDouble() + public override fun Tensor.mean(): Double = ndArray.meanNumber().toDouble() + public override fun Tensor.std(): Double = ndArray.stdNumber().toDouble() + public override fun Tensor.variance(): Double = ndArray.varNumber().toDouble() +} diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt index 6c414cc13..75a334ca7 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/arrays.kt @@ -5,5 +5,4 @@ package space.kscience.kmath.nd4j -internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() } internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() } diff --git a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt index c80a1d771..40da22763 100644 --- a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt +++ b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayAlgebraTest.kt @@ -52,7 +52,7 @@ internal class Nd4jArrayAlgebraTest { @Test fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke { - val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 } + val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 } val transformed = sin(initial) val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 } diff --git a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt index bdfec2ba6..30d01338f 100644 --- a/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt +++ b/kmath-nd4j/src/test/kotlin/space/kscience/kmath/nd4j/Nd4jArrayStructureTest.kt @@ -72,7 +72,7 @@ internal class Nd4jArrayStructureTest { @Test fun testSet() { val nd = Nd4j.rand(17, 12, 4, 8)!! - val struct = nd.asLongStructure() + val struct = nd.asIntStructure() struct[intArrayOf(1, 2, 3, 4)] = 777 assertEquals(777, struct[1, 2, 3, 4]) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 2eb18ada6..62b8ef046 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra * * @param T the type of items in the tensors. */ -public interface TensorAlgebra: Algebra> { +public interface TensorAlgebra : Algebra> { /** * Returns a single tensor value of unit dimension if tensor shape equals to [1]. @@ -27,7 +27,8 @@ public interface TensorAlgebra: Algebra> { * * @return the value of a scalar tensor. */ - public fun Tensor.value(): T + public fun Tensor.value(): T = + valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") /** * Each element of the tensor [other] is added to this value. @@ -60,15 +61,14 @@ public interface TensorAlgebra: Algebra> { * * @param value the number to be added to each element of this tensor. */ - public operator fun Tensor.plusAssign(value: T): Unit + public operator fun Tensor.plusAssign(value: T) /** * Each element of the tensor [other] is added to each element of this tensor. * * @param other tensor to be added. */ - public operator fun Tensor.plusAssign(other: Tensor): Unit - + public operator fun Tensor.plusAssign(other: Tensor) /** * Each element of the tensor [other] is subtracted from this value. @@ -101,14 +101,14 @@ public interface TensorAlgebra: Algebra> { * * @param value the number to be subtracted from each element of this tensor. */ - public operator fun Tensor.minusAssign(value: T): Unit + public operator fun Tensor.minusAssign(value: T) /** * Each element of the tensor [other] is subtracted from each element of this tensor. * * @param other tensor to be subtracted. */ - public operator fun Tensor.minusAssign(other: Tensor): Unit + public operator fun Tensor.minusAssign(other: Tensor) /** @@ -142,14 +142,14 @@ public interface TensorAlgebra: Algebra> { * * @param value the number to be multiplied by each element of this tensor. */ - public operator fun Tensor.timesAssign(value: T): Unit + public operator fun Tensor.timesAssign(value: T) /** * Each element of the tensor [other] is multiplied by each element of this tensor. * * @param other tensor to be multiplied. */ - public operator fun Tensor.timesAssign(other: Tensor): Unit + public operator fun Tensor.timesAssign(other: Tensor) /** * Numerical negative, element-wise. @@ -217,7 +217,7 @@ public interface TensorAlgebra: Algebra> { * a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. * If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix * multiple and removed after. - * The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). + * The non-matrix (i.e. batch) dimensions are broadcast (and thus must be broadcastable). * For example, if `input` is a (j × 1 × n × n) tensor and `other` is a * (k × n × n) tensor, out will be a (j × k × n × n) tensor. * @@ -255,7 +255,7 @@ public interface TensorAlgebra: Algebra> { diagonalEntries: Tensor, offset: Int = 0, dim1: Int = -2, - dim2: Int = -1 + dim2: Int = -1, ): Tensor /** diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 78989f1f3..1fd46bd57 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -9,20 +9,9 @@ import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra -import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra import space.kscience.kmath.tensors.api.Tensor -import space.kscience.kmath.tensors.core.internal.dotHelper -import space.kscience.kmath.tensors.core.internal.getRandomNormals +import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra import space.kscience.kmath.tensors.core.internal.* -import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors -import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency -import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer -import space.kscience.kmath.tensors.core.internal.checkEmptyShape -import space.kscience.kmath.tensors.core.internal.checkShapesCompatible -import space.kscience.kmath.tensors.core.internal.checkSquareMatrix -import space.kscience.kmath.tensors.core.internal.checkTranspose -import space.kscience.kmath.tensors.core.internal.checkView -import space.kscience.kmath.tensors.core.internal.minusIndexFrom import kotlin.math.* /** @@ -38,8 +27,8 @@ public open class DoubleTensorAlgebra : override fun Tensor.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) tensor.mutableBuffer.array()[tensor.bufferStart] else null - override fun Tensor.value(): Double = - valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") + override fun Tensor.value(): Double = valueOrNull() + ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") /** * Constructs a tensor with the specified shape and data. @@ -466,7 +455,7 @@ public open class DoubleTensorAlgebra : private fun Tensor.eq( other: Tensor, - eqFunction: (Double, Double) -> Boolean + eqFunction: (Double, Double) -> Boolean, ): Boolean { checkShapesCompatible(tensor, other) val n = tensor.numElements @@ -540,7 +529,7 @@ public open class DoubleTensorAlgebra : internal fun Tensor.foldDim( foldFunction: (DoubleArray) -> Double, dim: Int, - keepDim: Boolean + keepDim: Boolean, ): DoubleTensor { check(dim < dimension) { "Dimension $dim out of range $dimension" } val resShape = if (keepDim) { @@ -729,7 +718,7 @@ public open class DoubleTensorAlgebra : */ public fun luPivot( luTensor: Tensor, - pivotsTensor: Tensor + pivotsTensor: Tensor, ): Triple { checkSquareMatrix(luTensor.shape) check(