From 4635cd3fb3eaab8aac97936b50a08d463e7f91c6 Mon Sep 17 00:00:00 2001 From: darksnake Date: Tue, 26 Oct 2021 16:08:02 +0300 Subject: [PATCH] Refactor TensorAlgebra to take StructureND and inherit AlgebraND --- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 123 +++++++++++------- .../api/TensorPartialDivisionAlgebra.kt | 2 +- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 2 +- 3 files changed, 76 insertions(+), 51 deletions(-) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index ee9251dd1..e40ea3dbc 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -13,7 +13,10 @@ import org.nd4j.linalg.factory.Nd4j import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.ops.transforms.Transforms import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.StructureND +import space.kscience.kmath.operations.DoubleField +import space.kscience.kmath.operations.Field import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra @@ -22,7 +25,24 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra /** * ND4J based [TensorAlgebra] implementation. */ -public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra { +public sealed interface Nd4jTensorAlgebra> : AnalyticTensorAlgebra { + + override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure { + val array = + } + + override fun StructureND.map(transform: A.(T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + + override fun StructureND.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + + override fun zip(left: StructureND, right: StructureND, transform: A.(T, T) -> T): Nd4jArrayStructure { + TODO("Not yet implemented") + } + /** * Wraps [INDArray] to [Nd4jArrayStructure]. */ @@ -33,105 +53,107 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra */ public val StructureND.ndArray: INDArray - override fun T.plus(other: Tensor): Tensor = other.ndArray.add(this).wrap() - override fun Tensor.plus(value: T): Tensor = ndArray.add(value).wrap() + override fun T.plus(other: StructureND): Nd4jArrayStructure = other.ndArray.add(this).wrap() + override fun StructureND.plus(value: T): Nd4jArrayStructure = ndArray.add(value).wrap() - override fun Tensor.plus(other: Tensor): Tensor = ndArray.add(other.ndArray).wrap() + override fun StructureND.plus(other: StructureND): Nd4jArrayStructure = ndArray.add(other.ndArray).wrap() override fun Tensor.plusAssign(value: T) { ndArray.addi(value) } - override fun Tensor.plusAssign(other: Tensor) { + override fun Tensor.plusAssign(other: StructureND) { ndArray.addi(other.ndArray) } - override fun T.minus(other: Tensor): Tensor = other.ndArray.rsub(this).wrap() - override fun Tensor.minus(value: T): Tensor = ndArray.sub(value).wrap() - override fun Tensor.minus(other: Tensor): Tensor = ndArray.sub(other.ndArray).wrap() + override fun T.minus(other: StructureND): Nd4jArrayStructure = other.ndArray.rsub(this).wrap() + override fun StructureND.minus(arg: T): Nd4jArrayStructure = ndArray.sub(arg).wrap() + override fun StructureND.minus(other: StructureND): Nd4jArrayStructure = ndArray.sub(other.ndArray).wrap() override fun Tensor.minusAssign(value: T) { ndArray.rsubi(value) } - override fun Tensor.minusAssign(other: Tensor) { + override fun Tensor.minusAssign(other: StructureND) { ndArray.subi(other.ndArray) } - override fun T.times(other: Tensor): Tensor = other.ndArray.mul(this).wrap() + override fun T.times(arg: StructureND): Nd4jArrayStructure = arg.ndArray.mul(this).wrap() - override fun Tensor.times(value: T): Tensor = - ndArray.mul(value).wrap() + override fun StructureND.times(arg: T): Nd4jArrayStructure = + ndArray.mul(arg).wrap() - override fun Tensor.times(other: Tensor): Tensor = ndArray.mul(other.ndArray).wrap() + override fun StructureND.times(other: StructureND): Nd4jArrayStructure = ndArray.mul(other.ndArray).wrap() override fun Tensor.timesAssign(value: T) { ndArray.muli(value) } - override fun Tensor.timesAssign(other: Tensor) { + override fun Tensor.timesAssign(other: StructureND) { ndArray.mmuli(other.ndArray) } - override fun Tensor.unaryMinus(): Tensor = ndArray.neg().wrap() - override fun Tensor.get(i: Int): Tensor = ndArray.slice(i.toLong()).wrap() - override fun Tensor.transpose(i: Int, j: Int): Tensor = ndArray.swapAxes(i, j).wrap() - override fun Tensor.dot(other: Tensor): Tensor = ndArray.mmul(other.ndArray).wrap() + override fun StructureND.unaryMinus(): Nd4jArrayStructure = ndArray.neg().wrap() + override fun Tensor.get(i: Int): Nd4jArrayStructure = ndArray.slice(i.toLong()).wrap() + override fun Tensor.transpose(i: Int, j: Int): Nd4jArrayStructure = ndArray.swapAxes(i, j).wrap() + override fun Tensor.dot(other: Tensor): Nd4jArrayStructure = ndArray.mmul(other.ndArray).wrap() - override fun Tensor.min(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.min(keepDim, dim).wrap() - override fun Tensor.sum(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.sum(keepDim, dim).wrap() - override fun Tensor.max(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.max(keepDim, dim).wrap() - override fun Tensor.view(shape: IntArray): Tensor = ndArray.reshape(shape).wrap() - override fun Tensor.viewAs(other: Tensor): Tensor = view(other.shape) + override fun Tensor.view(shape: IntArray): Nd4jArrayStructure = ndArray.reshape(shape).wrap() + override fun Tensor.viewAs(other: Tensor): Nd4jArrayStructure = view(other.shape) - override fun Tensor.argMax(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndBase.get().argmax(ndArray, keepDim, dim).wrap() - override fun Tensor.mean(dim: Int, keepDim: Boolean): Tensor = ndArray.mean(keepDim, dim).wrap() + override fun Tensor.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure = ndArray.mean(keepDim, dim).wrap() - override fun Tensor.exp(): Tensor = Transforms.exp(ndArray).wrap() - override fun Tensor.ln(): Tensor = Transforms.log(ndArray).wrap() - override fun Tensor.sqrt(): Tensor = Transforms.sqrt(ndArray).wrap() - override fun Tensor.cos(): Tensor = Transforms.cos(ndArray).wrap() - override fun Tensor.acos(): Tensor = Transforms.acos(ndArray).wrap() - override fun Tensor.cosh(): Tensor = Transforms.cosh(ndArray).wrap() + override fun Tensor.exp(): Nd4jArrayStructure = Transforms.exp(ndArray).wrap() + override fun Tensor.ln(): Nd4jArrayStructure = Transforms.log(ndArray).wrap() + override fun Tensor.sqrt(): Nd4jArrayStructure = Transforms.sqrt(ndArray).wrap() + override fun Tensor.cos(): Nd4jArrayStructure = Transforms.cos(ndArray).wrap() + override fun Tensor.acos(): Nd4jArrayStructure = Transforms.acos(ndArray).wrap() + override fun Tensor.cosh(): Nd4jArrayStructure = Transforms.cosh(ndArray).wrap() - override fun Tensor.acosh(): Tensor = + override fun Tensor.acosh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.sin(): Tensor = Transforms.sin(ndArray).wrap() - override fun Tensor.asin(): Tensor = Transforms.asin(ndArray).wrap() + override fun Tensor.sin(): Nd4jArrayStructure = Transforms.sin(ndArray).wrap() + override fun Tensor.asin(): Nd4jArrayStructure = Transforms.asin(ndArray).wrap() override fun Tensor.sinh(): Tensor = Transforms.sinh(ndArray).wrap() - override fun Tensor.asinh(): Tensor = + override fun Tensor.asinh(): Nd4jArrayStructure = Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() - override fun Tensor.tan(): Tensor = Transforms.tan(ndArray).wrap() - override fun Tensor.atan(): Tensor = Transforms.atan(ndArray).wrap() - override fun Tensor.tanh(): Tensor = Transforms.tanh(ndArray).wrap() - override fun Tensor.atanh(): Tensor = Transforms.atanh(ndArray).wrap() - override fun Tensor.ceil(): Tensor = Transforms.ceil(ndArray).wrap() - override fun Tensor.floor(): Tensor = Transforms.floor(ndArray).wrap() - override fun Tensor.std(dim: Int, keepDim: Boolean): Tensor = ndArray.std(true, keepDim, dim).wrap() - override fun T.div(other: Tensor): Tensor = other.ndArray.rdiv(this).wrap() - override fun Tensor.div(value: T): Tensor = ndArray.div(value).wrap() - override fun Tensor.div(other: Tensor): Tensor = ndArray.div(other.ndArray).wrap() + override fun Tensor.tan(): Nd4jArrayStructure = Transforms.tan(ndArray).wrap() + override fun Tensor.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() + override fun Tensor.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() + override fun Tensor.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun Tensor.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() + override fun Tensor.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() + override fun Tensor.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = + ndArray.std(true, keepDim, dim).wrap() + + override fun T.div(arg: StructureND): Nd4jArrayStructure = arg.ndArray.rdiv(this).wrap() + override fun StructureND.div(arg: T): Nd4jArrayStructure = ndArray.div(arg).wrap() + override fun StructureND.div(other: StructureND): Nd4jArrayStructure = ndArray.div(other.ndArray).wrap() override fun Tensor.divAssign(value: T) { ndArray.divi(value) } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { ndArray.divi(other.ndArray) } - override fun Tensor.variance(dim: Int, keepDim: Boolean): Tensor = + override fun Tensor.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure = Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() private companion object { @@ -142,7 +164,10 @@ public sealed interface Nd4jTensorAlgebra : AnalyticTensorAlgebra /** * [Double] specialization of [Nd4jTensorAlgebra]. */ -public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { +public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { + + override val elementAlgebra: DoubleField get() = DoubleField + override fun INDArray.wrap(): Nd4jArrayStructure = asDoubleStructure() @OptIn(PerformancePitfall::class) @@ -154,7 +179,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra { } } - override fun Tensor.valueOrNull(): Double? = + override fun StructureND.valueOrNull(): Double? = if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null // TODO rewrite diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index 43c901ed5..dece54834 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -57,5 +57,5 @@ public interface TensorPartialDivisionAlgebra> : TensorAlgebra.divAssign(other: Tensor) + public operator fun Tensor.divAssign(other: StructureND) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 472db8f5f..3d343604b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -330,7 +330,7 @@ public open class DoubleTensorAlgebra : } } - override fun Tensor.divAssign(other: Tensor) { + override fun Tensor.divAssign(other: StructureND) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /=