rename bdot to matmul
This commit is contained in:
parent
e636ed27bd
commit
ee0d44e12e
@ -9,7 +9,7 @@
|
||||
### Changed
|
||||
- Kotlin 1.7.20
|
||||
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
|
||||
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `bdot` operation is added to `DoubleTensorAlgebra`.
|
||||
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `matmul` operation is added to `DoubleTensorAlgebra`.
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
@ -410,7 +411,8 @@ public open class DoubleTensorAlgebra :
|
||||
* @param other tensor to be multiplied.
|
||||
* @return a mathematical product of two tensors.
|
||||
*/
|
||||
public infix fun StructureND<Double>.bdot(other: StructureND<Double>): DoubleTensor {
|
||||
@UnstableKMathAPI
|
||||
public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor {
|
||||
if (tensor.shape.size == 1 && other.shape.size == 1) {
|
||||
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
|
||||
}
|
||||
@ -460,7 +462,7 @@ public open class DoubleTensorAlgebra :
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor {
|
||||
return if (dimension in 0..2 && other.dimension in 0..2) bdot(other)
|
||||
return if (dimension in 0..2 && other.dimension in 0..2) matmul(other)
|
||||
else error("Only vectors and matrices are allowed in non-broadcasting dot operation")
|
||||
}
|
||||
|
||||
@ -945,7 +947,7 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
val (u, s, v) = tensor.svd(epsilon)
|
||||
val shp = s.shape + intArrayOf(1)
|
||||
val utv = u.transpose() bdot v
|
||||
val utv = u.transpose() matmul v
|
||||
val n = s.shape.last()
|
||||
for (matrix in utv.matrixSequence()) {
|
||||
matrix.as2D().cleanSym(n)
|
||||
|
@ -115,7 +115,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
assertTrue { q.shape contentEquals shape }
|
||||
assertTrue { r.shape contentEquals shape }
|
||||
|
||||
assertTrue((q bdot r).eq(tensor))
|
||||
assertTrue((q matmul r).eq(tensor))
|
||||
|
||||
}
|
||||
|
||||
@ -136,17 +136,17 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
assertTrue { l.shape contentEquals shape }
|
||||
assertTrue { u.shape contentEquals shape }
|
||||
|
||||
assertTrue((p bdot tensor).eq(l bdot u))
|
||||
assertTrue((p matmul tensor).eq(l matmul u))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testCholesky() = DoubleTensorAlgebra {
|
||||
val tensor = randomNormal(intArrayOf(2, 5, 5), 0)
|
||||
val sigma = (tensor bdot tensor.transpose()) + diagonalEmbedding(
|
||||
val sigma = (tensor matmul tensor.transpose()) + diagonalEmbedding(
|
||||
fromArray(intArrayOf(2, 5), DoubleArray(10) { 0.1 })
|
||||
)
|
||||
val low = sigma.cholesky()
|
||||
val sigmChol = low bdot low.transpose()
|
||||
val sigmChol = low matmul low.transpose()
|
||||
assertTrue(sigma.eq(sigmChol))
|
||||
}
|
||||
|
||||
@ -171,7 +171,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
fun testBatchedSVD() = DoubleTensorAlgebra {
|
||||
val tensor = randomNormal(intArrayOf(2, 5, 3), 0)
|
||||
val (tensorU, tensorS, tensorV) = tensor.svd()
|
||||
val tensorSVD = tensorU bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose())
|
||||
val tensorSVD = tensorU matmul (diagonalEmbedding(tensorS) matmul tensorV.transpose())
|
||||
assertTrue(tensor.eq(tensorSVD))
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ internal class TestDoubleLinearOpsTensorAlgebra {
|
||||
val tensor = randomNormal(shape = intArrayOf(2, 3, 3), 0)
|
||||
val tensorSigma = tensor + tensor.transpose()
|
||||
val (tensorS, tensorV) = tensorSigma.symEig()
|
||||
val tensorSigmaCalc = tensorV bdot (diagonalEmbedding(tensorS) bdot tensorV.transpose())
|
||||
val tensorSigmaCalc = tensorV matmul (diagonalEmbedding(tensorS) matmul tensorV.transpose())
|
||||
assertTrue(tensorSigma.eq(tensorSigmaCalc))
|
||||
}
|
||||
|
||||
|
@ -114,7 +114,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
assertTrue(res12.mutableBuffer.array() contentEquals doubleArrayOf(140.0, 320.0))
|
||||
assertTrue(res12.shape contentEquals intArrayOf(2))
|
||||
|
||||
val res32 = tensor3.bdot(tensor2)
|
||||
val res32 = tensor3.matmul(tensor2)
|
||||
assertTrue(res32.mutableBuffer.array() contentEquals doubleArrayOf(-140.0))
|
||||
assertTrue(res32.shape contentEquals intArrayOf(1, 1))
|
||||
|
||||
@ -126,7 +126,7 @@ internal class TestDoubleTensorAlgebra {
|
||||
assertTrue(res11.mutableBuffer.array() contentEquals doubleArrayOf(22.0, 28.0, 49.0, 64.0))
|
||||
assertTrue(res11.shape contentEquals intArrayOf(2, 2))
|
||||
|
||||
val res45 = tensor4.bdot(tensor5)
|
||||
val res45 = tensor4.matmul(tensor5)
|
||||
assertTrue(res45.mutableBuffer.array() contentEquals doubleArrayOf(
|
||||
36.0, 42.0, 48.0, 81.0, 96.0, 111.0, 126.0, 150.0, 174.0,
|
||||
468.0, 501.0, 534.0, 594.0, 636.0, 678.0, 720.0, 771.0, 822.0
|
||||
|
Loading…
Reference in New Issue
Block a user