KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
2 changed files with 7 additions and 4 deletions
Showing only changes of commit 287e2aeba2 - Show all commits

View File

@ -8,6 +8,7 @@ package space.kscience.kmath.tensors.api
// https://proofwiki.org/wiki/Definition:Division_Algebra // https://proofwiki.org/wiki/Definition:Division_Algebra
public interface TensorPartialDivisionAlgebra<T> : public interface TensorPartialDivisionAlgebra<T> :
TensorAlgebra<T> { TensorAlgebra<T> {
public operator fun T.div(other: TensorStructure<T>): TensorStructure<T>
public operator fun TensorStructure<T>.div(value: T): TensorStructure<T> public operator fun TensorStructure<T>.div(value: T): TensorStructure<T>
public operator fun TensorStructure<T>.div(other: TensorStructure<T>): TensorStructure<T> public operator fun TensorStructure<T>.div(other: TensorStructure<T>): TensorStructure<T>
public operator fun TensorStructure<T>.divAssign(value: T) public operator fun TensorStructure<T>.divAssign(value: T)

View File

@ -178,13 +178,15 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
} }
override fun TensorStructure<Double>.div(value: Double): DoubleTensor { override fun Double.div(other: TensorStructure<Double>): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
tensor.buffer.array()[tensor.bufferStart + i] / value other.tensor.buffer.array()[other.tensor.bufferStart + i] / this
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.div(value: Double): DoubleTensor = value / tensor
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor { override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->