add function diagonalEmbedding with tests
This commit is contained in:
parent
b36281fa39
commit
3e98240b94
@ -283,7 +283,47 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||
TODO("Alya")
|
||||
val n = diagonalEntries.shape.size
|
||||
if (dim1 == dim2) {
|
||||
throw RuntimeException("Diagonal dimensions cannot be identical $dim1, $dim2")
|
||||
}
|
||||
if (dim1 > n || dim2 > n) {
|
||||
throw RuntimeException("Dimension out of range")
|
||||
}
|
||||
|
||||
var lessDim = dim1
|
||||
var greaterDim = dim2
|
||||
var realOffset = offset
|
||||
if (lessDim > greaterDim) {
|
||||
realOffset *= -1
|
||||
lessDim = greaterDim.also {greaterDim = lessDim}
|
||||
}
|
||||
|
||||
val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||
val resTensor = zeros(resShape)
|
||||
|
||||
for (i in 0 until diagonalEntries.linearStructure.size) {
|
||||
val multiIndex = diagonalEntries.linearStructure.index(i)
|
||||
|
||||
var offset1 = 0
|
||||
var offset2 = abs(realOffset)
|
||||
if (realOffset < 0) {
|
||||
offset1 = offset2.also {offset2 = offset1}
|
||||
}
|
||||
val diagonalMultiIndex = multiIndex.slice(0 until lessDim).toIntArray() +
|
||||
intArrayOf(multiIndex[n - 1] + offset1) +
|
||||
multiIndex.slice(lessDim until greaterDim - 1).toIntArray() +
|
||||
intArrayOf(multiIndex[n - 1] + offset2) +
|
||||
multiIndex.slice(greaterDim - 1 until n - 1).toIntArray()
|
||||
|
||||
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
|
||||
}
|
||||
|
||||
return resTensor
|
||||
}
|
||||
|
||||
|
||||
|
@ -115,6 +115,39 @@ class TestDoubleTensorAlgebra {
|
||||
assertTrue(tensor4.dot(tensor5).shape contentEquals intArrayOf(5, 4, 2, 8, 3, 8, 5))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun diagonalEmbedding() = DoubleTensorAlgebra {
|
||||
val tensor1 = fromArray(intArrayOf(3), doubleArrayOf(10.0, 20.0, 30.0))
|
||||
val tensor2 = fromArray(intArrayOf(2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val tensor3 = zeros(intArrayOf(2, 3, 4, 5))
|
||||
|
||||
assertTrue(diagonalEmbedding(tensor3, 0, 3, 4).shape contentEquals
|
||||
intArrayOf(2, 3, 4, 5, 5))
|
||||
assertTrue(diagonalEmbedding(tensor3, 1, 3, 4).shape contentEquals
|
||||
intArrayOf(2, 3, 4, 6, 6))
|
||||
assertTrue(diagonalEmbedding(tensor3, 2, 0, 3).shape contentEquals
|
||||
intArrayOf(7, 2, 3, 7, 4))
|
||||
|
||||
val diagonal1 = diagonalEmbedding(tensor1, 0, 1, 0)
|
||||
assertTrue(diagonal1.shape contentEquals intArrayOf(3, 3))
|
||||
assertTrue(diagonal1.buffer.array() contentEquals
|
||||
doubleArrayOf(10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0))
|
||||
|
||||
val diagonal1_offset = diagonalEmbedding(tensor1, 1, 1, 0)
|
||||
assertTrue(diagonal1_offset.shape contentEquals intArrayOf(4, 4))
|
||||
assertTrue(diagonal1_offset.buffer.array() contentEquals
|
||||
doubleArrayOf(0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 0.0, 30.0, 0.0))
|
||||
|
||||
val diagonal2 = diagonalEmbedding(tensor2, 1, 0, 2)
|
||||
assertTrue(diagonal2.shape contentEquals intArrayOf(4, 2, 4))
|
||||
assertTrue(diagonal2.buffer.array() contentEquals
|
||||
doubleArrayOf(
|
||||
0.0, 1.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0,
|
||||
0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 5.0, 0.0,
|
||||
0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 6.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testContentEqual() = DoubleTensorAlgebra {
|
||||
//TODO()
|
||||
|
Loading…
Reference in New Issue
Block a user