rename bdot to matmul

This commit is contained in:
Alexander Nozik 2022-08-03 18:20:46 +03:00
parent e636ed27bd
commit ee0d44e12e
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
4 changed files with 14 additions and 12 deletions

View File

@ -9,7 +9,7 @@
### Changed ### Changed
- Kotlin 1.7.20 - Kotlin 1.7.20
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style - `LazyStructure` `deffered` -> `async` to comply with coroutines code style
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `bdot` operation is added to `DoubleTensorAlgebra`. - Default `dot` operation in tensor algebra no longer support broadcasting. Instead `matmul` operation is added to `DoubleTensorAlgebra`.
### Deprecated ### Deprecated

View File

@ -9,6 +9,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.* import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.MutableBuffer import space.kscience.kmath.structures.MutableBuffer
@ -410,7 +411,8 @@ public open class DoubleTensorAlgebra :
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return a mathematical product of two tensors. * @return a mathematical product of two tensors.
*/ */
public infix fun StructureND<Double>.bdot(other: StructureND<Double>): DoubleTensor { @UnstableKMathAPI
public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
@ -460,7 +462,7 @@ public open class DoubleTensorAlgebra :
} }
override fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor {
return if (dimension in 0..2 && other.dimension in 0..2) bdot(other) return if (dimension in 0..2 && other.dimension in 0..2) matmul(other)
else error("Only vectors and matrices are allowed in non-broadcasting dot operation") else error("Only vectors and matrices are allowed in non-broadcasting dot operation")
} }
@ -945,7 +947,7 @@ public open class DoubleTensorAlgebra :
val (u, s, v) = tensor.svd(epsilon) val (u, s, v) = tensor.svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
val utv = u.transpose() bdot v val utv = u.transpose() matmul v
val n = s.shape.last() val n = s.shape.last()
for (matrix in utv.matrixSequence()) { for (matrix in utv.matrixSequence()) {
matrix.as2D().cleanSym(n) matrix.as2D().cleanSym(n)

View File

@ -115,7 +115,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
assertTrue { q.shape contentEquals shape } assertTrue { q.shape contentEquals shape }
assertTrue { r.shape contentEquals shape } assertTrue { r.shape contentEquals shape }
assertTrue((q bdot r).eq(tensor)) assertTrue((q matmul r).eq(tensor))
} }
@ -136,17 +136,17 @@ internal class TestDoubleLinearOpsTensorAlgebra {
assertTrue { l.shape contentEquals shape } assertTrue { l.shape contentEquals shape }
assertTrue { u.shape contentEquals shape } assertTrue { u.shape contentEquals shape }
assertTrue((p bdot tensor).eq(l bdot u)) assertTrue((p matmul tensor).eq(l matmul u))
} }
@Test @Test
fun testCholesky() = DoubleTensorAlgebra { fun testCholesky() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 5), 0) val tensor = randomNormal(intArrayOf(2, 5, 5), 0)
val sigma = (tensor bdot tensor.transpose()) + diagonalEmbedding( val sigma = (tensor matmul tensor.transpose()) + diagonalEmbedding(
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 }) fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
) )
val low = sigma.cholesky() val low = sigma.cholesky()
val sigmChol = low bdot low.transpose() val sigmChol = low matmul low.transpose()
assertTrue(sigma.eq(sigmChol)) assertTrue(sigma.eq(sigmChol))
} }
@ -171,7 +171,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
fun testBatchedSVD() = DoubleTensorAlgebra { fun testBatchedSVD() = DoubleTensorAlgebra {
val tensor = randomNormal(intArrayOf(2, 5, 3), 0) val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
val (tensorU, tensorS, tensorV) = tensor.svd() val (tensorU, tensorS, tensorV) = tensor.svd()
val tensorSVD = tensorU bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose()) val tensorSVD = tensorU matmul (diagonalEmbedding(tensorS) matmul tensorV.transpose())
assertTrue(tensor.eq(tensorSVD)) assertTrue(tensor.eq(tensorSVD))
} }
@ -180,7 +180,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0) val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
val tensorSigma = tensor + tensor.transpose() val tensorSigma = tensor + tensor.transpose()
val (tensorS, tensorV) = tensorSigma.symEig() val (tensorS, tensorV) = tensorSigma.symEig()
val tensorSigmaCalc = tensorV bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose()) val tensorSigmaCalc = tensorV matmul (diagonalEmbedding(tensorS) matmul tensorV.transpose())
assertTrue(tensorSigma.eq(tensorSigmaCalc)) assertTrue(tensorSigma.eq(tensorSigmaCalc))
} }

View File

@ -114,7 +114,7 @@ internal class TestDoubleTensorAlgebra {
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0)) assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
assertTrue(res12.shape contentEquals intArrayOf(2)) assertTrue(res12.shape contentEquals intArrayOf(2))
val res32 = tensor3.bdot(tensor2) val res32 = tensor3.matmul(tensor2)
assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0)) assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0))
assertTrue(res32.shape contentEquals intArrayOf(1, 1)) assertTrue(res32.shape contentEquals intArrayOf(1, 1))
@ -126,7 +126,7 @@ internal class TestDoubleTensorAlgebra {
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0)) assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
assertTrue(res11.shape contentEquals intArrayOf(2, 2)) assertTrue(res11.shape contentEquals intArrayOf(2, 2))
val res45 = tensor4.bdot(tensor5) val res45 = tensor4.matmul(tensor5)
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf( assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0, 36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0 468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0