add buffered tensor + lu
This commit is contained in:
parent
95b814e163
commit
0911efd4aa
@ -1,33 +0,0 @@
|
|||||||
package space.kscience.kmath.tensors
|
|
||||||
|
|
||||||
import space.kscience.kmath.linear.BufferMatrix
|
|
||||||
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
|
|
||||||
import space.kscience.kmath.nd.Matrix
|
|
||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
|
||||||
import space.kscience.kmath.nd.Structure2D
|
|
||||||
import space.kscience.kmath.nd.as2D
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
import space.kscience.kmath.structures.BufferAccessor2D
|
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
|
||||||
import space.kscience.kmath.structures.toList
|
|
||||||
|
|
||||||
public open class BufferTensor<T>(
|
|
||||||
override val shape: IntArray,
|
|
||||||
buffer: MutableBuffer<T>
|
|
||||||
) :
|
|
||||||
TensorStructure<T>,
|
|
||||||
MutableNDBuffer<T>(
|
|
||||||
TensorStrides(shape),
|
|
||||||
buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
public fun <T : Any> BufferTensor<T>.toBufferMatrix(): BufferMatrix<T> {
|
|
||||||
return BufferMatrix(shape[0], shape[1], this.buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// T???
|
|
||||||
public fun BufferMatrix<Double>.BufferTensor(): BufferTensor<Double> {
|
|
||||||
return BufferTensor(intArrayOf(rowNum, colNum), BufferAccessor2D(rowNum, colNum, Buffer.Companion::real).create(this))
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,33 @@
|
|||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.linear.BufferMatrix
|
||||||
|
import space.kscience.kmath.nd.MutableNDBuffer
|
||||||
|
import space.kscience.kmath.structures.*
|
||||||
|
import space.kscience.kmath.structures.BufferAccessor2D
|
||||||
|
|
||||||
|
public open class BufferedTensor<T>(
|
||||||
|
override val shape: IntArray,
|
||||||
|
buffer: MutableBuffer<T>
|
||||||
|
) :
|
||||||
|
TensorStructure<T>,
|
||||||
|
MutableNDBuffer<T>(
|
||||||
|
TensorStrides(shape),
|
||||||
|
buffer
|
||||||
|
) {
|
||||||
|
|
||||||
|
public operator fun get(i: Int, j: Int): T{
|
||||||
|
check(this.dimension == 2) {"Not matrix"}
|
||||||
|
return this[intArrayOf(i, j)]
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun set(i: Int, j: Int, value: T): Unit{
|
||||||
|
check(this.dimension == 2) {"Not matrix"}
|
||||||
|
this[intArrayOf(i, j)] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public class IntTensor(
|
||||||
|
shape: IntArray,
|
||||||
|
buffer: IntArray
|
||||||
|
) : BufferedTensor<Int>(shape, IntBuffer(buffer))
|
@ -1,24 +1,19 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.linear.BufferMatrix
|
|
||||||
import space.kscience.kmath.linear.RealMatrixContext.toBufferMatrix
|
|
||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
|
||||||
import space.kscience.kmath.nd.as2D
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
import space.kscience.kmath.structures.RealBuffer
|
import space.kscience.kmath.structures.RealBuffer
|
||||||
import space.kscience.kmath.structures.array
|
import space.kscience.kmath.structures.array
|
||||||
import space.kscience.kmath.structures.toList
|
import kotlin.math.abs
|
||||||
|
|
||||||
|
|
||||||
public class RealTensor(
|
public class RealTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: DoubleArray
|
buffer: DoubleArray
|
||||||
) : BufferTensor<Double>(shape, RealBuffer(buffer))
|
) : BufferedTensor<Double>(shape, RealBuffer(buffer))
|
||||||
|
|
||||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
|
|
||||||
override fun RealTensor.value(): Double {
|
override fun RealTensor.value(): Double {
|
||||||
check(this.shape contentEquals intArrayOf(1)) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
}
|
}
|
||||||
return this.buffer.array[0]
|
return this.buffer.array[0]
|
||||||
@ -51,7 +46,8 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.copy(): RealTensor {
|
override fun RealTensor.copy(): RealTensor {
|
||||||
TODO("Not yet implemented")
|
// should be rework as soon as copy() method for NDBuffer will be available
|
||||||
|
return RealTensor(this.shape, this.buffer.array.copyOf())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Double.plus(other: RealTensor): RealTensor {
|
override fun Double.plus(other: RealTensor): RealTensor {
|
||||||
@ -249,12 +245,59 @@ public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor
|
|||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.lu(): Pair<RealTensor, RealTensor> {
|
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
|
||||||
TODO()
|
// todo checks
|
||||||
|
val lu = this.copy()
|
||||||
|
val m = this.shape[0]
|
||||||
|
val pivot = IntArray(m)
|
||||||
|
|
||||||
|
|
||||||
|
// Initialize permutation array and parity
|
||||||
|
for (row in 0 until m) pivot[row] = row
|
||||||
|
var even = true
|
||||||
|
|
||||||
|
for (i in 0 until m) {
|
||||||
|
var maxA = -1.0
|
||||||
|
var iMax = i
|
||||||
|
|
||||||
|
for (k in i until m) {
|
||||||
|
val absA = abs(lu[k, i])
|
||||||
|
if (absA > maxA) {
|
||||||
|
maxA = absA
|
||||||
|
iMax = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//todo check singularity
|
||||||
|
|
||||||
|
if (iMax != i) {
|
||||||
|
|
||||||
|
val j = pivot[i]
|
||||||
|
pivot[i] = pivot[iMax]
|
||||||
|
pivot[iMax] = j
|
||||||
|
even != even
|
||||||
|
|
||||||
|
for (k in 0 until m) {
|
||||||
|
val tmp = lu[i, k]
|
||||||
|
lu[i, k] = lu[iMax, k]
|
||||||
|
lu[iMax, k] = tmp
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for (j in i + 1 until m) {
|
||||||
|
lu[j, i] /= lu[i, i]
|
||||||
|
for (k in i + 1 until m) {
|
||||||
|
lu[j, k] -= lu[j, i] * lu[i, k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luUnpack(A_LU: RealTensor, pivots: RealTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
override fun luUnpack(A_LU: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
TODO("Not yet implemented")
|
// todo checks
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
override fun RealTensor.svd(): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
|
@ -78,11 +78,12 @@ public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>
|
|||||||
public fun TensorType.log(): TensorType
|
public fun TensorType.log(): TensorType
|
||||||
public fun TensorType.logAssign(): Unit
|
public fun TensorType.logAssign(): Unit
|
||||||
|
|
||||||
|
// todo change type of pivots
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
public fun TensorType.lu(): Pair<TensorType, TensorType>
|
public fun TensorType.lu(): Pair<TensorType, IntTensor>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||||
public fun luUnpack(A_LU: TensorType, pivots: TensorType): Triple<TensorType, TensorType, TensorType>
|
public fun luUnpack(A_LU: TensorType, pivots: IntTensor): Triple<TensorType, TensorType, TensorType>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.svd.html
|
//https://pytorch.org/docs/stable/generated/torch.svd.html
|
||||||
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
public fun TensorType.svd(): Triple<TensorType, TensorType, TensorType>
|
||||||
|
Loading…
Reference in New Issue
Block a user