Fixed tests

This commit is contained in:
Roland Grinis 2021-04-09 07:33:25 +01:00
parent ea4d6618b4
commit a09a1c7adc
4 changed files with 31 additions and 48 deletions

View File

@ -93,7 +93,7 @@ public class DoubleLinearOpsTensorAlgebra :
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = this.shape.size val size = this.linearStructure.dim
val commonShape = this.shape.sliceArray(0 until size - 2) val commonShape = this.shape.sliceArray(0 until size - 2)
val (n, m) = this.shape.sliceArray(size - 2 until size) val (n, m) = this.shape.sliceArray(size - 2 until size)
val resU = zeros(commonShape + intArrayOf(min(n, m), n)) val resU = zeros(commonShape + intArrayOf(min(n, m), n))
@ -109,11 +109,11 @@ public class DoubleLinearOpsTensorAlgebra :
) )
svdHelper(curMatrix, USV, m, n) svdHelper(curMatrix, USV, m, n)
} }
return Triple(resU.transpose(size - 2, size - 1), resS, resV) return Triple(resU.transpose(), resS, resV.transpose())
} }
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> { override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
TODO("ANDREI") TODO()
} }
public fun DoubleTensor.detLU(): DoubleTensor { public fun DoubleTensor.detLU(): DoubleTensor {

View File

@ -339,19 +339,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
) )
} }
public fun DoubleTensor.contentEquals(other: DoubleTensor, delta: Double = 1e-5): Boolean {
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
}
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean { public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
return this.eq(other) { x, y -> abs(x - y) < delta } return this.eq(other) { x, y -> abs(x - y) < delta }
} }
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5) public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
this.eq(other, eqFunction)
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean { private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
checkShapesCompatible(this, other) checkShapesCompatible(this, other)
val n = this.linearStructure.size val n = this.linearStructure.size

View File

@ -57,3 +57,4 @@ internal inline fun <T, TensorType : TensorStructure<T>,
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices" "Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
} }
} }

View File

@ -20,15 +20,17 @@ class TestDoubleLinearOpsTensorAlgebra {
) )
) )
val expectedShape = intArrayOf(2, 1) val expectedTensor = fromArray(
val expectedBuffer = doubleArrayOf( intArrayOf(2, 1),
-1.0, doubleArrayOf(
-7.0 -1.0,
-7.0
)
) )
val detTensor = tensor.detLU() val detTensor = tensor.detLU()
assertTrue { detTensor.shape contentEquals expectedShape } assertTrue(detTensor.eq(expectedTensor))
assertTrue { detTensor.buffer.array().epsEqual(expectedBuffer) }
} }
@Test @Test
@ -43,17 +45,17 @@ class TestDoubleLinearOpsTensorAlgebra {
) )
) )
val expectedShape = intArrayOf(2, 2, 2) val expectedTensor = fromArray(
val expectedBuffer = doubleArrayOf( intArrayOf(2, 2, 2), doubleArrayOf(
1.0, 0.0, 1.0, 0.0,
0.0, 0.5, 0.0, 0.5,
0.0, 1.0, 0.0, 1.0,
1.0, -1.0 1.0, -1.0
)
) )
val invTensor = tensor.invLU() val invTensor = tensor.invLU()
assertTrue { invTensor.shape contentEquals expectedShape } assertTrue(invTensor.eq(expectedTensor))
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
} }
@Test @Test
@ -80,7 +82,7 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { q.shape contentEquals shape } assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape } assertTrue { r.shape contentEquals shape }
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) } assertTrue((q dot r).eq(tensor))
//todo check orthogonality/upper triang. //todo check orthogonality/upper triang.
} }
@ -106,7 +108,7 @@ class TestDoubleLinearOpsTensorAlgebra {
assertTrue { l.shape contentEquals shape } assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape } assertTrue { u.shape contentEquals shape }
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) } assertTrue((p dot tensor).eq(l dot u))
} }
@Test @Test
@ -116,8 +118,8 @@ class TestDoubleLinearOpsTensorAlgebra {
val res = svd1d(tensor2) val res = svd1d(tensor2)
assertTrue(res.shape contentEquals intArrayOf(2)) assertTrue(res.shape contentEquals intArrayOf(2))
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01} assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01 }
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01} assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 }
} }
@Test @Test
@ -130,11 +132,12 @@ class TestDoubleLinearOpsTensorAlgebra {
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra { fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0) val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV) val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD)) assertTrue(tensor.eq(tensorSVD))
} }
@Test @Ignore @Test
@Ignore
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra { fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0) val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
val tensorSigma = tensor + tensor.transpose(1, 2) val tensorSigma = tensor + tensor.transpose(1, 2)
@ -146,31 +149,17 @@ class TestDoubleLinearOpsTensorAlgebra {
} }
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
return abs(this - other) < eps
}
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
if (abs(elem1 - elem2) > eps) {
return false
}
}
return true
}
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit { private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
val svd = tensor.svd() val svd = tensor.svd()
val tensorSVD = svd.first val tensorSVD = svd.first
.dot( .dot(
diagonalEmbedding(svd.second, 0, 0, 1) diagonalEmbedding(svd.second)
.dot(svd.third.transpose(0, 1)) .dot(svd.third.transpose())
) )
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) { assertTrue(tensor.eq(tensorSVD, epsilon))
assertTrue { abs(x1 - x2) < epsilon }
}
} }