From 0911efd4aa1944c89cb9df4913c251daef98d31c Mon Sep 17 00:00:00 2001 From: Andrei Kislitsyn Date: Sat, 13 Mar 2021 19:30:58 +0300 Subject: [PATCH] add buffered tensor + lu --- .../kscience/kmath/tensors/BufferTensor.kt | 33 --------- .../kscience/kmath/tensors/BufferedTensor.kt | 33 +++++++++ .../kmath/tensors/RealTensorAlgebra.kt | 69 +++++++++++++++---- .../kscience/kmath/tensors/TensorAlgebra.kt | 5 +- 4 files changed, 92 insertions(+), 48 deletions(-) delete mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferTensor.kt create mode 100644 kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferTensor.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferTensor.kt deleted file mode 100644 index 817c51f11..000000000 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferTensor.kt +++ /dev/null @@ -1,33 +0,0 @@ -package space.kscience.kmath.tensors - -import space.kscience.kmath.linear.BufferMatrix -import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix -import space.kscience.kmath.nd.Matrix -import space.kscience.kmath.nd.MutableNDBuffer -import space.kscience.kmath.nd.Structure2D -import space.kscience.kmath.nd.as2D -import space.kscience.kmath.structures.Buffer -import space.kscience.kmath.structures.BufferAccessor2D -import space.kscience.kmath.structures.MutableBuffer -import space.kscience.kmath.structures.toList - -public open class BufferTensor( - override val shape: IntArray, - buffer: MutableBuffer -) : - TensorStructure, - MutableNDBuffer( - TensorStrides(shape), - buffer - ) - - -public fun BufferTensor.toBufferMatrix(): BufferMatrix { - return BufferMatrix(shape[0], shape[1], this.buffer) -} - -// T??? -public fun BufferMatrix.BufferTensor(): BufferTensor { - return BufferTensor(intArrayOf(rowNum, colNum), BufferAccessor2D(rowNum, colNum, Buffer.Companion::real).create(this)) -} - diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt new file mode 100644 index 000000000..9ffe2db61 --- /dev/null +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/BufferedTensor.kt @@ -0,0 +1,33 @@ +package space.kscience.kmath.tensors + +import space.kscience.kmath.linear.BufferMatrix +import space.kscience.kmath.nd.MutableNDBuffer +import space.kscience.kmath.structures.* +import space.kscience.kmath.structures.BufferAccessor2D + +public open class BufferedTensor( + override val shape: IntArray, + buffer: MutableBuffer +) : + TensorStructure, + MutableNDBuffer( + TensorStrides(shape), + buffer + ) { + + public operator fun get(i: Int, j: Int): T{ + check(this.dimension == 2) {"Not matrix"} + return this[intArrayOf(i, j)] + } + + public operator fun set(i: Int, j: Int, value: T): Unit{ + check(this.dimension == 2) {"Not matrix"} + this[intArrayOf(i, j)] = value + } + +} + +public class IntTensor( + shape: IntArray, + buffer: IntArray +) : BufferedTensor(shape, IntBuffer(buffer)) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt index e38056097..3e21a3a98 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/RealTensorAlgebra.kt @@ -1,24 +1,19 @@ package space.kscience.kmath.tensors -import space.kscience.kmath.linear.BufferMatrix -import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix -import space.kscience.kmath.nd.MutableNDBuffer -import space.kscience.kmath.nd.as2D -import space.kscience.kmath.structures.Buffer import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.array -import space.kscience.kmath.structures.toList +import kotlin.math.abs public class RealTensor( shape: IntArray, buffer: DoubleArray -) : BufferTensor(shape, RealBuffer(buffer)) +) : BufferedTensor(shape, RealBuffer(buffer)) public class RealTensorAlgebra : TensorPartialDivisionAlgebra { override fun RealTensor.value(): Double { - check(this.shape contentEquals intArrayOf(1)) { + check(this.shape contentEquals intArrayOf(1)) { "Inconsistent value for tensor of shape ${shape.toList()}" } return this.buffer.array[0] @@ -51,7 +46,8 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra { - TODO() + override fun RealTensor.lu(): Pair { + // todo checks + val lu = this.copy() + val m = this.shape[0] + val pivot = IntArray(m) + + + // Initialize permutation array and parity + for (row in 0 until m) pivot[row] = row + var even = true + + for (i in 0 until m) { + var maxA = -1.0 + var iMax = i + + for (k in i until m) { + val absA = abs(lu[k, i]) + if (absA > maxA) { + maxA = absA + iMax = k + } + } + + //todo check singularity + + if (iMax != i) { + + val j = pivot[i] + pivot[i] = pivot[iMax] + pivot[iMax] = j + even != even + + for (k in 0 until m) { + val tmp = lu[i, k] + lu[i, k] = lu[iMax, k] + lu[iMax, k] = tmp + } + + } + + for (j in i + 1 until m) { + lu[j, i] /= lu[i, i] + for (k in i + 1 until m) { + lu[j, k] -= lu[j, i] * lu[i, k] + } + } + } + return Pair(lu, IntTensor(intArrayOf(m), pivot)) } - override fun luUnpack(A_LU: RealTensor, pivots: RealTensor): Triple { - TODO("Not yet implemented") + override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple { + // todo checks + } override fun RealTensor.svd(): Triple { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 6f42623e0..0dbcae044 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -78,11 +78,12 @@ public interface TensorPartialDivisionAlgebra public fun TensorType.log(): TensorType public fun TensorType.logAssign(): Unit + // todo change type of pivots //https://pytorch.org/docs/stable/generated/torch.lu.html - public fun TensorType.lu(): Pair + public fun TensorType.lu(): Pair //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html - public fun luUnpack(A_LU: TensorType, pivots: TensorType): Triple + public fun luUnpack(A_LU: TensorType, pivots: IntTensor): Triple //https://pytorch.org/docs/stable/generated/torch.svd.html public fun TensorType.svd(): Triple