From 94562179354d5cf86d5894826ebee0afa88ca771 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Wed, 3 Aug 2022 17:29:01 +0300 Subject: [PATCH] Update multik algebra --- gradle.properties | 2 +- kmath-multik/build.gradle.kts | 2 +- .../kmath/multik/MultikDoubleAlgebra.kt | 4 + .../kmath/multik/MultikFloatAlgebra.kt | 22 ++++ .../kscience/kmath/multik/MultikIntAlgebra.kt | 20 ++++ .../kmath/multik/MultikLongAlgebra.kt | 22 ++++ .../kmath/multik/MultikShortAlgebra.kt | 20 ++++ .../kscience/kmath/multik/MultikTensor.kt | 40 +++++++ .../kmath/multik/MultikTensorAlgebra.kt | 102 +++++------------- 9 files changed, 154 insertions(+), 80 deletions(-) create mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt create mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt create mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt create mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt create mode 100644 kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt diff --git a/gradle.properties b/gradle.properties index ca8c91191..b3e3f4cda 100644 --- a/gradle.properties +++ b/gradle.properties @@ -6,7 +6,7 @@ kotlin.code.style=official kotlin.jupyter.add.scanner=false kotlin.mpp.stability.nowarn=true kotlin.native.ignoreDisabledTargets=true -//kotlin.incremental.js.ir=true +kotlin.incremental.js.ir=true org.gradle.configureondemand=true org.gradle.parallel=true diff --git a/kmath-multik/build.gradle.kts b/kmath-multik/build.gradle.kts index 0f31a183a..ee31ead3c 100644 --- a/kmath-multik/build.gradle.kts +++ b/kmath-multik/build.gradle.kts @@ -6,7 +6,7 @@ description = "JetBrains Multik connector" dependencies { api(project(":kmath-tensors")) - api("org.jetbrains.kotlinx:multik-default:0.1.0") + api("org.jetbrains.kotlinx:multik-default:0.2.0") } readme { diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt index 0de2d8349..1d4665618 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikDoubleAlgebra.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.multik +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarrayOf import org.jetbrains.kotlinx.multik.ndarray.data.DataType import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.nd.StructureND @@ -54,6 +56,8 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra): MultikTensor = arg.map { atanh(it) } + + override fun scalar(value: Double): MultikTensor = Multik.ndarrayOf(value).wrap() } public val Double.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikDoubleAlgebra diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt new file mode 100644 index 000000000..8dd9b3f60 --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikFloatAlgebra.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarrayOf +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.operations.FloatField + +public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { + override val elementAlgebra: FloatField get() = FloatField + override val type: DataType get() = DataType.FloatDataType + + override fun scalar(value: Float): MultikTensor = Multik.ndarrayOf(value).wrap() +} + + +public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra +public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt new file mode 100644 index 000000000..b9e0125e8 --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikIntAlgebra.kt @@ -0,0 +1,20 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarrayOf +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.operations.IntRing + +public object MultikIntAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: IntRing get() = IntRing + override val type: DataType get() = DataType.IntDataType + override fun scalar(value: Int): MultikTensor = Multik.ndarrayOf(value).wrap() +} + +public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra +public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt new file mode 100644 index 000000000..1f3dfb0e5 --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikLongAlgebra.kt @@ -0,0 +1,22 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarrayOf +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.operations.LongRing + +public object MultikLongAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: LongRing get() = LongRing + override val type: DataType get() = DataType.LongDataType + + override fun scalar(value: Long): MultikTensor = Multik.ndarrayOf(value).wrap() +} + + +public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra +public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt new file mode 100644 index 000000000..b9746571d --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikShortAlgebra.kt @@ -0,0 +1,20 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.api.Multik +import org.jetbrains.kotlinx.multik.api.ndarrayOf +import org.jetbrains.kotlinx.multik.ndarray.data.DataType +import space.kscience.kmath.operations.ShortRing + +public object MultikShortAlgebra : MultikTensorAlgebra() { + override val elementAlgebra: ShortRing get() = ShortRing + override val type: DataType get() = DataType.ShortDataType + override fun scalar(value: Short): MultikTensor = Multik.ndarrayOf(value).wrap() +} + +public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra +public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt new file mode 100644 index 000000000..7efe672b4 --- /dev/null +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensor.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2018-2021 KMath contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package space.kscience.kmath.multik + +import org.jetbrains.kotlinx.multik.ndarray.data.* +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.nd.Shape +import space.kscience.kmath.tensors.api.Tensor + +@JvmInline +public value class MultikTensor(public val array: MutableMultiArray) : Tensor { + override val shape: Shape get() = array.shape + + override fun get(index: IntArray): T = array[index] + + @PerformancePitfall + override fun elements(): Sequence> = + array.multiIndices.iterator().asSequence().map { it to get(it) } + + override fun set(index: IntArray, value: T) { + array[index] = value + } +} + + +internal fun MultiArray.asD1Array(): D1Array { + if (this is NDArray) + return this.asD1Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} + + +internal fun MultiArray.asD2Array(): D2Array { + if (this is NDArray) + return this.asD2Array() + else throw ClassCastException("Cannot cast MultiArray to NDArray.") +} \ No newline at end of file diff --git a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt index 250ef7e7f..1cc1b9b6d 100644 --- a/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt +++ b/kmath-multik/src/main/kotlin/space/kscience/kmath/multik/MultikTensorAlgebra.kt @@ -10,46 +10,16 @@ package space.kscience.kmath.multik import org.jetbrains.kotlinx.multik.api.* import org.jetbrains.kotlinx.multik.api.linalg.LinAlg import org.jetbrains.kotlinx.multik.api.math.Math +import org.jetbrains.kotlinx.multik.api.stat.Statistics import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.operations.* import space.kscience.kmath.misc.PerformancePitfall -import space.kscience.kmath.nd.DefaultStrides -import space.kscience.kmath.nd.Shape -import space.kscience.kmath.nd.StructureND -import space.kscience.kmath.nd.mapInPlace +import space.kscience.kmath.nd.* import space.kscience.kmath.operations.* import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra -@JvmInline -public value class MultikTensor(public val array: MutableMultiArray) : Tensor { - override val shape: Shape get() = array.shape - - override fun get(index: IntArray): T = array[index] - - @PerformancePitfall - override fun elements(): Sequence> = - array.multiIndices.iterator().asSequence().map { it to get(it) } - - override fun set(index: IntArray, value: T) { - array[index] = value - } -} - -private fun MultiArray.asD1Array(): D1Array { - if (this is NDArray) - return this.asD1Array() - else throw ClassCastException("Cannot cast MultiArray to NDArray.") -} - - -private fun MultiArray.asD2Array(): D2Array { - if (this is NDArray) - return this.asD2Array() - else throw ClassCastException("Cannot cast MultiArray to NDArray.") -} - public abstract class MultikTensorAlgebra> : TensorAlgebra where T : Number, T : Comparable { @@ -59,7 +29,6 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra protected val multikLinAl: LinAlg = mk.linalg protected val multikStat: Statistics = mk.stat - override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor { val strides = DefaultStrides(shape) val memoryView = initMemoryView(strides.linearSize, type) @@ -240,11 +209,15 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra override fun Tensor.viewAs(other: StructureND): MultikTensor = view(other.shape) + public abstract fun scalar(value: T): MultikTensor + override fun StructureND.dot(other: StructureND): MultikTensor = if (this.shape.size == 1 && other.shape.size == 1) { - Multik.ndarrayOf( - multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array()) - ).wrap() + scalar( + multikLinAl.linAlgEx.dotVV( + asMultik().array.asD1Array(), other.asMultik().array.asD1Array() + ) + ) } else if (this.shape.size == 2 && other.shape.size == 2) { multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap() } else if (this.shape.size == 2 && other.shape.size == 1) { @@ -254,41 +227,46 @@ public abstract class MultikTensorAlgebra> : TensorAlgebra } override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): MultikTensor { + TODO("Diagonal embedding not implemented") } - override fun StructureND.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> - elementAlgebra.add(acc, t) - } + override fun StructureND.sum(): T = multikMath.sum(asMultik().array) override fun StructureND.sum(dim: Int, keepDim: Boolean): MultikTensor { - TODO("Not yet implemented") + if (keepDim) TODO("keepDim not implemented") + return multikMath.sumDN(asMultik().array, dim).wrap() } override fun StructureND.min(): T? = asMultik().array.min() override fun StructureND.min(dim: Int, keepDim: Boolean): Tensor { - TODO("Not yet implemented") + if (keepDim) TODO("keepDim not implemented") + return multikMath.minDN(asMultik().array, dim).wrap() } override fun StructureND.max(): T? = asMultik().array.max() override fun StructureND.max(dim: Int, keepDim: Boolean): Tensor { - TODO("Not yet implemented") + if (keepDim) TODO("keepDim not implemented") + return multikMath.maxDN(asMultik().array, dim).wrap() } override fun StructureND.argMax(dim: Int, keepDim: Boolean): Tensor { - TODO("Not yet implemented") + if (keepDim) TODO("keepDim not implemented") + val res = multikMath.argMaxDN(asMultik().array, dim) + return with(MultikIntAlgebra) { res.wrap() } } } public abstract class MultikDivisionTensorAlgebra> : MultikTensorAlgebra(), TensorPartialDivisionAlgebra where T : Number, T : Comparable { - override fun T.div(arg: StructureND): MultikTensor = arg.map { elementAlgebra.divide(this@div, it) } + override fun T.div(arg: StructureND): MultikTensor = + Multik.ones(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap() override fun StructureND.div(arg: T): MultikTensor = - asMultik().array.deepCopy().apply { divAssign(arg) }.wrap() + asMultik().array.div(arg).wrap() override fun StructureND.div(arg: StructureND): MultikTensor = asMultik().array.div(arg.asMultik().array).wrap() @@ -308,36 +286,4 @@ public abstract class MultikDivisionTensorAlgebra> mapInPlace { index, t -> elementAlgebra.divide(t, arg[index]) } } } -} - -public object MultikFloatAlgebra : MultikDivisionTensorAlgebra() { - override val elementAlgebra: FloatField get() = FloatField - override val type: DataType get() = DataType.FloatDataType -} - -public val Float.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra -public val FloatField.multikAlgebra: MultikTensorAlgebra get() = MultikFloatAlgebra - -public object MultikShortAlgebra : MultikTensorAlgebra() { - override val elementAlgebra: ShortRing get() = ShortRing - override val type: DataType get() = DataType.ShortDataType -} - -public val Short.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra -public val ShortRing.multikAlgebra: MultikTensorAlgebra get() = MultikShortAlgebra - -public object MultikIntAlgebra : MultikTensorAlgebra() { - override val elementAlgebra: IntRing get() = IntRing - override val type: DataType get() = DataType.IntDataType -} - -public val Int.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra -public val IntRing.multikAlgebra: MultikTensorAlgebra get() = MultikIntAlgebra - -public object MultikLongAlgebra : MultikTensorAlgebra() { - override val elementAlgebra: LongRing get() = LongRing - override val type: DataType get() = DataType.LongDataType -} - -public val Long.Companion.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra -public val LongRing.multikAlgebra: MultikTensorAlgebra get() = MultikLongAlgebra \ No newline at end of file +} \ No newline at end of file