Fix tf dot

This commit is contained in:
Alexander Nozik 2022-02-17 22:46:17 +03:00
parent a78e361b17
commit ac3adfa644
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
7 changed files with 40 additions and 8 deletions

View File

@ -52,6 +52,8 @@ kotlin {
implementation(project(":kmath-viktor"))
implementation(project(":kmath-jafama"))
implementation(project(":kmath-multik"))
implementation(projects.kmath.kmathTensorflow)
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
implementation("org.nd4j:nd4j-native:1.0.0-M1")
// uncomment if your system supports AVX2
// val os = System.getProperty("os.name")

View File

@ -17,7 +17,9 @@ import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import space.kscience.kmath.tensorflow.produceWithTF
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.tensorAlgebra
import kotlin.random.Random
@State(Scope.Benchmark)
@ -44,6 +46,16 @@ internal class DotBenchmark {
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
}
@Benchmark
fun tfDot(blackhole: Blackhole){
blackhole.consume(
DoubleField.produceWithTF {
tensor1 dot tensor2
}
)
}
@Benchmark
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
blackhole.consume(matrix1 dot matrix2)
@ -64,13 +76,13 @@ internal class DotBenchmark {
blackhole.consume(matrix1 dot matrix2)
}
// @Benchmark
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
// blackhole.consume(matrix1 dot matrix2)
// }
@Benchmark
fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
blackhole.consume(matrix1 dot matrix2)
}
@Benchmark
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
blackhole.consume(matrix1 dot matrix2)
}

View File

@ -10,7 +10,7 @@ allprojects {
}
group = "space.kscience"
version = "0.3.0-dev-18"
version = "0.3.0-dev-19"
}
subprojects {

View File

@ -199,8 +199,9 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
ops.linalg.matMul(
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
if (l.shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
if (r.shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r
)
}
override fun diagonalEmbedding(

View File

@ -4,6 +4,9 @@ import org.junit.jupiter.api.Test
import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum
import kotlin.random.Random
import kotlin.test.assertEquals
class DoubleTensorFlowOps {
@ -18,6 +21,19 @@ class DoubleTensorFlowOps {
assertEquals(3.0, res[0, 0])
}
@Test
fun dot(){
val random = Random(12224)
val dim = 1000
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
DoubleField.produceWithTF {
tensor1 dot tensor2
}.sum()
}
@Test
fun extensionOps(){
val res = DoubleField.produceWithTF {

View File

@ -997,5 +997,6 @@ public open class DoubleTensorAlgebra :
}
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra