forked from kscience/kmath
Fix tf dot
This commit is contained in:
parent
a78e361b17
commit
ac3adfa644
@ -52,6 +52,8 @@ kotlin {
|
||||
implementation(project(":kmath-viktor"))
|
||||
implementation(project(":kmath-jafama"))
|
||||
implementation(project(":kmath-multik"))
|
||||
implementation(projects.kmath.kmathTensorflow)
|
||||
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||
// uncomment if your system supports AVX2
|
||||
// val os = System.getProperty("os.name")
|
||||
|
@ -17,7 +17,9 @@ import space.kscience.kmath.multik.multikAlgebra
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.invoke
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.tensorflow.produceWithTF
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||
import kotlin.random.Random
|
||||
|
||||
@State(Scope.Benchmark)
|
||||
@ -44,6 +46,16 @@ internal class DotBenchmark {
|
||||
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
||||
}
|
||||
|
||||
|
||||
@Benchmark
|
||||
fun tfDot(blackhole: Blackhole){
|
||||
blackhole.consume(
|
||||
DoubleField.produceWithTF {
|
||||
tensor1 dot tensor2
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
@ -64,13 +76,13 @@ internal class DotBenchmark {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
// @Benchmark
|
||||
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||
// blackhole.consume(matrix1 dot matrix2)
|
||||
// }
|
||||
@Benchmark
|
||||
fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
||||
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ allprojects {
|
||||
}
|
||||
|
||||
group = "space.kscience"
|
||||
version = "0.3.0-dev-18"
|
||||
version = "0.3.0-dev-19"
|
||||
}
|
||||
|
||||
subprojects {
|
||||
|
@ -199,8 +199,9 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
||||
|
||||
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
||||
ops.linalg.matMul(
|
||||
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
||||
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
|
||||
if (l.shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
||||
if (r.shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r
|
||||
)
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
|
@ -4,6 +4,9 @@ import org.junit.jupiter.api.Test
|
||||
import space.kscience.kmath.nd.get
|
||||
import space.kscience.kmath.nd.structureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum
|
||||
import kotlin.random.Random
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class DoubleTensorFlowOps {
|
||||
@ -18,6 +21,19 @@ class DoubleTensorFlowOps {
|
||||
assertEquals(3.0, res[0, 0])
|
||||
}
|
||||
|
||||
@Test
|
||||
fun dot(){
|
||||
val random = Random(12224)
|
||||
val dim = 1000
|
||||
|
||||
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
||||
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
||||
|
||||
DoubleField.produceWithTF {
|
||||
tensor1 dot tensor2
|
||||
}.sum()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun extensionOps(){
|
||||
val res = DoubleField.produceWithTF {
|
||||
|
@ -997,5 +997,6 @@ public open class DoubleTensorAlgebra :
|
||||
}
|
||||
|
||||
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||
public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user