Negative indices

This commit is contained in:
Roland Grinis 2021-04-06 12:07:39 +01:00
parent 2bbe10e41c
commit 174f6566e1
6 changed files with 32 additions and 15 deletions

View File

@ -28,7 +28,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public operator fun TensorType.get(i: Int): TensorType
//https://pytorch.org/docs/stable/generated/torch.transpose.html
public fun TensorType.transpose(i: Int, j: Int): TensorType
public fun TensorType.transpose(i: Int = -2, j: Int = -1): TensorType
//https://pytorch.org/docs/stable/tensor_view.html
public fun TensorType.view(shape: IntArray): TensorType
@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
public fun diagonalEmbedding(
diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = 0, dim2: Int = 1
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TensorType
}

View File

@ -198,19 +198,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
checkTranspose(this.dimension, i, j)
val ii = minusIndex(i)
val jj = minusIndex(j)
checkTranspose(this.dimension, ii, jj)
val n = this.linearStructure.size
val resBuffer = DoubleArray(n)
val resShape = this.shape.copyOf()
resShape[i] = resShape[j].also { resShape[j] = resShape[i] }
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) {
val oldMultiIndex = this.linearStructure.index(offset)
val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
resTensor.buffer.array()[linearIndex] =
@ -283,16 +285,18 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
}
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
val d1 = minusIndexFrom(diagonalEntries.linearStructure.dim + 1, dim1)
val d2 = minusIndexFrom(diagonalEntries.linearStructure.dim + 1, dim2)
val n = diagonalEntries.shape.size
if (dim1 == dim2) {
throw RuntimeException("Diagonal dimensions cannot be identical $dim1, $dim2")
if (d1 == d2) {
throw RuntimeException("Diagonal dimensions cannot be identical $d1, $d2")
}
if (dim1 > n || dim2 > n) {
if (d1 > n || d2 > n) {
throw RuntimeException("Dimension out of range")
}
var lessDim = dim1
var greaterDim = dim2
var lessDim = d1
var greaterDim = d2
var realOffset = offset
if (lessDim > greaterDim) {
realOffset *= -1

View File

@ -73,6 +73,9 @@ public class TensorLinearStructure(public val shape: IntArray)
public val size: Int
get() = shape.reduce(Int::times)
public val dim: Int
get() = shape.size
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
index(it)
}

View File

@ -39,4 +39,14 @@ internal fun Buffer<Double>.array(): DoubleArray = when (this) {
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
val u = Random(seed)
return (0 until n).map { sqrt(-2.0 * ln(u.nextDouble())) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
}
}
internal inline fun minusIndexFrom(n: Int, i: Int) : Int = if (i >= 0) i else {
val ii = n + i
check(ii >= 0) {
"Out of bound index $i for tensor of dim $n"
}
ii
}
internal inline fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.linearStructure.dim, i)

View File

@ -130,7 +130,7 @@ class TestDoubleLinearOpsTensorAlgebra {
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV)
assertTrue(tensor.eq(tensorSVD))
}
@ -139,7 +139,7 @@ class TestDoubleLinearOpsTensorAlgebra {
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
val tensorSigma = tensor + tensor.transpose(1, 2)
val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS, 0,1,2) dot tensorV.transpose(1, 2))
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2))
assertTrue(tensorSigma.eq(tensorSigmaCalc))
}

View File

@ -35,8 +35,8 @@ class TestDoubleTensorAlgebra {
fun transpose1x2x3() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res01 = tensor.transpose(0, 1)
val res02 = tensor.transpose(0, 2)
val res12 = tensor.transpose(1, 2)
val res02 = tensor.transpose(-3, 2)
val res12 = tensor.transpose()
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))