Refactor TensorAlgebra to take StructureND and inherit AlgebraND

This commit is contained in:
Alexander Nozik 2021-10-26 16:08:02 +03:00
parent 7e59ec5804
commit 4635cd3fb3
3 changed files with 76 additions and 51 deletions

View File

@ -13,7 +13,10 @@ import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.ops.NDBase
import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Field
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
@ -22,7 +25,24 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
/**
* ND4J based [TensorAlgebra] implementation.
*/
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> {
public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> {
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T> {
val array =
}
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
/**
* Wraps [INDArray] to [Nd4jArrayStructure].
*/
@ -33,105 +53,107 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/
public val StructureND<T>.ndArray: INDArray
override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap()
override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap()
override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap()
override fun StructureND<T>.plus(value: T): Nd4jArrayStructure<T> = ndArray.add(value).wrap()
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap()
override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap()
override fun Tensor<T>.plusAssign(value: T) {
ndArray.addi(value)
}
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
override fun Tensor<T>.plusAssign(other: StructureND<T>) {
ndArray.addi(other.ndArray)
}
override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap()
override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap()
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap()
override fun T.minus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.rsub(this).wrap()
override fun StructureND<T>.minus(arg: T): Nd4jArrayStructure<T> = ndArray.sub(arg).wrap()
override fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.sub(other.ndArray).wrap()
override fun Tensor<T>.minusAssign(value: T) {
ndArray.rsubi(value)
}
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
override fun Tensor<T>.minusAssign(other: StructureND<T>) {
ndArray.subi(other.ndArray)
}
override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap()
override fun T.times(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.mul(this).wrap()
override fun Tensor<T>.times(value: T): Tensor<T> =
ndArray.mul(value).wrap()
override fun StructureND<T>.times(arg: T): Nd4jArrayStructure<T> =
ndArray.mul(arg).wrap()
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap()
override fun StructureND<T>.times(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mul(other.ndArray).wrap()
override fun Tensor<T>.timesAssign(value: T) {
ndArray.muli(value)
}
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
override fun Tensor<T>.timesAssign(other: StructureND<T>) {
ndArray.mmuli(other.ndArray)
}
override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap()
override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap()
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
override fun Tensor<T>.dot(other: Tensor<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> =
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.min(keepDim, dim).wrap()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> =
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.sum(keepDim, dim).wrap()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> =
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.max(keepDim, dim).wrap()
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: Tensor<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = ndArray.mean(keepDim, dim).wrap()
override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap()
override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap()
override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap()
override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap()
override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap()
override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap()
override fun Tensor<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
override fun Tensor<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
override fun Tensor<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
override fun Tensor<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
override fun Tensor<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
override fun Tensor<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
override fun Tensor<T>.acosh(): Tensor<T> =
override fun Tensor<T>.acosh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap()
override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap()
override fun Tensor<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
override fun Tensor<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
override fun Tensor<T>.asinh(): Tensor<T> =
override fun Tensor<T>.asinh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap()
override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap()
override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap()
override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap()
override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap()
override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap()
override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap()
override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap()
override fun Tensor<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
override fun Tensor<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun Tensor<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun Tensor<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun Tensor<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun Tensor<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.std(true, keepDim, dim).wrap()
override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap()
override fun StructureND<T>.div(arg: T): Nd4jArrayStructure<T> = ndArray.div(arg).wrap()
override fun StructureND<T>.div(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.div(other.ndArray).wrap()
override fun Tensor<T>.divAssign(value: T) {
ndArray.divi(value)
}
override fun Tensor<T>.divAssign(other: Tensor<T>) {
override fun Tensor<T>.divAssign(other: StructureND<T>) {
ndArray.divi(other.ndArray)
}
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> =
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
private companion object {
@ -142,7 +164,10 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
/**
* [Double] specialization of [Nd4jTensorAlgebra].
*/
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class)
@ -154,7 +179,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
}
}
override fun Tensor<Double>.valueOrNull(): Double? =
override fun StructureND<Double>.valueOrNull(): Double? =
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
// TODO rewrite

View File

@ -57,5 +57,5 @@ public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T
*
* @param other tensor to be divided by.
*/
public operator fun Tensor<T>.divAssign(other: Tensor<T>)
public operator fun Tensor<T>.divAssign(other: StructureND<T>)
}

View File

@ -330,7 +330,7 @@ public open class DoubleTensorAlgebra :
}
}
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=