Refactor TensorAlgebra to take StructureND and inherit AlgebraND

This commit is contained in:
Alexander Nozik 2021-10-26 16:08:02 +03:00
parent 7e59ec5804
commit 4635cd3fb3
3 changed files with 76 additions and 51 deletions

View File

@ -13,7 +13,10 @@ import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.factory.ops.NDBase
import org.nd4j.linalg.ops.transforms.Transforms import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Field
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
@ -22,7 +25,24 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
/** /**
* ND4J based [TensorAlgebra] implementation. * ND4J based [TensorAlgebra] implementation.
*/ */
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> { public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> {
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T> {
val array =
}
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
/** /**
* Wraps [INDArray] to [Nd4jArrayStructure]. * Wraps [INDArray] to [Nd4jArrayStructure].
*/ */
@ -33,105 +53,107 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap() override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap()
override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap() override fun StructureND<T>.plus(value: T): Nd4jArrayStructure<T> = ndArray.add(value).wrap()
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap() override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap()
override fun Tensor<T>.plusAssign(value: T) { override fun Tensor<T>.plusAssign(value: T) {
ndArray.addi(value) ndArray.addi(value)
} }
override fun Tensor<T>.plusAssign(other: Tensor<T>) { override fun Tensor<T>.plusAssign(other: StructureND<T>) {
ndArray.addi(other.ndArray) ndArray.addi(other.ndArray)
} }
override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap() override fun T.minus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.rsub(this).wrap()
override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap() override fun StructureND<T>.minus(arg: T): Nd4jArrayStructure<T> = ndArray.sub(arg).wrap()
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap() override fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.sub(other.ndArray).wrap()
override fun Tensor<T>.minusAssign(value: T) { override fun Tensor<T>.minusAssign(value: T) {
ndArray.rsubi(value) ndArray.rsubi(value)
} }
override fun Tensor<T>.minusAssign(other: Tensor<T>) { override fun Tensor<T>.minusAssign(other: StructureND<T>) {
ndArray.subi(other.ndArray) ndArray.subi(other.ndArray)
} }
override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap() override fun T.times(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.mul(this).wrap()
override fun Tensor<T>.times(value: T): Tensor<T> = override fun StructureND<T>.times(arg: T): Nd4jArrayStructure<T> =
ndArray.mul(value).wrap() ndArray.mul(arg).wrap()
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap() override fun StructureND<T>.times(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mul(other.ndArray).wrap()
override fun Tensor<T>.timesAssign(value: T) { override fun Tensor<T>.timesAssign(value: T) {
ndArray.muli(value) ndArray.muli(value)
} }
override fun Tensor<T>.timesAssign(other: Tensor<T>) { override fun Tensor<T>.timesAssign(other: StructureND<T>) {
ndArray.mmuli(other.ndArray) ndArray.mmuli(other.ndArray)
} }
override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap() override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap() override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap() override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap() override fun Tensor<T>.dot(other: Tensor<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.min(keepDim, dim).wrap() ndArray.min(keepDim, dim).wrap()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.sum(keepDim, dim).wrap() ndArray.sum(keepDim, dim).wrap()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.max(keepDim, dim).wrap() ndArray.max(keepDim, dim).wrap()
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap() override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape) override fun Tensor<T>.viewAs(other: Tensor<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap() ndBase.get().argmax(ndArray, keepDim, dim).wrap()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap() override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = ndArray.mean(keepDim, dim).wrap()
override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap() override fun Tensor<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap() override fun Tensor<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap() override fun Tensor<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap() override fun Tensor<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap() override fun Tensor<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap() override fun Tensor<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
override fun Tensor<T>.acosh(): Tensor<T> = override fun Tensor<T>.acosh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap() override fun Tensor<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap() override fun Tensor<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap() override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
override fun Tensor<T>.asinh(): Tensor<T> = override fun Tensor<T>.asinh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap() override fun Tensor<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap() override fun Tensor<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap() override fun Tensor<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap() override fun Tensor<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap() override fun Tensor<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap() override fun Tensor<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap() override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap() ndArray.std(true, keepDim, dim).wrap()
override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap() override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap()
override fun StructureND<T>.div(arg: T): Nd4jArrayStructure<T> = ndArray.div(arg).wrap()
override fun StructureND<T>.div(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.div(other.ndArray).wrap()
override fun Tensor<T>.divAssign(value: T) { override fun Tensor<T>.divAssign(value: T) {
ndArray.divi(value) ndArray.divi(value)
} }
override fun Tensor<T>.divAssign(other: Tensor<T>) { override fun Tensor<T>.divAssign(other: StructureND<T>) {
ndArray.divi(other.ndArray) ndArray.divi(other.ndArray)
} }
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> = override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
private companion object { private companion object {
@ -142,7 +164,10 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
/** /**
* [Double] specialization of [Nd4jTensorAlgebra]. * [Double] specialization of [Nd4jTensorAlgebra].
*/ */
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> { public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
@ -154,7 +179,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
} }
} }
override fun Tensor<Double>.valueOrNull(): Double? = override fun StructureND<Double>.valueOrNull(): Double? =
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
// TODO rewrite // TODO rewrite

View File

@ -57,5 +57,5 @@ public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T
* *
* @param other tensor to be divided by. * @param other tensor to be divided by.
*/ */
public operator fun Tensor<T>.divAssign(other: Tensor<T>) public operator fun Tensor<T>.divAssign(other: StructureND<T>)
} }

View File

@ -330,7 +330,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.divAssign(other: Tensor<Double>) { override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=