forked from kscience/kmath
matrixhelper
This commit is contained in:
parent
4e4690e510
commit
2d2c5aa684
@ -13,6 +13,7 @@ public open class BufferedTensor<T>(
|
||||
TensorStrides(shape),
|
||||
buffer
|
||||
) {
|
||||
/*
|
||||
|
||||
public operator fun get(i: Int, j: Int): T {
|
||||
check(this.dimension == 2) { "Not matrix" }
|
||||
@ -24,8 +25,60 @@ public open class BufferedTensor<T>(
|
||||
this[intArrayOf(i, j)] = value
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
|
||||
//todo make generator mb nextMatrixIndex?
|
||||
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
|
||||
private var offset: Int = 0
|
||||
private val n : Int = tensor.shape.size
|
||||
//stride?
|
||||
private val step = tensor.shape[n - 1] * tensor.shape[n - 2]
|
||||
|
||||
public operator fun get(i: Int, j: Int): T {
|
||||
val index = tensor.strides.index(offset)
|
||||
index[n - 2] = i
|
||||
index[n - 1] = j
|
||||
return tensor[index]
|
||||
}
|
||||
|
||||
public operator fun set(i: Int, j: Int, value: T): Unit {
|
||||
val index = tensor.strides.index(offset)
|
||||
index[n - 2] = i
|
||||
index[n - 1] = j
|
||||
tensor[index] = value
|
||||
}
|
||||
|
||||
public fun makeStep(){
|
||||
offset += step
|
||||
}
|
||||
}
|
||||
|
||||
public class InnerVector<T>(private val tensor: BufferedTensor<T>){
|
||||
private var offset: Int = 0
|
||||
private val n : Int = tensor.shape.size
|
||||
//stride?
|
||||
private val step = tensor.shape[n - 1]
|
||||
|
||||
public operator fun get(i: Int): T {
|
||||
val index = tensor.strides.index(offset)
|
||||
index[n - 1] = i
|
||||
return tensor[index]
|
||||
}
|
||||
|
||||
public operator fun set(i: Int, value: T): Unit {
|
||||
val index = tensor.strides.index(offset)
|
||||
index[n - 1] = i
|
||||
tensor[index] = value
|
||||
}
|
||||
|
||||
public fun makeStep(){
|
||||
offset += step
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//todo default buffer = arrayOf(0)???
|
||||
public class IntTensor(
|
||||
shape: IntArray,
|
||||
buffer: IntArray
|
||||
|
@ -2,8 +2,7 @@ package space.kscience.kmath.tensors
|
||||
|
||||
public class RealLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||
RealTensorAlgebra()
|
||||
{
|
||||
RealTensorAlgebra() {
|
||||
override fun eye(n: Int): RealTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
val buffer = DoubleArray(n * n) { 0.0 }
|
||||
@ -26,14 +25,25 @@ public class RealLinearOpsTensorAlgebra :
|
||||
|
||||
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
|
||||
// todo checks
|
||||
val lu = this.copy()
|
||||
val m = this.shape[0]
|
||||
val pivot = IntArray(m)
|
||||
|
||||
|
||||
// Initialize permutation array and parity
|
||||
val luTensor = this.copy()
|
||||
val lu = InnerMatrix(luTensor)
|
||||
//stride TODO!!! move to generator?
|
||||
var matCnt = 1
|
||||
for (i in 0 until this.shape.size - 2) {
|
||||
matCnt *= this.shape[i]
|
||||
}
|
||||
val n = this.shape.size
|
||||
val m = this.shape[n - 1]
|
||||
val pivotsShape = IntArray(n - 1) { i ->
|
||||
this.shape[i]
|
||||
}
|
||||
val pivotsTensor = IntTensor(
|
||||
pivotsShape,
|
||||
IntArray(matCnt * m) { 0 }
|
||||
)
|
||||
val pivot = InnerVector(pivotsTensor)
|
||||
for (i in 0 until matCnt) {
|
||||
for (row in 0 until m) pivot[row] = row
|
||||
var even = true
|
||||
|
||||
for (i in 0 until m) {
|
||||
var maxA = -1.0
|
||||
@ -54,7 +64,6 @@ public class RealLinearOpsTensorAlgebra :
|
||||
val j = pivot[i]
|
||||
pivot[i] = pivot[iMax]
|
||||
pivot[iMax] = j
|
||||
even != even
|
||||
|
||||
for (k in 0 until m) {
|
||||
val tmp = lu[i, k]
|
||||
@ -71,10 +80,16 @@ public class RealLinearOpsTensorAlgebra :
|
||||
}
|
||||
}
|
||||
}
|
||||
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
||||
lu.makeStep()
|
||||
pivot.makeStep()
|
||||
}
|
||||
|
||||
return Pair(luTensor, pivotsTensor)
|
||||
}
|
||||
|
||||
override fun luPivot(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||
TODO()
|
||||
/*
|
||||
// todo checks
|
||||
val n = lu.shape[0]
|
||||
val p = lu.zeroesLike()
|
||||
@ -97,9 +112,12 @@ public class RealLinearOpsTensorAlgebra :
|
||||
}
|
||||
}
|
||||
}
|
||||
return Triple(p, l, u)
|
||||
|
||||
return Triple(p, l, u)*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
override fun RealTensor.det(): RealTensor {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
@ -127,5 +145,5 @@ public class RealLinearOpsTensorAlgebra :
|
||||
|
||||
}
|
||||
|
||||
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
|
||||
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealLinearOpsTensorAlgebra.() -> R): R =
|
||||
RealLinearOpsTensorAlgebra().block()
|
Loading…
Reference in New Issue
Block a user