KMP library for tensors #300
@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
|
||||
public fun TensorType.cholesky(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||
public fun TensorType.qr(): TensorType
|
||||
public fun TensorType.qr(): Pair<TensorType, TensorType>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>
|
||||
|
@ -42,7 +42,8 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
//todo checks
|
||||
checkSquareMatrix(luTensor.shape)
|
||||
check(
|
||||
luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape
|
||||
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
|
||||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
|
||||
) { "Bed shapes ((" } //todo rewrite
|
||||
|
||||
val n = luTensor.shape.last()
|
||||
@ -76,8 +77,56 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
return lTensor
|
||||
}
|
||||
|
||||
override fun DoubleTensor.qr(): DoubleTensor {
|
||||
TODO("ANDREI")
|
||||
private fun MutableStructure1D<Double>.dot(other: MutableStructure1D<Double>): Double {
|
||||
var res = 0.0
|
||||
for (i in 0 until size) {
|
||||
res += this[i] * other[i]
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
private fun MutableStructure1D<Double>.l2Norm(): Double {
|
||||
var squareSum = 0.0
|
||||
for (i in 0 until size) {
|
||||
squareSum += this[i] * this[i]
|
||||
}
|
||||
return sqrt(squareSum)
|
||||
}
|
||||
|
||||
fun qrHelper(
|
||||
matrix: MutableStructure2D<Double>,
|
||||
q: MutableStructure2D<Double>,
|
||||
r: MutableStructure2D<Double>
|
||||
) {
|
||||
//todo check square
|
||||
val n = matrix.colNum
|
||||
for (j in 0 until n) {
|
||||
val v = matrix.columns[j]
|
||||
if (j > 0) {
|
||||
for (i in 0 until j) {
|
||||
r[i, j] = q.columns[i].dot(matrix.columns[j])
|
||||
for (k in 0 until n) {
|
||||
v[k] = v[k] - r[i, j] * q.columns[i][k]
|
||||
}
|
||||
}
|
||||
}
|
||||
r[j, j] = v.l2Norm()
|
||||
for (i in 0 until n) {
|
||||
q[i, j] = v[i] / r[j, j]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(shape)
|
||||
val qTensor = zeroesLike()
|
||||
val rTensor = zeroesLike()
|
||||
val seq = matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence())))
|
||||
for ((matrix, qr) in seq) {
|
||||
val (q, r) = qr
|
||||
qrHelper(matrix.as2D(), q.as2D(), r.as2D())
|
||||
}
|
||||
return Pair(qTensor, rTensor)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
|
@ -76,4 +76,50 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
|
||||
assertEquals(a.dot(b).value(), 59.92)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testQR() = DoubleLinearOpsTensorAlgebra {
|
||||
val shape = intArrayOf(2, 2, 2)
|
||||
val buffer = doubleArrayOf(
|
||||
1.0, 3.0,
|
||||
1.0, 2.0,
|
||||
1.5, 1.0,
|
||||
10.0, 2.0
|
||||
)
|
||||
|
||||
val tensor = fromArray(shape, buffer)
|
||||
|
||||
val (q, r) = tensor.qr()
|
||||
|
||||
assertTrue { q.shape contentEquals shape }
|
||||
assertTrue { r.shape contentEquals shape }
|
||||
|
||||
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
|
||||
|
||||
//todo check orthogonality/upper triang.
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLU() = DoubleLinearOpsTensorAlgebra {
|
||||
val shape = intArrayOf(2, 2, 2)
|
||||
val buffer = doubleArrayOf(
|
||||
1.0, 3.0,
|
||||
1.0, 2.0,
|
||||
1.5, 1.0,
|
||||
10.0, 2.0
|
||||
)
|
||||
val tensor = fromArray(shape, buffer)
|
||||
|
||||
val (lu, pivots) = tensor.lu()
|
||||
|
||||
// todo check lu
|
||||
|
||||
val (p, l, u) = luPivot(lu, pivots)
|
||||
|
||||
assertTrue { p.shape contentEquals shape }
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
assertTrue { u.shape contentEquals shape }
|
||||
|
||||
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user