From a8182fad23417f4302ea77ccdeba3b931e50a471 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Thu, 4 Aug 2022 09:54:59 +0300 Subject: [PATCH] Multik went MPP --- CHANGELOG.md | 1 + benchmarks/build.gradle.kts | 10 ++++++--- .../kscience/kmath/benchmarks/DotBenchmark.kt | 3 +-- .../kmath/benchmarks/NDFieldBenchmark.kt | 6 ++---- .../kscience/kmath/benchmarks/globals.kt | 11 ++++++++++ build.gradle.kts | 2 ++ examples/build.gradle.kts | 3 +++ .../space/kscience/kmath/tensors/multik.kt | 9 +++++--- kmath-multik/build.gradle.kts | 21 +++++++++++++++---- .../kmath/multik/MultikDoubleAlgebra.kt | 9 +++++--- .../kmath/multik/MultikFloatAlgebra.kt | 9 +++++--- .../kscience/kmath/multik/MultikIntAlgebra.kt | 9 +++++--- .../kmath/multik/MultikLongAlgebra.kt | 9 +++++--- .../kmath/multik/MultikShortAlgebra.kt | 9 +++++--- .../kscience/kmath/multik/MultikTensor.kt | 1 + .../kmath/multik/MultikTensorAlgebra.kt | 18 +++++++++------- .../kscience/kmath/multik/MultikNDTest.kt | 16 +++++++------- 17 files changed, 100 insertions(+), 46 deletions(-) create mode 100644 benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/globals.kt rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt (87%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt (62%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt (63%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt (63%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt (62%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikTensor.kt (97%) rename kmath-multik/src/{main => commonMain}/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt (94%) rename kmath-multik/src/{test => commonTest}/kotlin/space/kscience/kmath/multik/MultikNDTest.kt (72%) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5ed80597..acec30836 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Kotlin 1.7.20 - `LazyStructure` `deffered` -> `async` to comply with coroutines code style - Default `dot` operation in tensor algebra no longer support broadcasting. Instead `matmul` operation is added to `DoubleTensorAlgebra`. +- Multik went MPP ### Deprecated diff --git a/benchmarks/build.gradle.kts b/benchmarks/build.gradle.kts index e86990043..9c8dc613a 100644 --- a/benchmarks/build.gradle.kts +++ b/benchmarks/build.gradle.kts @@ -1,5 +1,6 @@ @file:Suppress("UNUSED_VARIABLE") +import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile import space.kscience.kmath.benchmarks.addBenchmarkProperties plugins { @@ -15,6 +16,8 @@ repositories { mavenCentral() } +val multikVersion: String by rootProject.extra + kotlin { jvm() @@ -39,7 +42,9 @@ kotlin { implementation(project(":kmath-dimensions")) implementation(project(":kmath-for-real")) implementation(project(":kmath-tensors")) - implementation("org.jetbrains.kotlinx:kotlinx-benchmark-runtime:0.4.2") + implementation(project(":kmath-multik")) + implementation("org.jetbrains.kotlinx:multik-default:$multikVersion") + implementation(npmlibs.kotlinx.benchmark.runtime) } } @@ -51,7 +56,6 @@ kotlin { implementation(project(":kmath-kotlingrad")) implementation(project(":kmath-viktor")) implementation(project(":kmath-jafama")) - implementation(project(":kmath-multik")) implementation(projects.kmath.kmathTensorflow) implementation("org.tensorflow:tensorflow-core-platform:0.4.0") implementation("org.nd4j:nd4j-native:1.0.0-M1") @@ -155,7 +159,7 @@ kotlin.sourceSets.all { } } -tasks.withType { +tasks.withType { kotlinOptions { jvmTarget = "11" freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy" diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt index 7ceecb5ab..d7ead6dd8 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/DotBenchmark.kt @@ -13,7 +13,6 @@ import space.kscience.kmath.commons.linear.CMLinearSpace import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM import space.kscience.kmath.linear.invoke import space.kscience.kmath.linear.linearSpace -import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensorflow.produceWithTF @@ -78,7 +77,7 @@ internal class DotBenchmark { } @Benchmark - fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) { + fun multikDot(blackhole: Blackhole) = with(multikAlgebra) { blackhole.consume(matrix1 dot matrix2) } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt index 89673acd4..eb5e1da47 100644 --- a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/NDFieldBenchmark.kt @@ -13,7 +13,6 @@ import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ones import org.jetbrains.kotlinx.multik.ndarray.data.DN import org.jetbrains.kotlinx.multik.ndarray.data.DataType -import space.kscience.kmath.multik.multikAlgebra import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.ndAlgebra @@ -43,7 +42,7 @@ internal class NDFieldBenchmark { } @Benchmark - fun multikAdd(blackhole: Blackhole) = with(multikField) { + fun multikAdd(blackhole: Blackhole) = with(multikAlgebra) { var res: StructureND = one(shape) repeat(n) { res += 1.0 } blackhole.consume(res) @@ -71,7 +70,7 @@ internal class NDFieldBenchmark { } @Benchmark - fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) { + fun multikInPlaceAdd(blackhole: Blackhole) = with(multikAlgebra) { val res = Multik.ones(shape, DataType.DoubleDataType).wrap() repeat(n) { res += 1.0 } blackhole.consume(res) @@ -91,7 +90,6 @@ internal class NDFieldBenchmark { private val specializedField = DoubleField.ndAlgebra private val genericField = BufferedFieldOpsND(DoubleField) private val nd4jField = DoubleField.nd4j - private val multikField = DoubleField.multikAlgebra private val viktorField = DoubleField.viktorAlgebra } } diff --git a/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/globals.kt b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/globals.kt new file mode 100644 index 000000000..9f235c56e --- /dev/null +++ b/benchmarks/src/jvmMain/kotlin/space/kscience/kmath/benchmarks/globals.kt @@ -0,0 +1,11 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.benchmarks + +import org.jetbrains.kotlinx.multik.default.DefaultEngine +import space.kscience.kmath.multik.MultikDoubleAlgebra + +val multikAlgebra = MultikDoubleAlgebra(DefaultEngine()) \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index c57b29105..b6e5b0748 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -74,3 +74,5 @@ ksciencePublish { } apiValidation.nonPublicMarkers.add("space.kscience.kmath.misc.UnstableKMathAPI") + +val multikVersion by extra("0.2.0") diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts index cfa50815d..1d224bace 100644 --- a/examples/build.gradle.kts +++ b/examples/build.gradle.kts @@ -8,6 +8,8 @@ repositories { maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers") } +val multikVersion: String by rootProject.extra + dependencies { implementation(project(":kmath-ast")) implementation(project(":kmath-kotlingrad")) @@ -30,6 +32,7 @@ dependencies { implementation(project(":kmath-jafama")) //multik implementation(project(":kmath-multik")) + implementation("org.jetbrains.kotlinx:multik-default:$multikVersion") implementation("org.nd4j:nd4j-native:1.0.0-beta7") diff --git a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt index f2d1f0b41..7cf1c5120 100644 --- a/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt +++ b/examples/src/main/kotlin/space/kscience/kmath/tensors/multik.kt @@ -7,11 +7,14 @@ package space.kscience.kmath.tensors import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarray -import space.kscience.kmath.multik.multikAlgebra +import org.jetbrains.kotlinx.multik.default.DefaultEngine +import space.kscience.kmath.multik.MultikDoubleAlgebra import space.kscience.kmath.nd.one -import space.kscience.kmath.operations.DoubleField -fun main(): Unit = with(DoubleField.multikAlgebra) { + +val multikAlgebra = MultikDoubleAlgebra(DefaultEngine()) + +fun main(): Unit = with(multikAlgebra) { val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType().wrap() val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() one(a.shape) - a + b * 3.0 diff --git a/kmath-multik/build.gradle.kts b/kmath-multik/build.gradle.kts index ee31ead3c..39c03ab9b 100644 --- a/kmath-multik/build.gradle.kts +++ b/kmath-multik/build.gradle.kts @@ -1,12 +1,25 @@ plugins { - id("space.kscience.gradle.jvm") + id("space.kscience.gradle.mpp") } description = "JetBrains Multik connector" -dependencies { - api(project(":kmath-tensors")) - api("org.jetbrains.kotlinx:multik-default:0.2.0") +val multikVersion: String by rootProject.extra + +kotlin{ + sourceSets{ + commonMain{ + dependencies{ + api(project(":kmath-tensors")) + api("org.jetbrains.kotlinx:multik-core:$multikVersion") + } + } + commonTest{ + dependencies{ + api("org.jetbrains.kotlinx:multik-default:$multikVersion") + } + } + } } readme { diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt similarity index 87% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt index 1d4665618..dcc1623cc 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt @@ -5,6 +5,7 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Engine import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType @@ -14,7 +15,9 @@ import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.ExponentialOperations import space.kscience.kmath.operations.TrigonometricOperations -public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra(), +public class MultikDoubleAlgebra( + multikEngine: Engine +) : MultikDivisionTensorAlgebra(multikEngine), TrigonometricOperations>, ExponentialOperations> { override val elementAlgebra: DoubleField get() = DoubleField override val type: DataType get() = DataType.DoubleDataType @@ -60,6 +63,6 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra = Multik.ndarrayOf(value).wrap() } -public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra -public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra +//public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra +//public val DoubleField.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt similarity index 62% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt index 8dd9b3f60..db251b552 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt @@ -5,12 +5,15 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Engine import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.operations.FloatField -public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { +public class MultikFloatAlgebra( + multikEngine: Engine +) : MultikDivisionTensorAlgebra(multikEngine) { override val elementAlgebra: FloatField get() = FloatField override val type: DataType get() = DataType.FloatDataType @@ -18,5 +21,5 @@ public object MultikFloatAlgebra : MultikDivisionTensorAlgebra get() = MultikFloatAlgebra -public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra \ No newline at end of file +//public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra +//public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt similarity index 63% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt index b9e0125e8..517845939 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt @@ -5,16 +5,19 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Engine import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.operations.IntRing -public object MultikIntAlgebra : MultikTensorAlgebra() { +public class MultikIntAlgebra( + multikEngine: Engine +) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: IntRing get() = IntRing override val type: DataType get() = DataType.IntDataType override fun scalar(value: Int): MultikTensor = Multik.ndarrayOf(value).wrap() } -public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra -public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra \ No newline at end of file +//public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra +//public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt similarity index 63% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt index 1f3dfb0e5..9c6c1af54 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt @@ -5,12 +5,15 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Engine import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.operations.LongRing -public object MultikLongAlgebra : MultikTensorAlgebra() { +public class MultikLongAlgebra( + multikEngine: Engine +) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: LongRing get() = LongRing override val type: DataType get() = DataType.LongDataType @@ -18,5 +21,5 @@ public object MultikLongAlgebra : MultikTensorAlgebra() { } -public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra -public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file +//public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra +//public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt similarity index 62% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt index b9746571d..5288351b2 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt @@ -5,16 +5,19 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Engine import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.operations.ShortRing -public object MultikShortAlgebra : MultikTensorAlgebra() { +public class MultikShortAlgebra( + multikEngine: Engine +) : MultikTensorAlgebra(multikEngine) { override val elementAlgebra: ShortRing get() = ShortRing override val type: DataType get() = DataType.ShortDataType override fun scalar(value: Short): MultikTensor = Multik.ndarrayOf(value).wrap() } -public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra -public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra \ No newline at end of file +//public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra +//public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt similarity index 97% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt index 7efe672b4..3ab4b0e3d 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensor.kt @@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.multik.ndarray.data.* import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.Shape import space.kscience.kmath.tensors.api.Tensor +import kotlin.jvm.JvmInline @JvmInline public value class MultikTensor(public val array: MutableMultiArray) : Tensor { diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt similarity index 94% rename from kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt rename to kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index 1cc1b9b6d..f9768ea39 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/commonMain/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -20,14 +20,15 @@ import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra -public abstract class MultikTensorAlgebra> : TensorAlgebra - where T : Number, T : Comparable { +public abstract class MultikTensorAlgebra>( + private val multikEngine: Engine, +) : TensorAlgebra where T : Number, T : Comparable { public abstract val type: DataType - protected val multikMath: Math = mk.math - protected val multikLinAl: LinAlg = mk.linalg - protected val multikStat: Statistics = mk.stat + protected val multikMath: Math = multikEngine.getMath() + protected val multikLinAl: LinAlg = multikEngine.getLinAlg() + protected val multikStat: Statistics = multikEngine.getStatistics() override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { val strides = DefaultStrides(shape) @@ -255,12 +256,13 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { if (keepDim) TODO("keepDim not implemented") val res = multikMath.argMaxDN(asMultik().array, dim) - return with(MultikIntAlgebra) { res.wrap() } + return with(MultikIntAlgebra(multikEngine)) { res.wrap() } } } -public abstract class MultikDivisionTensorAlgebra> - : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { +public abstract class MultikDivisionTensorAlgebra>( + multikEngine: Engine, +) : MultikTensorAlgebra(multikEngine), TensorPartialDivisionAlgebra where T : Number, T : Comparable { override fun T.div(arg: StructureND): MultikTensor = Multik.ones(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap() diff --git a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt b/kmath-multik/src/commonTest/kotlin/space/kscience/kmath/multik/MultikNDTest.kt similarity index 72% rename from kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt rename to kmath-multik/src/commonTest/kotlin/space/kscience/kmath/multik/MultikNDTest.kt index 72c43d8e6..8bc43eef2 100644 --- a/kmath-multik/src/test/kotlin/space/kscience/kmath/multik/MultikNDTest.kt +++ b/kmath-multik/src/commonTest/kotlin/space/kscience/kmath/multik/MultikNDTest.kt @@ -5,33 +5,35 @@ package space.kscience.kmath.multik -import org.junit.jupiter.api.Test +import org.jetbrains.kotlinx.multik.default.DefaultEngine import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.one import space.kscience.kmath.operations.DoubleField -import space.kscience.kmath.operations.invoke import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.tensorAlgebra +import kotlin.test.Test import kotlin.test.assertTrue internal class MultikNDTest { + val multikAlgebra = MultikDoubleAlgebra(DefaultEngine()) + @Test - fun basicAlgebra(): Unit = DoubleField.multikAlgebra{ - one(2,2) + 1.0 + fun basicAlgebra(): Unit = with(multikAlgebra) { + one(2, 2) + 1.0 } @Test - fun dotResult(){ + fun dotResult() { val dim = 100 val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224) val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225) - val multikResult = with(DoubleField.multikAlgebra){ + val multikResult = with(multikAlgebra) { tensor1 dot tensor2 } - val defaultResult = with(DoubleField.tensorAlgebra){ + val defaultResult = with(DoubleField.tensorAlgebra) { tensor1 dot tensor2 }