IndexTensor type added to LinearOps
This commit is contained in:
parent
138d376289
commit
bd3425e7a5
@ -12,6 +12,9 @@ public open class BufferedTensor<T>(
|
|||||||
public val strides: TensorStrides
|
public val strides: TensorStrides
|
||||||
get() = TensorStrides(shape)
|
get() = TensorStrides(shape)
|
||||||
|
|
||||||
|
public val numel: Int
|
||||||
|
get() = strides.linearSize
|
||||||
|
|
||||||
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)]
|
override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)]
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) {
|
override fun set(index: IntArray, value: T) {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
public class DoubleLinearOpsTensorAlgebra :
|
public class DoubleLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, DoubleTensor>,
|
LinearOpsTensorAlgebra<Double, DoubleTensor, IntTensor>,
|
||||||
DoubleTensorAlgebra() {
|
DoubleTensorAlgebra() {
|
||||||
|
|
||||||
override fun DoubleTensor.inv(): DoubleTensor {
|
override fun DoubleTensor.inv(): DoubleTensor {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
|
||||||
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, IndexTensorType: TensorStructure<Int>> :
|
||||||
TensorPartialDivisionAlgebra<T, TensorType> {
|
TensorPartialDivisionAlgebra<T, TensorType> {
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||||
@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>> :
|
|||||||
public fun TensorType.qr(): TensorType
|
public fun TensorType.qr(): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
public fun TensorType.lu(): Pair<TensorType, IntTensor>
|
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||||
public fun luPivot(lu: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
public fun luPivot(lu: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
||||||
|
Loading…
Reference in New Issue
Block a user