Fixed tests
This commit is contained in:
parent
ea4d6618b4
commit
a09a1c7adc
@ -93,7 +93,7 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
|
|
||||||
|
|
||||||
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val size = this.shape.size
|
val size = this.linearStructure.dim
|
||||||
val commonShape = this.shape.sliceArray(0 until size - 2)
|
val commonShape = this.shape.sliceArray(0 until size - 2)
|
||||||
val (n, m) = this.shape.sliceArray(size - 2 until size)
|
val (n, m) = this.shape.sliceArray(size - 2 until size)
|
||||||
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
|
val resU = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||||
@ -109,11 +109,11 @@ public class DoubleLinearOpsTensorAlgebra :
|
|||||||
)
|
)
|
||||||
svdHelper(curMatrix, USV, m, n)
|
svdHelper(curMatrix, USV, m, n)
|
||||||
}
|
}
|
||||||
return Triple(resU.transpose(size - 2, size - 1), resS, resV)
|
return Triple(resU.transpose(), resS, resV.transpose())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
override fun DoubleTensor.symEig(eigenvectors: Boolean): Pair<DoubleTensor, DoubleTensor> {
|
||||||
TODO("ANDREI")
|
TODO()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.detLU(): DoubleTensor {
|
public fun DoubleTensor.detLU(): DoubleTensor {
|
||||||
|
@ -339,19 +339,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, delta: Double = 1e-5): Boolean {
|
|
||||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
||||||
|
|
||||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean =
|
|
||||||
this.eq(other, eqFunction)
|
|
||||||
|
|
||||||
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
||||||
checkShapesCompatible(this, other)
|
checkShapesCompatible(this, other)
|
||||||
val n = this.linearStructure.size
|
val n = this.linearStructure.size
|
||||||
|
@ -57,3 +57,4 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
|
"Tensor must be batches of square matrices, but they are ${shape[n - 1]} by ${shape[n - 1]} matrices"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,15 +20,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
val expectedShape = intArrayOf(2, 1)
|
val expectedTensor = fromArray(
|
||||||
val expectedBuffer = doubleArrayOf(
|
intArrayOf(2, 1),
|
||||||
|
doubleArrayOf(
|
||||||
-1.0,
|
-1.0,
|
||||||
-7.0
|
-7.0
|
||||||
)
|
)
|
||||||
|
)
|
||||||
val detTensor = tensor.detLU()
|
val detTensor = tensor.detLU()
|
||||||
|
|
||||||
assertTrue { detTensor.shape contentEquals expectedShape }
|
assertTrue(detTensor.eq(expectedTensor))
|
||||||
assertTrue { detTensor.buffer.array().epsEqual(expectedBuffer) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -43,17 +45,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
val expectedShape = intArrayOf(2, 2, 2)
|
val expectedTensor = fromArray(
|
||||||
val expectedBuffer = doubleArrayOf(
|
intArrayOf(2, 2, 2), doubleArrayOf(
|
||||||
1.0, 0.0,
|
1.0, 0.0,
|
||||||
0.0, 0.5,
|
0.0, 0.5,
|
||||||
0.0, 1.0,
|
0.0, 1.0,
|
||||||
1.0, -1.0
|
1.0, -1.0
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
val invTensor = tensor.invLU()
|
val invTensor = tensor.invLU()
|
||||||
assertTrue { invTensor.shape contentEquals expectedShape }
|
assertTrue(invTensor.eq(expectedTensor))
|
||||||
assertTrue { invTensor.buffer.array().epsEqual(expectedBuffer) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -80,7 +82,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue { q.shape contentEquals shape }
|
assertTrue { q.shape contentEquals shape }
|
||||||
assertTrue { r.shape contentEquals shape }
|
assertTrue { r.shape contentEquals shape }
|
||||||
|
|
||||||
assertTrue { q.dot(r).buffer.array().epsEqual(buffer) }
|
assertTrue((q dot r).eq(tensor))
|
||||||
|
|
||||||
//todo check orthogonality/upper triang.
|
//todo check orthogonality/upper triang.
|
||||||
}
|
}
|
||||||
@ -106,7 +108,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
assertTrue { l.shape contentEquals shape }
|
assertTrue { l.shape contentEquals shape }
|
||||||
assertTrue { u.shape contentEquals shape }
|
assertTrue { u.shape contentEquals shape }
|
||||||
|
|
||||||
assertTrue { p.dot(tensor).buffer.array().epsEqual(l.dot(u).buffer.array()) }
|
assertTrue((p dot tensor).eq(l dot u))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -130,11 +132,12 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
|
val tensor = randNormal(intArrayOf(1, 15, 4, 7, 5, 3), 0)
|
||||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV)
|
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose())
|
||||||
assertTrue(tensor.eq(tensorSVD))
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test @Ignore
|
@Test
|
||||||
|
@Ignore
|
||||||
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSymEig() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||||
val tensorSigma = tensor + tensor.transpose(1, 2)
|
val tensorSigma = tensor + tensor.transpose(1, 2)
|
||||||
@ -146,31 +149,17 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun Double.epsEqual(other: Double, eps: Double = 1e-5): Boolean {
|
|
||||||
return abs(this - other) < eps
|
|
||||||
}
|
|
||||||
|
|
||||||
private inline fun DoubleArray.epsEqual(other: DoubleArray, eps: Double = 1e-5): Boolean {
|
|
||||||
for ((elem1, elem2) in this.asSequence().zip(other.asSequence())) {
|
|
||||||
if (abs(elem1 - elem2) > eps) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
|
private inline fun DoubleLinearOpsTensorAlgebra.testSVDFor(tensor: DoubleTensor, epsilon: Double = 1e-10): Unit {
|
||||||
val svd = tensor.svd()
|
val svd = tensor.svd()
|
||||||
|
|
||||||
val tensorSVD = svd.first
|
val tensorSVD = svd.first
|
||||||
.dot(
|
.dot(
|
||||||
diagonalEmbedding(svd.second, 0, 0, 1)
|
diagonalEmbedding(svd.second)
|
||||||
.dot(svd.third.transpose(0, 1))
|
.dot(svd.third.transpose())
|
||||||
)
|
)
|
||||||
|
|
||||||
for ((x1, x2) in tensor.buffer.array() zip tensorSVD.buffer.array()) {
|
assertTrue(tensor.eq(tensorSVD, epsilon))
|
||||||
assertTrue { abs(x1 - x2) < epsilon }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user