Negative indices
This commit is contained in:
parent
2bbe10e41c
commit
174f6566e1
@ -28,7 +28,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public operator fun TensorType.get(i: Int): TensorType
|
public operator fun TensorType.get(i: Int): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
public fun TensorType.transpose(i: Int = -2, j: Int = -1): TensorType
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/tensor_view.html
|
//https://pytorch.org/docs/stable/tensor_view.html
|
||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
@ -40,7 +40,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
diagonalEntries: TensorType,
|
diagonalEntries: TensorType,
|
||||||
offset: Int = 0, dim1: Int = 0, dim2: Int = 1
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
): TensorType
|
): TensorType
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -198,19 +198,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
|
override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor {
|
||||||
checkTranspose(this.dimension, i, j)
|
val ii = minusIndex(i)
|
||||||
|
val jj = minusIndex(j)
|
||||||
|
checkTranspose(this.dimension, ii, jj)
|
||||||
val n = this.linearStructure.size
|
val n = this.linearStructure.size
|
||||||
val resBuffer = DoubleArray(n)
|
val resBuffer = DoubleArray(n)
|
||||||
|
|
||||||
val resShape = this.shape.copyOf()
|
val resShape = this.shape.copyOf()
|
||||||
resShape[i] = resShape[j].also { resShape[j] = resShape[i] }
|
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
|
||||||
|
|
||||||
val resTensor = DoubleTensor(resShape, resBuffer)
|
val resTensor = DoubleTensor(resShape, resBuffer)
|
||||||
|
|
||||||
for (offset in 0 until n) {
|
for (offset in 0 until n) {
|
||||||
val oldMultiIndex = this.linearStructure.index(offset)
|
val oldMultiIndex = this.linearStructure.index(offset)
|
||||||
val newMultiIndex = oldMultiIndex.copyOf()
|
val newMultiIndex = oldMultiIndex.copyOf()
|
||||||
newMultiIndex[i] = newMultiIndex[j].also { newMultiIndex[j] = newMultiIndex[i] }
|
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
|
||||||
|
|
||||||
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
|
val linearIndex = resTensor.linearStructure.offset(newMultiIndex)
|
||||||
resTensor.buffer.array()[linearIndex] =
|
resTensor.buffer.array()[linearIndex] =
|
||||||
@ -283,16 +285,18 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
override fun diagonalEmbedding(diagonalEntries: DoubleTensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor {
|
||||||
|
val d1 = minusIndexFrom(diagonalEntries.linearStructure.dim + 1, dim1)
|
||||||
|
val d2 = minusIndexFrom(diagonalEntries.linearStructure.dim + 1, dim2)
|
||||||
val n = diagonalEntries.shape.size
|
val n = diagonalEntries.shape.size
|
||||||
if (dim1 == dim2) {
|
if (d1 == d2) {
|
||||||
throw RuntimeException("Diagonal dimensions cannot be identical $dim1, $dim2")
|
throw RuntimeException("Diagonal dimensions cannot be identical $d1, $d2")
|
||||||
}
|
}
|
||||||
if (dim1 > n || dim2 > n) {
|
if (d1 > n || d2 > n) {
|
||||||
throw RuntimeException("Dimension out of range")
|
throw RuntimeException("Dimension out of range")
|
||||||
}
|
}
|
||||||
|
|
||||||
var lessDim = dim1
|
var lessDim = d1
|
||||||
var greaterDim = dim2
|
var greaterDim = d2
|
||||||
var realOffset = offset
|
var realOffset = offset
|
||||||
if (lessDim > greaterDim) {
|
if (lessDim > greaterDim) {
|
||||||
realOffset *= -1
|
realOffset *= -1
|
||||||
|
@ -73,6 +73,9 @@ public class TensorLinearStructure(public val shape: IntArray)
|
|||||||
public val size: Int
|
public val size: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
|
public val dim: Int
|
||||||
|
get() = shape.size
|
||||||
|
|
||||||
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
|
public fun indices(): Sequence<IntArray> = (0 until size).asSequence().map {
|
||||||
index(it)
|
index(it)
|
||||||
}
|
}
|
||||||
|
@ -40,3 +40,13 @@ internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
|||||||
val u = Random(seed)
|
val u = Random(seed)
|
||||||
return (0 until n).map { sqrt(-2.0 * ln(u.nextDouble())) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
|
return (0 until n).map { sqrt(-2.0 * ln(u.nextDouble())) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun minusIndexFrom(n: Int, i: Int) : Int = if (i >= 0) i else {
|
||||||
|
val ii = n + i
|
||||||
|
check(ii >= 0) {
|
||||||
|
"Out of bound index $i for tensor of dim $n"
|
||||||
|
}
|
||||||
|
ii
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T> BufferedTensor<T>.minusIndex(i: Int): Int = minusIndexFrom(this.linearStructure.dim, i)
|
@ -130,7 +130,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
fun testBatchedSVD() = DoubleLinearOpsTensorAlgebra {
|
||||||
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
|
val tensor = randNormal(intArrayOf(7, 5, 3), 0)
|
||||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||||
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS,0,1,2) dot tensorV)
|
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV)
|
||||||
assertTrue(tensor.eq(tensorSVD))
|
assertTrue(tensor.eq(tensorSVD))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ class TestDoubleLinearOpsTensorAlgebra {
|
|||||||
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
val tensor = randNormal(shape = intArrayOf(5, 2, 2), 0)
|
||||||
val tensorSigma = tensor + tensor.transpose(1, 2)
|
val tensorSigma = tensor + tensor.transpose(1, 2)
|
||||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||||
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS, 0,1,2) dot tensorV.transpose(1, 2))
|
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose(1, 2))
|
||||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,8 +35,8 @@ class TestDoubleTensorAlgebra {
|
|||||||
fun transpose1x2x3() = DoubleTensorAlgebra {
|
fun transpose1x2x3() = DoubleTensorAlgebra {
|
||||||
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor = fromArray(intArrayOf(1, 2, 3), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val res01 = tensor.transpose(0, 1)
|
val res01 = tensor.transpose(0, 1)
|
||||||
val res02 = tensor.transpose(0, 2)
|
val res02 = tensor.transpose(-3, 2)
|
||||||
val res12 = tensor.transpose(1, 2)
|
val res12 = tensor.transpose()
|
||||||
|
|
||||||
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
assertTrue(res01.shape contentEquals intArrayOf(2, 1, 3))
|
||||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||||
|
Loading…
Reference in New Issue
Block a user