add function diagonalEmbedding with tests
This commit is contained in:
parent
b36281fa39
commit
3e98240b94
@ -283,7 +283,47 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||||
TODO("Alya")
|
val n = diagonalEntries.shape.size
|
||||||
|
if (dim1 == dim2) {
|
||||||
|
throw RuntimeException("Diagonal dimensions cannot be identical $dim1, $dim2")
|
||||||
|
}
|
||||||
|
if (dim1 > n || dim2 > n) {
|
||||||
|
throw RuntimeException("Dimension out of range")
|
||||||
|
}
|
||||||
|
|
||||||
|
var lessDim = dim1
|
||||||
|
var greaterDim = dim2
|
||||||
|
var realOffset = offset
|
||||||
|
if (lessDim > greaterDim) {
|
||||||
|
realOffset *= -1
|
||||||
|
lessDim = greaterDim.also {greaterDim = lessDim}
|
||||||
|
}
|
||||||
|
|
||||||
|
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
|
||||||
|
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||||
|
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||||
|
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||||
|
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||||
|
val resTensor = zeros(resShape)
|
||||||
|
|
||||||
|
for (i in 0 until diagonalEntries.linearStructure.size) {
|
||||||
|
val multiIndex = diagonalEntries.linearStructure.index(i)
|
||||||
|
|
||||||
|
var offset1 = 0
|
||||||
|
var offset2 = abs(realOffset)
|
||||||
|
if (realOffset < 0) {
|
||||||
|
offset1 = offset2.also {offset2 = offset1}
|
||||||
|
}
|
||||||
|
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
|
||||||
|
intArrayOf(multiIndex[n - 1] + offset1) +
|
||||||
|
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||||
|
intArrayOf(multiIndex[n - 1] + offset2) +
|
||||||
|
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||||
|
|
||||||
|
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -115,6 +115,39 @@ class TestDoubleTensorAlgebra {
|
|||||||
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5))
|
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun diagonalEmbedding() = DoubleTensorAlgebra {
|
||||||
|
val tensor1 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||||
|
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
|
val tensor3 = zeros(intArrayOf(2, 3, 4, 5))
|
||||||
|
|
||||||
|
assertTrue(diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
|
||||||
|
intArrayOf(2, 3, 4, 5, 5))
|
||||||
|
assertTrue(diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
|
||||||
|
intArrayOf(2, 3, 4, 6, 6))
|
||||||
|
assertTrue(diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
|
||||||
|
intArrayOf(7, 2, 3, 7, 4))
|
||||||
|
|
||||||
|
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
||||||
|
assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3))
|
||||||
|
assertTrue(diagonal1.buffer.array() contentEquals
|
||||||
|
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0))
|
||||||
|
|
||||||
|
val diagonal1_offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
||||||
|
assertTrue(diagonal1_offset.shape contentEquals intArrayOf(4, 4))
|
||||||
|
assertTrue(diagonal1_offset.buffer.array() contentEquals
|
||||||
|
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0))
|
||||||
|
|
||||||
|
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
||||||
|
assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4))
|
||||||
|
assertTrue(diagonal2.buffer.array() contentEquals
|
||||||
|
doubleArrayOf(
|
||||||
|
0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0,
|
||||||
|
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testContentEqual() = DoubleTensorAlgebra {
|
fun testContentEqual() = DoubleTensorAlgebra {
|
||||||
//TODO()
|
//TODO()
|
||||||
|
Loading…
Reference in New Issue
Block a user