forked from kscience/kmath
Refactor TensorAlgebra to take StructureND and inherit AlgebraND
This commit is contained in:
parent
7e59ec5804
commit
4635cd3fb3
@ -13,7 +13,10 @@ import org.nd4j.linalg.factory.Nd4j
|
|||||||
import org.nd4j.linalg.factory.ops.NDBase
|
import org.nd4j.linalg.factory.ops.NDBase
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms
|
import org.nd4j.linalg.ops.transforms.Transforms
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.Field
|
||||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
@ -22,7 +25,24 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
|||||||
/**
|
/**
|
||||||
* ND4J based [TensorAlgebra] implementation.
|
* ND4J based [TensorAlgebra] implementation.
|
||||||
*/
|
*/
|
||||||
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> {
|
public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> {
|
||||||
|
|
||||||
|
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||||
|
val array =
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Wraps [INDArray] to [Nd4jArrayStructure].
|
* Wraps [INDArray] to [Nd4jArrayStructure].
|
||||||
*/
|
*/
|
||||||
@ -33,105 +53,107 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
*/
|
*/
|
||||||
public val StructureND<T>.ndArray: INDArray
|
public val StructureND<T>.ndArray: INDArray
|
||||||
|
|
||||||
override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap()
|
override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap()
|
||||||
override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap()
|
override fun StructureND<T>.plus(value: T): Nd4jArrayStructure<T> = ndArray.add(value).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap()
|
override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(value: T) {
|
override fun Tensor<T>.plusAssign(value: T) {
|
||||||
ndArray.addi(value)
|
ndArray.addi(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
|
override fun Tensor<T>.plusAssign(other: StructureND<T>) {
|
||||||
ndArray.addi(other.ndArray)
|
ndArray.addi(other.ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap()
|
override fun T.minus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.rsub(this).wrap()
|
||||||
override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap()
|
override fun StructureND<T>.minus(arg: T): Nd4jArrayStructure<T> = ndArray.sub(arg).wrap()
|
||||||
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap()
|
override fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.sub(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(value: T) {
|
override fun Tensor<T>.minusAssign(value: T) {
|
||||||
ndArray.rsubi(value)
|
ndArray.rsubi(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
|
override fun Tensor<T>.minusAssign(other: StructureND<T>) {
|
||||||
ndArray.subi(other.ndArray)
|
ndArray.subi(other.ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap()
|
override fun T.times(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.mul(this).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.times(value: T): Tensor<T> =
|
override fun StructureND<T>.times(arg: T): Nd4jArrayStructure<T> =
|
||||||
ndArray.mul(value).wrap()
|
ndArray.mul(arg).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap()
|
override fun StructureND<T>.times(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mul(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.timesAssign(value: T) {
|
override fun Tensor<T>.timesAssign(value: T) {
|
||||||
ndArray.muli(value)
|
ndArray.muli(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
|
override fun Tensor<T>.timesAssign(other: StructureND<T>) {
|
||||||
ndArray.mmuli(other.ndArray)
|
ndArray.mmuli(other.ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap()
|
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
|
||||||
override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap()
|
override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
|
||||||
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap()
|
override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
|
||||||
override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap()
|
override fun Tensor<T>.dot(other: Tensor<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
ndArray.min(keepDim, dim).wrap()
|
ndArray.min(keepDim, dim).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
ndArray.sum(keepDim, dim).wrap()
|
ndArray.sum(keepDim, dim).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
ndArray.max(keepDim, dim).wrap()
|
ndArray.max(keepDim, dim).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
|
||||||
override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
override fun Tensor<T>.viewAs(other: Tensor<T>): Nd4jArrayStructure<T> = view(other.shape)
|
||||||
|
|
||||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = ndArray.mean(keepDim, dim).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap()
|
override fun Tensor<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
|
||||||
override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap()
|
override fun Tensor<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
|
||||||
override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap()
|
override fun Tensor<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
|
||||||
override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap()
|
override fun Tensor<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
|
||||||
override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap()
|
override fun Tensor<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
|
||||||
override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap()
|
override fun Tensor<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.acosh(): Tensor<T> =
|
override fun Tensor<T>.acosh(): Nd4jArrayStructure<T> =
|
||||||
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
|
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap()
|
override fun Tensor<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
|
||||||
override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap()
|
override fun Tensor<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
|
||||||
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
|
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.asinh(): Tensor<T> =
|
override fun Tensor<T>.asinh(): Nd4jArrayStructure<T> =
|
||||||
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
|
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap()
|
override fun Tensor<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
|
||||||
override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap()
|
override fun Tensor<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
|
||||||
override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap()
|
override fun Tensor<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
|
||||||
override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap()
|
override fun Tensor<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
|
||||||
override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap()
|
override fun Tensor<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
|
||||||
override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap()
|
override fun Tensor<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
|
||||||
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap()
|
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap()
|
ndArray.std(true, keepDim, dim).wrap()
|
||||||
override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
|
|
||||||
override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap()
|
override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap()
|
||||||
|
override fun StructureND<T>.div(arg: T): Nd4jArrayStructure<T> = ndArray.div(arg).wrap()
|
||||||
|
override fun StructureND<T>.div(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.div(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.divAssign(value: T) {
|
override fun Tensor<T>.divAssign(value: T) {
|
||||||
ndArray.divi(value)
|
ndArray.divi(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.divAssign(other: Tensor<T>) {
|
override fun Tensor<T>.divAssign(other: StructureND<T>) {
|
||||||
ndArray.divi(other.ndArray)
|
ndArray.divi(other.ndArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> =
|
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
|
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
|
||||||
|
|
||||||
private companion object {
|
private companion object {
|
||||||
@ -142,7 +164,10 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
|
|||||||
/**
|
/**
|
||||||
* [Double] specialization of [Nd4jTensorAlgebra].
|
* [Double] specialization of [Nd4jTensorAlgebra].
|
||||||
*/
|
*/
|
||||||
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||||
|
|
||||||
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
@ -154,7 +179,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<Double>.valueOrNull(): Double? =
|
override fun StructureND<Double>.valueOrNull(): Double? =
|
||||||
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
|
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
|
||||||
|
|
||||||
// TODO rewrite
|
// TODO rewrite
|
||||||
|
@ -57,5 +57,5 @@ public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T
|
|||||||
*
|
*
|
||||||
* @param other tensor to be divided by.
|
* @param other tensor to be divided by.
|
||||||
*/
|
*/
|
||||||
public operator fun Tensor<T>.divAssign(other: Tensor<T>)
|
public operator fun Tensor<T>.divAssign(other: StructureND<T>)
|
||||||
}
|
}
|
||||||
|
@ -330,7 +330,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
|
override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
|
||||||
checkShapesCompatible(tensor, other)
|
checkShapesCompatible(tensor, other)
|
||||||
for (i in 0 until tensor.numElements) {
|
for (i in 0 until tensor.numElements) {
|
||||||
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
|
||||||
|
Loading…
Reference in New Issue
Block a user