From bd3425e7a545622c117740bba178ec1b1b531eab Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Tue, 16 Mar 2021 14:43:20 +0000 Subject: [PATCH] IndexTensor type added to LinearOps --- .../kotlin/space/kscience/kmath/tensors/BufferedTensor.kt | 3 +++ .../kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt | 2 +- .../space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt index 264692d0c..c9a401ad5 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt @@ -12,6 +12,9 @@ public open class BufferedTensor( public val strides: TensorStrides get() = TensorStrides(shape) + public val numel: Int + get() = strides.linearSize + override fun get(index: IntArray): T = buffer[bufferStart + strides.offset(index)] override fun set(index: IntArray, value: T) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt index d6b202556..689eda9e0 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleLinearOpsTensorAlgebra.kt @@ -1,7 +1,7 @@ package space.kscience.kmath.tensors public class DoubleLinearOpsTensorAlgebra : - LinearOpsTensorAlgebra, + LinearOpsTensorAlgebra, DoubleTensorAlgebra() { override fun DoubleTensor.inv(): DoubleTensor { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index 7450c09c1..2e1c4a92c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -1,7 +1,7 @@ package space.kscience.kmath.tensors -public interface LinearOpsTensorAlgebra> : +public interface LinearOpsTensorAlgebra, IndexTensorType: TensorStructure> : TensorPartialDivisionAlgebra { //https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv @@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra> : public fun TensorType.qr(): TensorType //https://pytorch.org/docs/stable/generated/torch.lu.html - public fun TensorType.lu(): Pair + public fun TensorType.lu(): Pair //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html public fun luPivot(lu: TensorType, pivots: IntTensor): Triple