IndexTensor type added to LinearOps

This commit is contained in:
Roland Grinis 2021-03-16 14:43:20 +00:00
parent 138d376289
commit bd3425e7a5
3 changed files with 6 additions and 3 deletions

View File

@ -12,6 +12,9 @@ public open class BufferedTensor<T>(
public val strides: TensorStrides
get() = TensorStrides(shape)
public val numel: Int
get() = strides.linearSize
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)]
override fun set(index: IntArray, value: T) {

View File

@ -1,7 +1,7 @@
package space.kscience.kmath.tensors
public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor>,
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
DoubleTensorAlgebra() {
override fun DoubleTensor.inv(): DoubleTensor {

View File

@ -1,7 +1,7 @@
package space.kscience.kmath.tensors
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, IndexTensorType: TensorStructure<Int>> :
TensorPartialDivisionAlgebra<T, TensorType> {
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
public fun TensorType.qr(): TensorType
//https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, IntTensor>
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
public fun luPivot(lu: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>