Update multik algebra

This commit is contained in:
Alexander Nozik 2022-08-03 17:29:01 +03:00
parent 0e9072710f
commit 9456217935
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
9 changed files with 154 additions and 80 deletions

View File

@ -6,7 +6,7 @@ kotlin.code.style=official
kotlin.jupyter.add.scanner=false
kotlin.mpp.stability.nowarn=true
kotlin.native.ignoreDisabledTargets=true
//kotlin.incremental.js.ir=true
kotlin.incremental.js.ir=true
org.gradle.configureondemand=true
org.gradle.parallel=true

View File

@ -6,7 +6,7 @@ description = "JetBrains Multik connector"
dependencies {
api(project(":kmath-tensors"))
api("org.jetbrains.kotlinx:multik-default:0.1.0")
api("org.jetbrains.kotlinx:multik-default:0.2.0")
}
readme {

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND
@ -54,6 +56,8 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
@PerformancePitfall
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
override fun scalar(value: Double): MultikTensor<Double> = Multik.ndarrayOf(value).wrap()
}
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra

View File

@ -0,0 +1,22 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.operations.FloatField
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
override val elementAlgebra: FloatField get() = FloatField
override val type: DataType get() = DataType.FloatDataType
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
}
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra

View File

@ -0,0 +1,20 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.operations.IntRing
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
override val elementAlgebra: IntRing get() = IntRing
override val type: DataType get() = DataType.IntDataType
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
}
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra

View File

@ -0,0 +1,22 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.operations.LongRing
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
override val elementAlgebra: LongRing get() = LongRing
override val type: DataType get() = DataType.LongDataType
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
}
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra

View File

@ -0,0 +1,20 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarrayOf
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.operations.ShortRing
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
override val elementAlgebra: ShortRing get() = ShortRing
override val type: DataType get() = DataType.ShortDataType
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
}
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra

View File

@ -0,0 +1,40 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.ndarray.data.*
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.tensors.api.Tensor
@JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: Shape get() = array.shape
override fun get(index: IntArray): T = array[index]
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> =
array.multiIndices.iterator().asSequence().map { it to get(it) }
override fun set(index: IntArray, value: T) {
array[index] = value
}
}
internal fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
if (this is NDArray<T, D>)
return this.asD1Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}
internal fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
if (this is NDArray<T, D>)
return this.asD2Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}

View File

@ -10,46 +10,16 @@ package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.*
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
import org.jetbrains.kotlinx.multik.api.math.Math
import org.jetbrains.kotlinx.multik.api.stat.Statistics
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
@JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: Shape get() = array.shape
override fun get(index: IntArray): T = array[index]
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> =
array.multiIndices.iterator().asSequence().map { it to get(it) }
override fun set(index: IntArray, value: T) {
array[index] = value
}
}
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
if (this is NDArray<T, D>)
return this.asD1Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
if (this is NDArray<T, D>)
return this.asD2Array()
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
}
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
where T : Number, T : Comparable<T> {
@ -59,7 +29,6 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
protected val multikLinAl: LinAlg = mk.linalg
protected val multikStat: Statistics = mk.stat
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val strides = DefaultStrides(shape)
val memoryView = initMemoryView<T>(strides.linearSize, type)
@ -240,11 +209,15 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
override fun Tensor<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
public abstract fun scalar(value: T): MultikTensor<T>
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
if (this.shape.size == 1 && other.shape.size == 1) {
Multik.ndarrayOf(
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
).wrap()
scalar(
multikLinAl.linAlgEx.dotVV(
asMultik().array.asD1Array(), other.asMultik().array.asD1Array()
)
)
} else if (this.shape.size == 2 && other.shape.size == 2) {
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
} else if (this.shape.size == 2 && other.shape.size == 1) {
@ -254,41 +227,46 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
}
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
TODO("Diagonal embedding not implemented")
}
override fun StructureND<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
elementAlgebra.add(acc, t)
}
override fun StructureND<T>.sum(): T = multikMath.sum(asMultik().array)
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
if (keepDim) TODO("keepDim not implemented")
return multikMath.sumDN(asMultik().array, dim).wrap()
}
override fun StructureND<T>.min(): T? = asMultik().array.min()
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented")
if (keepDim) TODO("keepDim not implemented")
return multikMath.minDN(asMultik().array, dim).wrap()
}
override fun StructureND<T>.max(): T? = asMultik().array.max()
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented")
if (keepDim) TODO("keepDim not implemented")
return multikMath.maxDN(asMultik().array, dim).wrap()
}
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
TODO("Not yet implemented")
if (keepDim) TODO("keepDim not implemented")
val res = multikMath.argMaxDN(asMultik().array, dim)
return with(MultikIntAlgebra) { res.wrap() }
}
}
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
: MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
override fun T.div(arg: StructureND<T>): MultikTensor<T> = arg.map { elementAlgebra.divide(this@div, it) }
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
Multik.ones<T, DN>(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap()
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { divAssign(arg) }.wrap()
asMultik().array.div(arg).wrap()
override fun StructureND<T>.div(arg: StructureND<T>): MultikTensor<T> =
asMultik().array.div(arg.asMultik().array).wrap()
@ -308,36 +286,4 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
mapInPlace { index, t -> elementAlgebra.divide(t, arg[index]) }
}
}
}
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
override val elementAlgebra: FloatField get() = FloatField
override val type: DataType get() = DataType.FloatDataType
}
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
override val elementAlgebra: ShortRing get() = ShortRing
override val type: DataType get() = DataType.ShortDataType
}
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
override val elementAlgebra: IntRing get() = IntRing
override val type: DataType get() = DataType.IntDataType
}
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
override val elementAlgebra: LongRing get() = LongRing
override val type: DataType get() = DataType.LongDataType
}
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
}