KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
3 changed files with 99 additions and 4 deletions
Showing only changes of commit 97f148c175 - Show all commits

View File

@ -14,7 +14,7 @@ public interface LinearOpsTensorAlgebra<T, TensorType : TensorStructure<T>, Inde
public fun TensorType.cholesky(): TensorType
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
public fun TensorType.qr(): TensorType
public fun TensorType.qr(): Pair<TensorType, TensorType>
//https://pytorch.org/docs/stable/generated/torch.lu.html
public fun TensorType.lu(): Pair<TensorType, IndexTensorType>

View File

@ -42,7 +42,8 @@ public class DoubleLinearOpsTensorAlgebra :
//todo checks
checkSquareMatrix(luTensor.shape)
check(
luTensor.shape.dropLast(1).toIntArray() contentEquals pivotsTensor.shape
luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() ||
luTensor.shape.last() == pivotsTensor.shape.last() - 1
) { "Bed shapes ((" } //todo rewrite
val n = luTensor.shape.last()
@ -76,8 +77,56 @@ public class DoubleLinearOpsTensorAlgebra :
return lTensor
}
override fun DoubleTensor.qr(): DoubleTensor {
TODO("ANDREI")
private fun MutableStructure1D<Double>.dot(other: MutableStructure1D<Double>): Double {
var res = 0.0
for (i in 0 until size) {
res += this[i] * other[i]
}
return res
}
private fun MutableStructure1D<Double>.l2Norm(): Double {
var squareSum = 0.0
for (i in 0 until size) {
squareSum += this[i] * this[i]
}
return sqrt(squareSum)
}
fun qrHelper(
matrix: MutableStructure2D<Double>,
q: MutableStructure2D<Double>,
r: MutableStructure2D<Double>
) {
//todo check square
val n = matrix.colNum
for (j in 0 until n) {
val v = matrix.columns[j]
if (j > 0) {
for (i in 0 until j) {
r[i, j] = q.columns[i].dot(matrix.columns[j])
for (k in 0 until n) {
v[k] = v[k] - r[i, j] * q.columns[i][k]
}
}
}
r[j, j] = v.l2Norm()
for (i in 0 until n) {
q[i, j] = v[i] / r[j, j]
}
}
}
override fun DoubleTensor.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape)
val qTensor = zeroesLike()
val rTensor = zeroesLike()
val seq = matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence())))
for ((matrix, qr) in seq) {
val (q, r) = qr
qrHelper(matrix.as2D(), q.as2D(), r.as2D())
}
return Pair(qTensor, rTensor)
}
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {

View File

@ -76,4 +76,50 @@ class TestDoubleLinearOpsTensorAlgebra {
val b = fromArray(intArrayOf(3), doubleArrayOf(5.5, 2.6, 6.4))
assertEquals(a.dot(b).value(), 59.92)
}
@Test
fun testQR() = DoubleLinearOpsTensorAlgebra {
val shape = intArrayOf(2, 2, 2)
val buffer = doubleArrayOf(
1.0, 3.0,
1.0, 2.0,
1.5, 1.0,
10.0, 2.0
)
val tensor = fromArray(shape, buffer)
val (q, r) = tensor.qr()
assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape }
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
//todo check orthogonality/upper triang.
}
@Test
fun testLU() = DoubleLinearOpsTensorAlgebra {
val shape = intArrayOf(2, 2, 2)
val buffer = doubleArrayOf(
1.0, 3.0,
1.0, 2.0,
1.5, 1.0,
10.0, 2.0
)
val tensor = fromArray(shape, buffer)
val (lu, pivots) = tensor.lu()
// todo check lu
val (p, l, u) = luPivot(lu, pivots)
assertTrue { p.shape contentEquals shape }
assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape }
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
}
}