Fixed tests

This commit is contained in:
Roland Grinis 2021-04-09 07:33:25 +01:00
parent ea4d6618b4
commit a09a1c7adc
4 changed files with 31 additions and 48 deletions

View File

@ -93,7 +93,7 @@ public class DoubleLinearOpsTensorAlgebra :
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = this.shape.size
val size = this.linearStructure.dim
val commonShape = this.shape.sliceArray(0 until size - 2)
val (n, m) = this.shape.sliceArray(size - 2 until size)
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
@ -109,11 +109,11 @@ public class DoubleLinearOpsTensorAlgebra :
)
svdHelper(curMatrix, USV, m, n)
}
return Triple(resU.transpose(size - 2, size - 1), resS, resV)
return Triple(resU.transpose(), resS, resV.transpose())
}
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
TODO("ANDREI")
TODO()
}
public fun DoubleTensor.detLU(): DoubleTensor {

View File

@ -339,19 +339,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
)
}
public fun DoubleTensor.contentEquals(other: DoubleTensor, delta: Double = 1e-5): Boolean {
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
}
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
return this.eq(other) { x, y -> abs(x - y) < delta }
}
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
this.eq(other, eqFunction)
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
checkShapesCompatible(this, other)
val n = this.linearStructure.size

View File

@ -56,4 +56,5 @@ internal inline fun <T, TensorType : TensorStructure<T>,
check(shape[n - 1] == shape[n - 2]) {
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
}
}
}

View File

@ -20,15 +20,17 @@ class TestDoubleLinearOpsTensorAlgebra {
)
)
val expectedShape = intArrayOf(2, 1)
val expectedBuffer = doubleArrayOf(
-1.0,
-7.0
val expectedTensor = fromArray(
intArrayOf(2, 1),
doubleArrayOf(
-1.0,
-7.0
)
)
val detTensor = tensor.detLU()
assertTrue { detTensor.shape contentEquals expectedShape }
assertTrue { detTensor.buffer.array().epsEqual(expectedBuffer) }
assertTrue(detTensor.eq(expectedTensor))
}
@Test
@ -43,17 +45,17 @@ class TestDoubleLinearOpsTensorAlgebra {
)
)
val expectedShape = intArrayOf(2, 2, 2)
val expectedBuffer = doubleArrayOf(
1.0, 0.0,
0.0, 0.5,
0.0, 1.0,
1.0, -1.0
val expectedTensor = fromArray(
intArrayOf(2, 2, 2), doubleArrayOf(
1.0, 0.0,
0.0, 0.5,
0.0, 1.0,
1.0, -1.0
)
)
val invTensor = tensor.invLU()
assertTrue { invTensor.shape contentEquals expectedShape }
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
assertTrue(invTensor.eq(expectedTensor))
}
@Test
@ -80,7 +82,7 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape }
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
assertTrue((q dot r).eq(tensor))
//todo check orthogonality/upper triang.
}
@ -106,7 +108,7 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape }
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
assertTrue((p dot tensor).eq(l dot u))
}
@Test
@ -116,8 +118,8 @@ class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2)
assertTrue(res.shape contentEquals intArrayOf(2))
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01}
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01}
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 }
}
@Test
@ -130,11 +132,12 @@ class TestDoubleLinearOpsTensorAlgebra {
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV)
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD))
}
@Test @Ignore
@Test
@Ignore
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
val tensorSigma = tensor + tensor.transpose(1, 2)
@ -146,31 +149,17 @@ class TestDoubleLinearOpsTensorAlgebra {
}
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
return abs(this - other) < eps
}
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
}
}
return true
}
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
val svd = tensor.svd()
val tensorSVD = svd.first
.dot(
diagonalEmbedding(svd.second, 0, 0, 1)
.dot(svd.third.transpose(0, 1))
diagonalEmbedding(svd.second)
.dot(svd.third.transpose())
)
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
assertTrue { abs(x1 - x2) < epsilon }
}
assertTrue(tensor.eq(tensorSVD, epsilon))
}