forked from kscience/kmath
add buffered tensor + lu
This commit is contained in:
parent
95b814e163
commit
0911efd4aa
@ -1,33 +0,0 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.linear.BufferMatrix
|
||||
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
|
||||
import space.kscience.kmath.nd.Matrix
|
||||
import space.kscience.kmath.nd.MutableNDBuffer
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferAccessor2D
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.toList
|
||||
|
||||
public open class BufferTensor<T>(
|
||||
override val shape: IntArray,
|
||||
buffer: MutableBuffer<T>
|
||||
) :
|
||||
TensorStructure<T>,
|
||||
MutableNDBuffer<T>(
|
||||
TensorStrides(shape),
|
||||
buffer
|
||||
)
|
||||
|
||||
|
||||
public fun <T : Any> BufferTensor<T>.toBufferMatrix(): BufferMatrix<T> {
|
||||
return BufferMatrix(shape[0], shape[1], this.buffer)
|
||||
}
|
||||
|
||||
// T???
|
||||
public fun BufferMatrix<Double>.BufferTensor(): BufferTensor<Double> {
|
||||
return BufferTensor(intArrayOf(rowNum, colNum), BufferAccessor2D(rowNum, colNum, Buffer.Companion::real).create(this))
|
||||
}
|
||||
|
@ -0,0 +1,33 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.linear.BufferMatrix
|
||||
import space.kscience.kmath.nd.MutableNDBuffer
|
||||
import space.kscience.kmath.structures.*
|
||||
import space.kscience.kmath.structures.BufferAccessor2D
|
||||
|
||||
public open class BufferedTensor<T>(
|
||||
override val shape: IntArray,
|
||||
buffer: MutableBuffer<T>
|
||||
) :
|
||||
TensorStructure<T>,
|
||||
MutableNDBuffer<T>(
|
||||
TensorStrides(shape),
|
||||
buffer
|
||||
) {
|
||||
|
||||
public operator fun get(i: Int, j: Int): T{
|
||||
check(this.dimension == 2) {"Not matrix"}
|
||||
return this[intArrayOf(i, j)]
|
||||
}
|
||||
|
||||
public operator fun set(i: Int, j: Int, value: T): Unit{
|
||||
check(this.dimension == 2) {"Not matrix"}
|
||||
this[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public class IntTensor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray
|
||||
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
@ -1,24 +1,19 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.linear.BufferMatrix
|
||||
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
|
||||
import space.kscience.kmath.nd.MutableNDBuffer
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.RealBuffer
|
||||
import space.kscience.kmath.structures.array
|
||||
import space.kscience.kmath.structures.toList
|
||||
import kotlin.math.abs
|
||||
|
||||
|
||||
public class RealTensor(
|
||||
shape: IntArray,
|
||||
buffer: DoubleArray
|
||||
) : BufferTensor<Double>(shape, RealBuffer(buffer))
|
||||
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
||||
|
||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||
|
||||
override fun RealTensor.value(): Double {
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||
}
|
||||
return this.buffer.array[0]
|
||||
@ -51,7 +46,8 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
}
|
||||
|
||||
override fun RealTensor.copy(): RealTensor {
|
||||
TODO("Not yet implemented")
|
||||
// should be rework as soon as copy() method for NDBuffer will be available
|
||||
return RealTensor(this.shape, this.buffer.array.copyOf())
|
||||
}
|
||||
|
||||
override fun Double.plus(other: RealTensor): RealTensor {
|
||||
@ -249,12 +245,59 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> {
|
||||
TODO()
|
||||
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
|
||||
// todo checks
|
||||
val lu = this.copy()
|
||||
val m = this.shape[0]
|
||||
val pivot = IntArray(m)
|
||||
|
||||
|
||||
// Initialize permutation array and parity
|
||||
for (row in 0 until m) pivot[row] = row
|
||||
var even = true
|
||||
|
||||
for (i in 0 until m) {
|
||||
var maxA = -1.0
|
||||
var iMax = i
|
||||
|
||||
for (k in i until m) {
|
||||
val absA = abs(lu[k, i])
|
||||
if (absA > maxA) {
|
||||
maxA = absA
|
||||
iMax = k
|
||||
}
|
||||
}
|
||||
|
||||
//todo check singularity
|
||||
|
||||
if (iMax != i) {
|
||||
|
||||
val j = pivot[i]
|
||||
pivot[i] = pivot[iMax]
|
||||
pivot[iMax] = j
|
||||
even != even
|
||||
|
||||
for (k in 0 until m) {
|
||||
val tmp = lu[i, k]
|
||||
lu[i, k] = lu[iMax, k]
|
||||
lu[iMax, k] = tmp
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for (j in i + 1 until m) {
|
||||
lu[j, i] /= lu[i, i]
|
||||
for (k in i + 1 until m) {
|
||||
lu[j, k] -= lu[j, i] * lu[i, k]
|
||||
}
|
||||
}
|
||||
}
|
||||
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
||||
}
|
||||
|
||||
override fun luUnpack(A_LU: RealTensor, pivots: RealTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
TODO("Not yet implemented")
|
||||
override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
// todo checks
|
||||
|
||||
}
|
||||
|
||||
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
|
@ -78,11 +78,12 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
||||
public fun TensorType.log(): TensorType
|
||||
public fun TensorType.logAssign(): Unit
|
||||
|
||||
// todo change type of pivots
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
public fun TensorType.lu(): Pair<TensorType, TensorType>
|
||||
public fun TensorType.lu(): Pair<TensorType, IntTensor>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||
public fun luUnpack(A_LU: TensorType, pivots: TensorType): Triple<TensorType, TensorType, TensorType>
|
||||
public fun luUnpack(A_LU: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.svd.html
|
||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||
|
Loading…
Reference in New Issue
Block a user