Sync with Andrew

This commit is contained in:
Roland Grinis 2021-03-15 17:08:45 +00:00
commit b6a5fbfc14
2 changed files with 119 additions and 50 deletions

View File

@ -14,6 +14,7 @@ public open class BufferedTensor<T>(
buffer buffer
) { ) {
public operator fun get(i: Int, j: Int): T { public operator fun get(i: Int, j: Int): T {
check(this.dimension == 2) { "Not matrix" } check(this.dimension == 2) { "Not matrix" }
return this[intArrayOf(i, j)] return this[intArrayOf(i, j)]
@ -24,8 +25,60 @@ public open class BufferedTensor<T>(
this[intArrayOf(i, j)] = value this[intArrayOf(i, j)] = value
} }
} }
//todo make generator mb nextMatrixIndex?
public class InnerMatrix<T>(private val tensor: BufferedTensor<T>){
private var offset: Int = 0
private val n : Int = tensor.shape.size
//stride?
private val step = tensor.shape[n - 1] * tensor.shape[n - 2]
public operator fun get(i: Int, j: Int): T {
val index = tensor.strides.index(offset)
index[n - 2] = i
index[n - 1] = j
return tensor[index]
}
public operator fun set(i: Int, j: Int, value: T): Unit {
val index = tensor.strides.index(offset)
index[n - 2] = i
index[n - 1] = j
tensor[index] = value
}
public fun makeStep(){
offset += step
}
}
public class InnerVector<T>(private val tensor: BufferedTensor<T>){
private var offset: Int = 0
private val n : Int = tensor.shape.size
//stride?
private val step = tensor.shape[n - 1]
public operator fun get(i: Int): T {
val index = tensor.strides.index(offset)
index[n - 1] = i
return tensor[index]
}
public operator fun set(i: Int, value: T): Unit {
val index = tensor.strides.index(offset)
index[n - 1] = i
tensor[index] = value
}
public fun makeStep(){
offset += step
}
}
//todo default buffer = arrayOf(0)???
public class IntTensor( public class IntTensor(
shape: IntArray, shape: IntArray,
buffer: IntArray buffer: IntArray

View File

@ -1,9 +1,8 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
public class RealLinearOpsTensorAlgebra : public class DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, DoubleTensor>, LinearOpsTensorAlgebra<Double, DoubleTensor>,
RealTensorAlgebra() RealTensorAlgebra() {
{
override fun eye(n: Int): DoubleTensor { override fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n) val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 } val buffer = DoubleArray(n * n) { 0.0 }
@ -26,14 +25,25 @@ public class RealLinearOpsTensorAlgebra :
override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> { override fun DoubleTensor.lu(): Pair<DoubleTensor, IntTensor> {
// todo checks // todo checks
val lu = this.copy() val luTensor = this.copy()
val m = this.shape[0] val lu = InnerMatrix(luTensor)
val pivot = IntArray(m) //stride TODO!!! move to generator?
var matCnt = 1
for (i in 0 until this.shape.size - 2) {
// Initialize permutation array and parity matCnt *= this.shape[i]
}
val n = this.shape.size
val m = this.shape[n - 1]
val pivotsShape = IntArray(n - 1) { i ->
this.shape[i]
}
val pivotsTensor = IntTensor(
pivotsShape,
IntArray(matCnt * m) { 0 }
)
val pivot = InnerVector(pivotsTensor)
for (i in 0 until matCnt) {
for (row in 0 until m) pivot[row] = row for (row in 0 until m) pivot[row] = row
var even = true
for (i in 0 until m) { for (i in 0 until m) {
var maxA = -1.0 var maxA = -1.0
@ -54,7 +64,6 @@ public class RealLinearOpsTensorAlgebra :
val j = pivot[i] val j = pivot[i]
pivot[i] = pivot[iMax] pivot[i] = pivot[iMax]
pivot[iMax] = j pivot[iMax] = j
even != even
for (k in 0 until m) { for (k in 0 until m) {
val tmp = lu[i, k] val tmp = lu[i, k]
@ -71,10 +80,15 @@ public class RealLinearOpsTensorAlgebra :
} }
} }
} }
return Pair(lu, IntTensor(intArrayOf(m), pivot)) lu.makeStep()
pivot.makeStep()
}
return Pair(luTensor, pivotsTensor)
} }
override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { override fun luPivot(lu: DoubleTensor, pivots: IntTensor): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
// todo checks // todo checks
val n = lu.shape[0] val n = lu.shape[0]
val p = lu.zeroesLike() val p = lu.zeroesLike()
@ -97,10 +111,12 @@ public class RealLinearOpsTensorAlgebra :
} }
} }
} }
return Triple(p, l, u) return Triple(p, l, u)
} }
override fun DoubleTensor.det(): DoubleTensor { override fun DoubleTensor.det(): DoubleTensor {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
@ -127,5 +143,5 @@ public class RealLinearOpsTensorAlgebra :
} }
public inline fun <R> RealLinearOpsTensorAlgebra(block: RealLinearOpsTensorAlgebra.() -> R): R = public inline fun <R> DoubleLinearOpsTensorAlgebra(block: DoubleLinearOpsTensorAlgebra.() -> R): R =
RealLinearOpsTensorAlgebra().block() DoubleLinearOpsTensorAlgebra().block()