complete qr + test qr and lu
This commit is contained in:
parent
6e85d496f2
commit
2503d35ba8
@ -118,7 +118,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
//todo checks
|
//todo checks
|
||||||
checkSquareMatrix(luTensor.shape)
|
checkSquareMatrix(luTensor.shape)
|
||||||
check(
|
check(
|
||||||
luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape
|
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
|
||||||
|
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||||
) { "Bed shapes ((" } //todo rewrite
|
) { "Bed shapes ((" } //todo rewrite
|
||||||
|
|
||||||
val n = luTensor.shape.last()
|
val n = luTensor.shape.last()
|
||||||
@ -173,16 +174,56 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
return lTensor
|
return lTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun matrixQR(
|
private fun MutableStructure1D<Double>.dot(other: MutableStructure1D<Double>): Double {
|
||||||
matrix: Structure2D<Double>,
|
var res = 0.0
|
||||||
|
for (i in 0 until size) {
|
||||||
|
res += this[i] * other[i]
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun MutableStructure1D<Double>.l2Norm(): Double {
|
||||||
|
var squareSum = 0.0
|
||||||
|
for (i in 0 until size) {
|
||||||
|
squareSum += this[i] * this[i]
|
||||||
|
}
|
||||||
|
return sqrt(squareSum)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun qrHelper(
|
||||||
|
matrix: MutableStructure2D<Double>,
|
||||||
q: MutableStructure2D<Double>,
|
q: MutableStructure2D<Double>,
|
||||||
r: MutableStructure2D<Double>
|
r: MutableStructure2D<Double>
|
||||||
) {
|
) {
|
||||||
|
//todo check square
|
||||||
|
val n = matrix.colNum
|
||||||
|
for (j in 0 until n) {
|
||||||
|
val v = matrix.columns[j]
|
||||||
|
if (j > 0) {
|
||||||
|
for (i in 0 until j) {
|
||||||
|
r[i, j] = q.columns[i].dot(matrix.columns[j])
|
||||||
|
for (k in 0 until n) {
|
||||||
|
v[k] = v[k] - r[i, j] * q.columns[i][k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r[j, j] = v.l2Norm()
|
||||||
|
for (i in 0 until n) {
|
||||||
|
q[i, j] = v[i] / r[j, j]
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||||
TODO("ANDREI")
|
checkSquareMatrix(shape)
|
||||||
|
val qTensor = zeroesLike()
|
||||||
|
val rTensor = zeroesLike()
|
||||||
|
val seq = matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence())))
|
||||||
|
for ((matrix, qr) in seq) {
|
||||||
|
val (q, r) = qr
|
||||||
|
qrHelper(matrix.as2D(), q.as2D(), r.as2D())
|
||||||
|
}
|
||||||
|
return Pair(qTensor, rTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
|
@ -76,4 +76,50 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
|
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
|
||||||
assertEquals(a.dot(b).value(), 59.92)
|
assertEquals(a.dot(b).value(), 59.92)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testQR() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val shape = intArrayOf(2, 2, 2)
|
||||||
|
val buffer = doubleArrayOf(
|
||||||
|
1.0, 3.0,
|
||||||
|
1.0, 2.0,
|
||||||
|
1.5, 1.0,
|
||||||
|
10.0, 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
val tensor = fromArray(shape, buffer)
|
||||||
|
|
||||||
|
val (q, r) = tensor.qr()
|
||||||
|
|
||||||
|
assertTrue { q.shape contentEquals shape }
|
||||||
|
assertTrue { r.shape contentEquals shape }
|
||||||
|
|
||||||
|
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
|
||||||
|
|
||||||
|
//todo check orthogonality/upper triang.
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLU() = DoubleLinearOpsTensorAlgebra {
|
||||||
|
val shape = intArrayOf(2, 2, 2)
|
||||||
|
val buffer = doubleArrayOf(
|
||||||
|
1.0, 3.0,
|
||||||
|
1.0, 2.0,
|
||||||
|
1.5, 1.0,
|
||||||
|
10.0, 2.0
|
||||||
|
)
|
||||||
|
val tensor = fromArray(shape, buffer)
|
||||||
|
|
||||||
|
val (lu, pivots) = tensor.lu()
|
||||||
|
|
||||||
|
// todo check lu
|
||||||
|
|
||||||
|
val (p, l, u) = luPivot(lu, pivots)
|
||||||
|
|
||||||
|
assertTrue { p.shape contentEquals shape }
|
||||||
|
assertTrue { l.shape contentEquals shape }
|
||||||
|
assertTrue { u.shape contentEquals shape }
|
||||||
|
|
||||||
|
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user