forked from kscience/kmath
Fix tf dot
This commit is contained in:
parent
a78e361b17
commit
ac3adfa644
@ -52,6 +52,8 @@ kotlin {
|
|||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
|
implementation(projects.kmath.kmathTensorflow)
|
||||||
|
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||||
// uncomment if your system supports AVX2
|
// uncomment if your system supports AVX2
|
||||||
// val os = System.getProperty("os.name")
|
// val os = System.getProperty("os.name")
|
||||||
|
@ -17,7 +17,9 @@ import space.kscience.kmath.multik.multikAlgebra
|
|||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.tensorflow.produceWithTF
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
@ -44,6 +46,16 @@ internal class DotBenchmark {
|
|||||||
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
val ejmlMatrix2 = EjmlLinearSpaceDDRM { matrix2.toEjml() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun tfDot(blackhole: Blackhole){
|
||||||
|
blackhole.consume(
|
||||||
|
DoubleField.produceWithTF {
|
||||||
|
tensor1 dot tensor2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
fun cmDotWithConversion(blackhole: Blackhole) = CMLinearSpace {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
@ -64,13 +76,13 @@ internal class DotBenchmark {
|
|||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Benchmark
|
@Benchmark
|
||||||
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
fun tensorDot(blackhole: Blackhole) = with(DoubleField.tensorAlgebra) {
|
||||||
// blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
// }
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ allprojects {
|
|||||||
}
|
}
|
||||||
|
|
||||||
group = "space.kscience"
|
group = "space.kscience"
|
||||||
version = "0.3.0-dev-18"
|
version = "0.3.0-dev-19"
|
||||||
}
|
}
|
||||||
|
|
||||||
subprojects {
|
subprojects {
|
||||||
|
@ -199,8 +199,9 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
|
|||||||
|
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
override fun StructureND<T>.dot(other: StructureND<T>): TensorFlowOutput<T, TT> = operate(other) { l, r ->
|
||||||
ops.linalg.matMul(
|
ops.linalg.matMul(
|
||||||
if (l.asTensor().shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
if (l.shape().numDimensions() == 1) ops.expandDims(l, ops.constant(0)) else l,
|
||||||
if (r.asTensor().shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r)
|
if (r.shape().numDimensions() == 1) ops.expandDims(r, ops.constant(-1)) else r
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(
|
override fun diagonalEmbedding(
|
||||||
|
@ -4,6 +4,9 @@ import org.junit.jupiter.api.Test
|
|||||||
import space.kscience.kmath.nd.get
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.nd.structureND
|
import space.kscience.kmath.nd.structureND
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra.Companion.sum
|
||||||
|
import kotlin.random.Random
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class DoubleTensorFlowOps {
|
class DoubleTensorFlowOps {
|
||||||
@ -18,6 +21,19 @@ class DoubleTensorFlowOps {
|
|||||||
assertEquals(3.0, res[0, 0])
|
assertEquals(3.0, res[0, 0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun dot(){
|
||||||
|
val random = Random(12224)
|
||||||
|
val dim = 1000
|
||||||
|
|
||||||
|
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
||||||
|
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
||||||
|
|
||||||
|
DoubleField.produceWithTF {
|
||||||
|
tensor1 dot tensor2
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun extensionOps(){
|
fun extensionOps(){
|
||||||
val res = DoubleField.produceWithTF {
|
val res = DoubleField.produceWithTF {
|
||||||
|
@ -997,5 +997,6 @@ public open class DoubleTensorAlgebra :
|
|||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||||
|
public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user