matrixhelper

This commit is contained in:
Andrei Kislitsyn 2021-03-15 15:18:21 +03:00
parent 4e4690e510
commit 2d2c5aa684
2 changed files with 120 additions and 49 deletions

View File

@ -13,6 +13,7 @@ public open class BufferedTensor<T>(
TensorStrides(shape), TensorStrides(shape),
buffer buffer
) { ) {
/*
public operator fun get(i: Int, j: Int): T { public operator fun get(i: Int, j: Int): T {
check(this.dimension == 2) { "Not matrix" } check(this.dimension == 2) { "Not matrix" }
@ -24,8 +25,60 @@ public open class BufferedTensor<T>(
this[intArrayOf(i, j)] = value this[intArrayOf(i, j)] = value
} }
*/
} }
//todo make generator mb nextMatrixIndex?
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
private var offset: Int = 0
private val n : Int = tensor.shape.size
//stride?
private val step = tensor.shape[n - 1] * tensor.shape[n - 2]
public operator fun get(i: Int, j: Int): T {
val index = tensor.strides.index(offset)
index[n - 2] = i
index[n - 1] = j
return tensor[index]
}
public operator fun set(i: Int, j: Int, value: T): Unit {
val index = tensor.strides.index(offset)
index[n - 2] = i
index[n - 1] = j
tensor[index] = value
}
public fun makeStep(){
offset += step
}
}
public class InnerVector<T>(private val tensor: BufferedTensor<T>){
private var offset: Int = 0
private val n : Int = tensor.shape.size
//stride?
private val step = tensor.shape[n - 1]
public operator fun get(i: Int): T {
val index = tensor.strides.index(offset)
index[n - 1] = i
return tensor[index]
}
public operator fun set(i: Int, value: T): Unit {
val index = tensor.strides.index(offset)
index[n - 1] = i
tensor[index] = value
}
public fun makeStep(){
offset += step
}
}
//todo default buffer = arrayOf(0)???
public class IntTensor( public class IntTensor(
shape: IntArray, shape: IntArray,
buffer: IntArray buffer: IntArray

View File

@ -2,8 +2,7 @@ package space.kscience.kmath.tensors
public class RealLinearOpsTensorAlgebra : public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>, LinearOpsTensorAlgebra<Double, RealTensor>,
RealTensorAlgebra() RealTensorAlgebra() {
{
override fun eye(n: Int): RealTensor { override fun eye(n: Int): RealTensor {
val shape = intArrayOf(n, n) val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 } val buffer = DoubleArray(n * n) { 0.0 }
@ -26,14 +25,25 @@ public class RealLinearOpsTensorAlgebra :
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> { override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
// todo checks // todo checks
val lu = this.copy() val luTensor = this.copy()
val m = this.shape[0] val lu = InnerMatrix(luTensor)
val pivot = IntArray(m) //stride TODO!!! move to generator?
var matCnt = 1
for (i in 0 until this.shape.size - 2) {
// Initialize permutation array and parity matCnt *= this.shape[i]
}
val n = this.shape.size
val m = this.shape[n - 1]
val pivotsShape = IntArray(n - 1) { i ->
this.shape[i]
}
val pivotsTensor = IntTensor(
pivotsShape,
IntArray(matCnt * m) { 0 }
)
val pivot = InnerVector(pivotsTensor)
for (i in 0 until matCnt) {
for (row in 0 until m) pivot[row] = row for (row in 0 until m) pivot[row] = row
var even = true
for (i in 0 until m) { for (i in 0 until m) {
var maxA = -1.0 var maxA = -1.0
@ -54,7 +64,6 @@ public class RealLinearOpsTensorAlgebra :
val j = pivot[i] val j = pivot[i]
pivot[i] = pivot[iMax] pivot[i] = pivot[iMax]
pivot[iMax] = j pivot[iMax] = j
even != even
for (k in 0 until m) { for (k in 0 until m) {
val tmp = lu[i, k] val tmp = lu[i, k]
@ -71,10 +80,16 @@ public class RealLinearOpsTensorAlgebra :
} }
} }
} }
return Pair(lu, IntTensor(intArrayOf(m), pivot)) lu.makeStep()
pivot.makeStep()
}
return Pair(luTensor, pivotsTensor)
} }
override fun luPivot(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> { override fun luPivot(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
TODO()
/*
// todo checks // todo checks
val n = lu.shape[0] val n = lu.shape[0]
val p = lu.zeroesLike() val p = lu.zeroesLike()
@ -97,9 +112,12 @@ public class RealLinearOpsTensorAlgebra :
} }
} }
} }
return Triple(p, l, u)
return Triple(p, l, u)*/
} }
override fun RealTensor.det(): RealTensor { override fun RealTensor.det(): RealTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
@ -127,5 +145,5 @@ public class RealLinearOpsTensorAlgebra :
} }
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealTensorAlgebra.() -> R): R = public inline fun <R> RealLinearOpsTensorAlgebra(block: RealLinearOpsTensorAlgebra.() -> R): R =
RealLinearOpsTensorAlgebra().block() RealLinearOpsTensorAlgebra().block()