Multik went MPP
This commit is contained in:
parent
ee0d44e12e
commit
a8182fad23
@ -10,6 +10,7 @@
|
|||||||
- Kotlin 1.7.20
|
- Kotlin 1.7.20
|
||||||
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
|
- `LazyStructure` `deffered` -> `async` to comply with coroutines code style
|
||||||
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `matmul` operation is added to `DoubleTensorAlgebra`.
|
- Default `dot` operation in tensor algebra no longer support broadcasting. Instead `matmul` operation is added to `DoubleTensorAlgebra`.
|
||||||
|
- Multik went MPP
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
@file:Suppress("UNUSED_VARIABLE")
|
@file:Suppress("UNUSED_VARIABLE")
|
||||||
|
|
||||||
|
import org.jetbrains.kotlin.gradle.tasks.KotlinJvmCompile
|
||||||
import space.kscience.kmath.benchmarks.addBenchmarkProperties
|
import space.kscience.kmath.benchmarks.addBenchmarkProperties
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
@ -15,6 +16,8 @@ repositories {
|
|||||||
mavenCentral()
|
mavenCentral()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val multikVersion: String by rootProject.extra
|
||||||
|
|
||||||
kotlin {
|
kotlin {
|
||||||
jvm()
|
jvm()
|
||||||
|
|
||||||
@ -39,7 +42,9 @@ kotlin {
|
|||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation(project(":kmath-for-real"))
|
implementation(project(":kmath-for-real"))
|
||||||
implementation(project(":kmath-tensors"))
|
implementation(project(":kmath-tensors"))
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-benchmark-runtime:0.4.2")
|
implementation(project(":kmath-multik"))
|
||||||
|
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||||
|
implementation(npmlibs.kotlinx.benchmark.runtime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,7 +56,6 @@ kotlin {
|
|||||||
implementation(project(":kmath-kotlingrad"))
|
implementation(project(":kmath-kotlingrad"))
|
||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
implementation(project(":kmath-multik"))
|
|
||||||
implementation(projects.kmath.kmathTensorflow)
|
implementation(projects.kmath.kmathTensorflow)
|
||||||
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
implementation("org.tensorflow:tensorflow-core-platform:0.4.0")
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
implementation("org.nd4j:nd4j-native:1.0.0-M1")
|
||||||
@ -155,7 +159,7 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<org.jetbrains.kotlin.gradle.dsl.KotlinJvmCompile> {
|
tasks.withType<KotlinJvmCompile> {
|
||||||
kotlinOptions {
|
kotlinOptions {
|
||||||
jvmTarget = "11"
|
jvmTarget = "11"
|
||||||
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
freeCompilerArgs = freeCompilerArgs + "-Xjvm-default=all" + "-Xlambdas=indy"
|
||||||
|
@ -13,7 +13,6 @@ import space.kscience.kmath.commons.linear.CMLinearSpace
|
|||||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||||
import space.kscience.kmath.linear.invoke
|
import space.kscience.kmath.linear.invoke
|
||||||
import space.kscience.kmath.linear.linearSpace
|
import space.kscience.kmath.linear.linearSpace
|
||||||
import space.kscience.kmath.multik.multikAlgebra
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.tensorflow.produceWithTF
|
import space.kscience.kmath.tensorflow.produceWithTF
|
||||||
@ -78,7 +77,7 @@ internal class DotBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikDot(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
fun multikDot(blackhole: Blackhole) = with(multikAlgebra) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,6 @@ import org.jetbrains.kotlinx.multik.api.Multik
|
|||||||
import org.jetbrains.kotlinx.multik.api.ones
|
import org.jetbrains.kotlinx.multik.api.ones
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.multik.multikAlgebra
|
|
||||||
import space.kscience.kmath.nd.BufferedFieldOpsND
|
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
@ -43,7 +42,7 @@ internal class NDFieldBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikAdd(blackhole: Blackhole) = with(multikField) {
|
fun multikAdd(blackhole: Blackhole) = with(multikAlgebra) {
|
||||||
var res: StructureND<Double> = one(shape)
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
@ -71,7 +70,7 @@ internal class NDFieldBenchmark {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
|
fun multikInPlaceAdd(blackhole: Blackhole) = with(multikAlgebra) {
|
||||||
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
|
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
@ -91,7 +90,6 @@ internal class NDFieldBenchmark {
|
|||||||
private val specializedField = DoubleField.ndAlgebra
|
private val specializedField = DoubleField.ndAlgebra
|
||||||
private val genericField = BufferedFieldOpsND(DoubleField)
|
private val genericField = BufferedFieldOpsND(DoubleField)
|
||||||
private val nd4jField = DoubleField.nd4j
|
private val nd4jField = DoubleField.nd4j
|
||||||
private val multikField = DoubleField.multikAlgebra
|
|
||||||
private val viktorField = DoubleField.viktorAlgebra
|
private val viktorField = DoubleField.viktorAlgebra
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,11 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.benchmarks
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.default.DefaultEngine
|
||||||
|
import space.kscience.kmath.multik.MultikDoubleAlgebra
|
||||||
|
|
||||||
|
val multikAlgebra = MultikDoubleAlgebra(DefaultEngine())
|
@ -74,3 +74,5 @@ ksciencePublish {
|
|||||||
}
|
}
|
||||||
|
|
||||||
apiValidation.nonPublicMarkers.add("space.kscience.kmath.misc.UnstableKMathAPI")
|
apiValidation.nonPublicMarkers.add("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||||
|
|
||||||
|
val multikVersion by extra("0.2.0")
|
||||||
|
@ -8,6 +8,8 @@ repositories {
|
|||||||
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
maven("https://maven.pkg.jetbrains.space/kotlin/p/kotlin/kotlin-js-wrappers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val multikVersion: String by rootProject.extra
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-kotlingrad"))
|
implementation(project(":kmath-kotlingrad"))
|
||||||
@ -30,6 +32,7 @@ dependencies {
|
|||||||
implementation(project(":kmath-jafama"))
|
implementation(project(":kmath-jafama"))
|
||||||
//multik
|
//multik
|
||||||
implementation(project(":kmath-multik"))
|
implementation(project(":kmath-multik"))
|
||||||
|
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||||
|
|
||||||
|
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||||
|
@ -7,11 +7,14 @@ package space.kscience.kmath.tensors
|
|||||||
|
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarray
|
import org.jetbrains.kotlinx.multik.api.ndarray
|
||||||
import space.kscience.kmath.multik.multikAlgebra
|
import org.jetbrains.kotlinx.multik.default.DefaultEngine
|
||||||
|
import space.kscience.kmath.multik.MultikDoubleAlgebra
|
||||||
import space.kscience.kmath.nd.one
|
import space.kscience.kmath.nd.one
|
||||||
import space.kscience.kmath.operations.DoubleField
|
|
||||||
|
|
||||||
fun main(): Unit = with(DoubleField.multikAlgebra) {
|
|
||||||
|
val multikAlgebra = MultikDoubleAlgebra(DefaultEngine())
|
||||||
|
|
||||||
|
fun main(): Unit = with(multikAlgebra) {
|
||||||
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
|
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
|
||||||
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap()
|
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap()
|
||||||
one(a.shape) - a + b * 3.0
|
one(a.shape) - a + b * 3.0
|
||||||
|
@ -1,12 +1,25 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id("space.kscience.gradle.jvm")
|
id("space.kscience.gradle.mpp")
|
||||||
}
|
}
|
||||||
|
|
||||||
description = "JetBrains Multik connector"
|
description = "JetBrains Multik connector"
|
||||||
|
|
||||||
|
val multikVersion: String by rootProject.extra
|
||||||
|
|
||||||
|
kotlin{
|
||||||
|
sourceSets{
|
||||||
|
commonMain{
|
||||||
dependencies{
|
dependencies{
|
||||||
api(project(":kmath-tensors"))
|
api(project(":kmath-tensors"))
|
||||||
api("org.jetbrains.kotlinx:multik-default:0.2.0")
|
api("org.jetbrains.kotlinx:multik-core:$multikVersion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commonTest{
|
||||||
|
dependencies{
|
||||||
|
api("org.jetbrains.kotlinx:multik-default:$multikVersion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Engine
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
@ -14,7 +15,9 @@ import space.kscience.kmath.operations.DoubleField
|
|||||||
import space.kscience.kmath.operations.ExponentialOperations
|
import space.kscience.kmath.operations.ExponentialOperations
|
||||||
import space.kscience.kmath.operations.TrigonometricOperations
|
import space.kscience.kmath.operations.TrigonometricOperations
|
||||||
|
|
||||||
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>(),
|
public class MultikDoubleAlgebra(
|
||||||
|
multikEngine: Engine
|
||||||
|
) : MultikDivisionTensorAlgebra<Double, DoubleField>(multikEngine),
|
||||||
TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> {
|
TrigonometricOperations<StructureND<Double>>, ExponentialOperations<StructureND<Double>> {
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
override val type: DataType get() = DataType.DoubleDataType
|
override val type: DataType get() = DataType.DoubleDataType
|
||||||
@ -60,6 +63,6 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
|
|||||||
override fun scalar(value: Double): MultikTensor<Double> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Double): MultikTensor<Double> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
//public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
//public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
|
@ -5,12 +5,15 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Engine
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.operations.FloatField
|
import space.kscience.kmath.operations.FloatField
|
||||||
|
|
||||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
public class MultikFloatAlgebra(
|
||||||
|
multikEngine: Engine
|
||||||
|
) : MultikDivisionTensorAlgebra<Float, FloatField>(multikEngine) {
|
||||||
override val elementAlgebra: FloatField get() = FloatField
|
override val elementAlgebra: FloatField get() = FloatField
|
||||||
override val type: DataType get() = DataType.FloatDataType
|
override val type: DataType get() = DataType.FloatDataType
|
||||||
|
|
||||||
@ -18,5 +21,5 @@ public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
//public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
||||||
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
//public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
@ -5,16 +5,19 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Engine
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.operations.IntRing
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
|
||||||
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
|
public class MultikIntAlgebra(
|
||||||
|
multikEngine: Engine
|
||||||
|
) : MultikTensorAlgebra<Int, IntRing>(multikEngine) {
|
||||||
override val elementAlgebra: IntRing get() = IntRing
|
override val elementAlgebra: IntRing get() = IntRing
|
||||||
override val type: DataType get() = DataType.IntDataType
|
override val type: DataType get() = DataType.IntDataType
|
||||||
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
//public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
||||||
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
//public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
@ -5,12 +5,15 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Engine
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.operations.LongRing
|
import space.kscience.kmath.operations.LongRing
|
||||||
|
|
||||||
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
public class MultikLongAlgebra(
|
||||||
|
multikEngine: Engine
|
||||||
|
) : MultikTensorAlgebra<Long, LongRing>(multikEngine) {
|
||||||
override val elementAlgebra: LongRing get() = LongRing
|
override val elementAlgebra: LongRing get() = LongRing
|
||||||
override val type: DataType get() = DataType.LongDataType
|
override val type: DataType get() = DataType.LongDataType
|
||||||
|
|
||||||
@ -18,5 +21,5 @@ public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
//public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
||||||
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
//public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
@ -5,16 +5,19 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Engine
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.operations.ShortRing
|
import space.kscience.kmath.operations.ShortRing
|
||||||
|
|
||||||
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
|
public class MultikShortAlgebra(
|
||||||
|
multikEngine: Engine
|
||||||
|
) : MultikTensorAlgebra<Short, ShortRing>(multikEngine) {
|
||||||
override val elementAlgebra: ShortRing get() = ShortRing
|
override val elementAlgebra: ShortRing get() = ShortRing
|
||||||
override val type: DataType get() = DataType.ShortDataType
|
override val type: DataType get() = DataType.ShortDataType
|
||||||
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
//public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
||||||
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
//public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.multik.ndarray.data.*
|
|||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.Shape
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
import kotlin.jvm.JvmInline
|
||||||
|
|
||||||
@JvmInline
|
@JvmInline
|
||||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
@ -20,14 +20,15 @@ import space.kscience.kmath.tensors.api.Tensor
|
|||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||||
|
|
||||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||||
where T : Number, T : Comparable<T> {
|
private val multikEngine: Engine,
|
||||||
|
) : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
public abstract val type: DataType
|
public abstract val type: DataType
|
||||||
|
|
||||||
protected val multikMath: Math = mk.math
|
protected val multikMath: Math = multikEngine.getMath()
|
||||||
protected val multikLinAl: LinAlg = mk.linalg
|
protected val multikLinAl: LinAlg = multikEngine.getLinAlg()
|
||||||
protected val multikStat: Statistics = mk.stat
|
protected val multikStat: Statistics = multikEngine.getStatistics()
|
||||||
|
|
||||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
@ -255,12 +256,13 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
||||||
if (keepDim) TODO("keepDim not implemented")
|
if (keepDim) TODO("keepDim not implemented")
|
||||||
val res = multikMath.argMaxDN(asMultik().array, dim)
|
val res = multikMath.argMaxDN(asMultik().array, dim)
|
||||||
return with(MultikIntAlgebra) { res.wrap() }
|
return with(MultikIntAlgebra(multikEngine)) { res.wrap() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>(
|
||||||
: MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
|
multikEngine: Engine,
|
||||||
|
) : MultikTensorAlgebra<T, A>(multikEngine), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
||||||
Multik.ones<T, DN>(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap()
|
Multik.ones<T, DN>(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap()
|
@ -5,18 +5,20 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
import org.junit.jupiter.api.Test
|
import org.jetbrains.kotlinx.multik.default.DefaultEngine
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.one
|
import space.kscience.kmath.nd.one
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.core.tensorAlgebra
|
import space.kscience.kmath.tensors.core.tensorAlgebra
|
||||||
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
internal class MultikNDTest {
|
internal class MultikNDTest {
|
||||||
|
val multikAlgebra = MultikDoubleAlgebra(DefaultEngine())
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun basicAlgebra(): Unit = DoubleField.multikAlgebra{
|
fun basicAlgebra(): Unit = with(multikAlgebra) {
|
||||||
one(2, 2) + 1.0
|
one(2, 2) + 1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,7 +29,7 @@ internal class MultikNDTest {
|
|||||||
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
val tensor1 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12224)
|
||||||
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
val tensor2 = DoubleTensorAlgebra.randomNormal(shape = intArrayOf(dim, dim), 12225)
|
||||||
|
|
||||||
val multikResult = with(DoubleField.multikAlgebra){
|
val multikResult = with(multikAlgebra) {
|
||||||
tensor1 dot tensor2
|
tensor1 dot tensor2
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user