Restrict tensor dot ot vectors and matrices only. Introduce bdot to Double TensorAlgebra for broadcasting operations.

This commit is contained in:
Alexander Nozik 2022-08-03 18:10:44 +03:00
parent 9456217935
commit 5402ba47c9
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
5 changed files with 52 additions and 24 deletions

View File

@ -7,8 +7,9 @@
- Algebra now has an obligatory `bufferFactory` (#477). - Algebra now has an obligatory `bufferFactory` (#477).
### Changed ### Changed
- Kotlin 1.7 - Kotlin 1.7.20
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style - `LazyStructure` `deffered` -> `async` to comply with coroutines code style
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `bdot` operation is added to `DoubleTensorAlgebra`.
### Deprecated ### Deprecated

View File

@ -213,16 +213,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional, * 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
* the matrix-vector product is returned. * the matrix-vector product is returned.
* *
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), * Otherwise, throw an exception.
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
* multiple and removed after.
* The non-matrix (i.e., batch) dimensions are broadcast (and thus must be broadcastable).
* For example, if `input` is a (j &times; 1 &times; n &times; n) tensor and `other` is a
* (k &times; n &times; n) tensor, out will be a (j &times; k &times; n &times; n) tensor.
*
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
* *
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return a mathematical product of two tensors. * @return a mathematical product of two tensors.

View File

@ -381,7 +381,36 @@ public open class DoubleTensorAlgebra :
override fun Tensor<Double>.viewAs(other: StructureND<Double>): DoubleTensor = override fun Tensor<Double>.viewAs(other: StructureND<Double>): DoubleTensor =
tensor.view(other.shape) tensor.view(other.shape)
override infix fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor { /**
* Broadcasting Matrix product of two tensors.
*
* The behavior depends on the dimensionality of the tensors as follows:
* 1. If both tensors are 1-dimensional, the dot product (scalar) is returned.
*
* 2. If both arguments are 2-dimensional, the matrix-matrix product is returned.
*
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
* After the matrix multiply, depending on the implementation the prepended dimension might be removed.
*
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
* the matrix-vector product is returned.
*
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2),
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
* multiple and removed after.
* The non-matrix (i.e., batch) dimensions are broadcast (and thus must be broadcastable).
* For example, if `input` is a (j &times; 1 &times; n &times; n) tensor and `other` is a
* (k &times; n &times; n) tensor, out will be a (j &times; k &times; n &times; n) tensor.
*
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
*
* @param other tensor to be multiplied.
* @return a mathematical product of two tensors.
*/
public infix fun StructureND<Double>.bdot(other: StructureND<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
@ -430,6 +459,11 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor {
return if (dimension in 0..2 && other.dimension in 0..2) bdot(other)
else error("Only vectors and matrices are allowed in non-broadcasting dot operation")
}
override fun diagonalEmbedding( override fun diagonalEmbedding(
diagonalEntries: Tensor<Double>, diagonalEntries: Tensor<Double>,
offset: Int, offset: Int,
@ -587,7 +621,8 @@ public open class DoubleTensorAlgebra :
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val init = foldFunction(DoubleArray(1) { 0.0 }) val init = foldFunction(DoubleArray(1) { 0.0 })
val resTensor = BufferedTensor(resShape, val resTensor = BufferedTensor(resShape,
MutableBuffer.auto(resNumElements) { init }, 0) MutableBuffer.auto(resNumElements) { init }, 0
)
for (index in resTensor.indices) { for (index in resTensor.indices) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
@ -882,7 +917,8 @@ public open class DoubleTensorAlgebra :
return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
} }
override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEigJacobi(maxIteration = 50, epsilon = 1e-15) override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEigJacobi(maxIteration = 50, epsilon = 1e-15)
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
@ -909,7 +945,7 @@ public open class DoubleTensorAlgebra :
val (u, s, v) = tensor.svd(epsilon) val (u, s, v) = tensor.svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
val utv = u.transpose() dot v val utv = u.transpose() bdot v
val n = s.shape.last() val n = s.shape.last()
for (matrix in utv.matrixSequence()) { for (matrix in utv.matrixSequence()) {
matrix.as2D().cleanSym(n) matrix.as2D().cleanSym(n)
@ -951,7 +987,7 @@ public open class DoubleTensorAlgebra :
private fun MutableStructure2D<Double>.jacobiHelper( private fun MutableStructure2D<Double>.jacobiHelper(
maxIteration: Int, maxIteration: Int,
epsilon: Double epsilon: Double,
): Pair<Structure1D<Double>, Structure2D<Double>> { ): Pair<Structure1D<Double>, Structure2D<Double>> {
val n = this.shape[0] val n = this.shape[0]
val A_ = this.copy() val A_ = this.copy()

View File

@ -115,7 +115,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
assertTrue { q.shape contentEquals shape } assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape } assertTrue { r.shape contentEquals shape }
assertTrue((q dot r).eq(tensor)) assertTrue((q bdot r).eq(tensor))
} }
@ -136,17 +136,17 @@ internal class TestDoubleLinearOpsTensorAlgebra {
assertTrue { l.shape contentEquals shape } assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape } assertTrue { u.shape contentEquals shape }
assertTrue((p dot tensor).eq(l dot u)) assertTrue((p bdot tensor).eq(l bdot u))
} }
@Test @Test
fun testCholesky() = DoubleTensorAlgebra { fun testCholesky() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 5), 0) val tensor = randomNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor dot tensor.transpose()) + diagonalEmbedding( val sigma = (tensor bdot tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
) )
val low = sigma.cholesky() val low = sigma.cholesky()
val sigmChol = low dot low.transpose() val sigmChol = low bdot low.transpose()
assertTrue(sigma.eq(sigmChol)) assertTrue(sigma.eq(sigmChol))
} }
@ -171,7 +171,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
fun testBatchedSVD() = DoubleTensorAlgebra { fun testBatchedSVD() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 3), 0) val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSVD = tensorU bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose())
assertTrue(tensor.eq(tensorSVD)) assertTrue(tensor.eq(tensorSVD))
} }
@ -180,7 +180,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0) val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV dot (diagonalEmbedding(tensorS) dot tensorV.transpose()) val tensorSigmaCalc = tensorV bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc)) assertTrue(tensorSigma.eq(tensorSigmaCalc))
} }

View File

@ -114,7 +114,7 @@ internal class TestDoubleTensorAlgebra {
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0)) assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals intArrayOf(2)) assertTrue(res12.shape contentEquals intArrayOf(2))
val res32 = tensor3.dot(tensor2) val res32 = tensor3.bdot(tensor2)
assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0)) assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals intArrayOf(1, 1)) assertTrue(res32.shape contentEquals intArrayOf(1, 1))
@ -126,7 +126,7 @@ internal class TestDoubleTensorAlgebra {
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2)) assertTrue(res11.shape contentEquals intArrayOf(2, 2))
val res45 = tensor4.dot(tensor5) val res45 = tensor4.bdot(tensor5)
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf( assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0, 36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0