forked from kscience/kmath
matrixhelper
This commit is contained in:
parent
4e4690e510
commit
2d2c5aa684
@ -13,6 +13,7 @@ public open class BufferedTensor<T>(
|
|||||||
TensorStrides(shape),
|
TensorStrides(shape),
|
||||||
buffer
|
buffer
|
||||||
) {
|
) {
|
||||||
|
/*
|
||||||
|
|
||||||
public operator fun get(i: Int, j: Int): T {
|
public operator fun get(i: Int, j: Int): T {
|
||||||
check(this.dimension == 2) { "Not matrix" }
|
check(this.dimension == 2) { "Not matrix" }
|
||||||
@ -24,8 +25,60 @@ public open class BufferedTensor<T>(
|
|||||||
this[intArrayOf(i, j)] = value
|
this[intArrayOf(i, j)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//todo make generator mb nextMatrixIndex?
|
||||||
|
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
|
||||||
|
private var offset: Int = 0
|
||||||
|
private val n : Int = tensor.shape.size
|
||||||
|
//stride?
|
||||||
|
private val step = tensor.shape[n - 1] * tensor.shape[n - 2]
|
||||||
|
|
||||||
|
public operator fun get(i: Int, j: Int): T {
|
||||||
|
val index = tensor.strides.index(offset)
|
||||||
|
index[n - 2] = i
|
||||||
|
index[n - 1] = j
|
||||||
|
return tensor[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun set(i: Int, j: Int, value: T): Unit {
|
||||||
|
val index = tensor.strides.index(offset)
|
||||||
|
index[n - 2] = i
|
||||||
|
index[n - 1] = j
|
||||||
|
tensor[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun makeStep(){
|
||||||
|
offset += step
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class InnerVector<T>(private val tensor: BufferedTensor<T>){
|
||||||
|
private var offset: Int = 0
|
||||||
|
private val n : Int = tensor.shape.size
|
||||||
|
//stride?
|
||||||
|
private val step = tensor.shape[n - 1]
|
||||||
|
|
||||||
|
public operator fun get(i: Int): T {
|
||||||
|
val index = tensor.strides.index(offset)
|
||||||
|
index[n - 1] = i
|
||||||
|
return tensor[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun set(i: Int, value: T): Unit {
|
||||||
|
val index = tensor.strides.index(offset)
|
||||||
|
index[n - 1] = i
|
||||||
|
tensor[index] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun makeStep(){
|
||||||
|
offset += step
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//todo default buffer = arrayOf(0)???
|
||||||
public class IntTensor(
|
public class IntTensor(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
buffer: IntArray
|
buffer: IntArray
|
||||||
|
@ -2,8 +2,7 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
public class RealLinearOpsTensorAlgebra :
|
public class RealLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||||
RealTensorAlgebra()
|
RealTensorAlgebra() {
|
||||||
{
|
|
||||||
override fun eye(n: Int): RealTensor {
|
override fun eye(n: Int): RealTensor {
|
||||||
val shape = intArrayOf(n, n)
|
val shape = intArrayOf(n, n)
|
||||||
val buffer = DoubleArray(n * n) { 0.0 }
|
val buffer = DoubleArray(n * n) { 0.0 }
|
||||||
@ -26,14 +25,25 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
|
override fun RealTensor.lu(): Pair<RealTensor, IntTensor> {
|
||||||
// todo checks
|
// todo checks
|
||||||
val lu = this.copy()
|
val luTensor = this.copy()
|
||||||
val m = this.shape[0]
|
val lu = InnerMatrix(luTensor)
|
||||||
val pivot = IntArray(m)
|
//stride TODO!!! move to generator?
|
||||||
|
var matCnt = 1
|
||||||
|
for (i in 0 until this.shape.size - 2) {
|
||||||
// Initialize permutation array and parity
|
matCnt *= this.shape[i]
|
||||||
|
}
|
||||||
|
val n = this.shape.size
|
||||||
|
val m = this.shape[n - 1]
|
||||||
|
val pivotsShape = IntArray(n - 1) { i ->
|
||||||
|
this.shape[i]
|
||||||
|
}
|
||||||
|
val pivotsTensor = IntTensor(
|
||||||
|
pivotsShape,
|
||||||
|
IntArray(matCnt * m) { 0 }
|
||||||
|
)
|
||||||
|
val pivot = InnerVector(pivotsTensor)
|
||||||
|
for (i in 0 until matCnt) {
|
||||||
for (row in 0 until m) pivot[row] = row
|
for (row in 0 until m) pivot[row] = row
|
||||||
var even = true
|
|
||||||
|
|
||||||
for (i in 0 until m) {
|
for (i in 0 until m) {
|
||||||
var maxA = -1.0
|
var maxA = -1.0
|
||||||
@ -54,7 +64,6 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
val j = pivot[i]
|
val j = pivot[i]
|
||||||
pivot[i] = pivot[iMax]
|
pivot[i] = pivot[iMax]
|
||||||
pivot[iMax] = j
|
pivot[iMax] = j
|
||||||
even != even
|
|
||||||
|
|
||||||
for (k in 0 until m) {
|
for (k in 0 until m) {
|
||||||
val tmp = lu[i, k]
|
val tmp = lu[i, k]
|
||||||
@ -71,10 +80,16 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Pair(lu, IntTensor(intArrayOf(m), pivot))
|
lu.makeStep()
|
||||||
|
pivot.makeStep()
|
||||||
|
}
|
||||||
|
|
||||||
|
return Pair(luTensor, pivotsTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun luPivot(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
override fun luPivot(lu: RealTensor, pivots: IntTensor): Triple<RealTensor, RealTensor, RealTensor> {
|
||||||
|
TODO()
|
||||||
|
/*
|
||||||
// todo checks
|
// todo checks
|
||||||
val n = lu.shape[0]
|
val n = lu.shape[0]
|
||||||
val p = lu.zeroesLike()
|
val p = lu.zeroesLike()
|
||||||
@ -97,9 +112,12 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Triple(p, l, u)
|
|
||||||
|
return Triple(p, l, u)*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
override fun RealTensor.det(): RealTensor {
|
override fun RealTensor.det(): RealTensor {
|
||||||
TODO("Not yet implemented")
|
TODO("Not yet implemented")
|
||||||
}
|
}
|
||||||
@ -127,5 +145,5 @@ public class RealLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealTensorAlgebra.() -> R): R =
|
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealLinearOpsTensorAlgebra.() -> R): R =
|
||||||
RealLinearOpsTensorAlgebra().block()
|
RealLinearOpsTensorAlgebra().block()
|
Loading…
Reference in New Issue
Block a user