Fixed tests
This commit is contained in:
parent
ea4d6618b4
commit
a09a1c7adc
@ -93,7 +93,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
|
||||
|
||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
val size = this.shape.size
|
||||
val size = this.linearStructure.dim
|
||||
val commonShape = this.shape.sliceArray(0 until size - 2)
|
||||
val (n, m) = this.shape.sliceArray(size - 2 until size)
|
||||
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||
@ -109,11 +109,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
||||
)
|
||||
svdHelper(curMatrix, USV, m, n)
|
||||
}
|
||||
return Triple(resU.transpose(size - 2, size - 1), resS, resV)
|
||||
return Triple(resU.transpose(), resS, resV.transpose())
|
||||
}
|
||||
|
||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||
TODO("ANDREI")
|
||||
TODO()
|
||||
}
|
||||
|
||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||
|
@ -339,19 +339,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
)
|
||||
}
|
||||
|
||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, delta: Double = 1e-5): Boolean {
|
||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
||||
|
||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
|
||||
this.eq(other, eqFunction)
|
||||
|
||||
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
||||
checkShapesCompatible(this, other)
|
||||
val n = this.linearStructure.size
|
||||
|
@ -56,4 +56,5 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
check(shape[n - 1] == shape[n - 2]) {
|
||||
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,15 +20,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
)
|
||||
)
|
||||
|
||||
val expectedShape = intArrayOf(2, 1)
|
||||
val expectedBuffer = doubleArrayOf(
|
||||
-1.0,
|
||||
-7.0
|
||||
val expectedTensor = fromArray(
|
||||
intArrayOf(2, 1),
|
||||
doubleArrayOf(
|
||||
-1.0,
|
||||
-7.0
|
||||
)
|
||||
)
|
||||
val detTensor = tensor.detLU()
|
||||
|
||||
assertTrue { detTensor.shape contentEquals expectedShape }
|
||||
assertTrue { detTensor.buffer.array().epsEqual(expectedBuffer) }
|
||||
assertTrue(detTensor.eq(expectedTensor))
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -43,17 +45,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
)
|
||||
)
|
||||
|
||||
val expectedShape = intArrayOf(2, 2, 2)
|
||||
val expectedBuffer = doubleArrayOf(
|
||||
1.0, 0.0,
|
||||
0.0, 0.5,
|
||||
0.0, 1.0,
|
||||
1.0, -1.0
|
||||
val expectedTensor = fromArray(
|
||||
intArrayOf(2, 2, 2), doubleArrayOf(
|
||||
1.0, 0.0,
|
||||
0.0, 0.5,
|
||||
0.0, 1.0,
|
||||
1.0, -1.0
|
||||
)
|
||||
)
|
||||
|
||||
val invTensor = tensor.invLU()
|
||||
assertTrue { invTensor.shape contentEquals expectedShape }
|
||||
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
|
||||
assertTrue(invTensor.eq(expectedTensor))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -80,7 +82,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
assertTrue { q.shape contentEquals shape }
|
||||
assertTrue { r.shape contentEquals shape }
|
||||
|
||||
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
|
||||
assertTrue((q dot r).eq(tensor))
|
||||
|
||||
//todo check orthogonality/upper triang.
|
||||
}
|
||||
@ -106,7 +108,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
assertTrue { u.shape contentEquals shape }
|
||||
|
||||
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
|
||||
assertTrue((p dot tensor).eq(l dot u))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -116,8 +118,8 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
val res = svd1d(tensor2)
|
||||
|
||||
assertTrue(res.shape contentEquals intArrayOf(2))
|
||||
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01}
|
||||
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01}
|
||||
assertTrue { abs(abs(res.buffer.array()[res.bufferStart]) - 0.386) < 0.01 }
|
||||
assertTrue { abs(abs(res.buffer.array()[res.bufferStart + 1]) - 0.922) < 0.01 }
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -130,11 +132,12 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
|
||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV)
|
||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||
assertTrue(tensor.eq(tensorSVD))
|
||||
}
|
||||
|
||||
@Test @Ignore
|
||||
@Test
|
||||
@Ignore
|
||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||
val tensorSigma = tensor + tensor.transpose(1, 2)
|
||||
@ -146,31 +149,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
||||
|
||||
}
|
||||
|
||||
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
|
||||
return abs(this - other) < eps
|
||||
}
|
||||
|
||||
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
||||
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
||||
if (abs(elem1 - elem2) > eps) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
|
||||
val svd = tensor.svd()
|
||||
|
||||
val tensorSVD = svd.first
|
||||
.dot(
|
||||
diagonalEmbedding(svd.second, 0, 0, 1)
|
||||
.dot(svd.third.transpose(0, 1))
|
||||
diagonalEmbedding(svd.second)
|
||||
.dot(svd.third.transpose())
|
||||
)
|
||||
|
||||
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
|
||||
assertTrue { abs(x1 - x2) < epsilon }
|
||||
}
|
||||
assertTrue(tensor.eq(tensorSVD, epsilon))
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user