forked from kscience/kmath
Update multik algebra
This commit is contained in:
parent
0e9072710f
commit
9456217935
@ -6,7 +6,7 @@ kotlin.code.style=official
|
|||||||
kotlin.jupyter.add.scanner=false
|
kotlin.jupyter.add.scanner=false
|
||||||
kotlin.mpp.stability.nowarn=true
|
kotlin.mpp.stability.nowarn=true
|
||||||
kotlin.native.ignoreDisabledTargets=true
|
kotlin.native.ignoreDisabledTargets=true
|
||||||
//kotlin.incremental.js.ir=true
|
kotlin.incremental.js.ir=true
|
||||||
|
|
||||||
org.gradle.configureondemand=true
|
org.gradle.configureondemand=true
|
||||||
org.gradle.parallel=true
|
org.gradle.parallel=true
|
||||||
|
@ -6,7 +6,7 @@ description = "JetBrains Multik connector"
|
|||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
api(project(":kmath-tensors"))
|
api(project(":kmath-tensors"))
|
||||||
api("org.jetbrains.kotlinx:multik-default:0.1.0")
|
api("org.jetbrains.kotlinx:multik-default:0.2.0")
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
readme {
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
@ -54,6 +56,8 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
|
|||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
|
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
|
||||||
|
|
||||||
|
override fun scalar(value: Double): MultikTensor<Double> = Multik.ndarrayOf(value).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
@ -0,0 +1,22 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.operations.FloatField
|
||||||
|
|
||||||
|
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||||
|
override val elementAlgebra: FloatField get() = FloatField
|
||||||
|
override val type: DataType get() = DataType.FloatDataType
|
||||||
|
|
||||||
|
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
||||||
|
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
|
|
||||||
|
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
|
||||||
|
override val elementAlgebra: IntRing get() = IntRing
|
||||||
|
override val type: DataType get() = DataType.IntDataType
|
||||||
|
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
||||||
|
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
@ -0,0 +1,22 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.operations.LongRing
|
||||||
|
|
||||||
|
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
||||||
|
override val elementAlgebra: LongRing get() = LongRing
|
||||||
|
override val type: DataType get() = DataType.LongDataType
|
||||||
|
|
||||||
|
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
||||||
|
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
@ -0,0 +1,20 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.operations.ShortRing
|
||||||
|
|
||||||
|
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
|
||||||
|
override val elementAlgebra: ShortRing get() = ShortRing
|
||||||
|
override val type: DataType get() = DataType.ShortDataType
|
||||||
|
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
||||||
|
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
|
||||||
|
@JvmInline
|
||||||
|
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||||
|
override val shape: Shape get() = array.shape
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = array[index]
|
||||||
|
|
||||||
|
@PerformancePitfall
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
|
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
internal fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
||||||
|
if (this is NDArray<T, D>)
|
||||||
|
return this.asD1Array()
|
||||||
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
internal fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||||
|
if (this is NDArray<T, D>)
|
||||||
|
return this.asD2Array()
|
||||||
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
|
}
|
@ -10,46 +10,16 @@ package space.kscience.kmath.multik
|
|||||||
import org.jetbrains.kotlinx.multik.api.*
|
import org.jetbrains.kotlinx.multik.api.*
|
||||||
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
||||||
import org.jetbrains.kotlinx.multik.api.math.Math
|
import org.jetbrains.kotlinx.multik.api.math.Math
|
||||||
|
import org.jetbrains.kotlinx.multik.api.stat.Statistics
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.DefaultStrides
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.Shape
|
|
||||||
import space.kscience.kmath.nd.StructureND
|
|
||||||
import space.kscience.kmath.nd.mapInPlace
|
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||||
|
|
||||||
@JvmInline
|
|
||||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
|
||||||
override val shape: Shape get() = array.shape
|
|
||||||
|
|
||||||
override fun get(index: IntArray): T = array[index]
|
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
|
||||||
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
|
||||||
|
|
||||||
override fun set(index: IntArray, value: T) {
|
|
||||||
array[index] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
|
||||||
if (this is NDArray<T, D>)
|
|
||||||
return this.asD1Array()
|
|
||||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
|
||||||
if (this is NDArray<T, D>)
|
|
||||||
return this.asD2Array()
|
|
||||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||||
where T : Number, T : Comparable<T> {
|
where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
@ -59,7 +29,6 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
protected val multikLinAl: LinAlg = mk.linalg
|
protected val multikLinAl: LinAlg = mk.linalg
|
||||||
protected val multikStat: Statistics = mk.stat
|
protected val multikStat: Statistics = mk.stat
|
||||||
|
|
||||||
|
|
||||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||||
@ -240,11 +209,15 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
|
|
||||||
override fun Tensor<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
|
override fun Tensor<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
|
||||||
|
|
||||||
|
public abstract fun scalar(value: T): MultikTensor<T>
|
||||||
|
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
||||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
Multik.ndarrayOf(
|
scalar(
|
||||||
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
|
multikLinAl.linAlgEx.dotVV(
|
||||||
).wrap()
|
asMultik().array.asD1Array(), other.asMultik().array.asD1Array()
|
||||||
|
)
|
||||||
|
)
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||||
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||||
@ -254,41 +227,46 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||||
|
|
||||||
TODO("Diagonal embedding not implemented")
|
TODO("Diagonal embedding not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
override fun StructureND<T>.sum(): T = multikMath.sum(asMultik().array)
|
||||||
elementAlgebra.add(acc, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||||
TODO("Not yet implemented")
|
if (keepDim) TODO("keepDim not implemented")
|
||||||
|
return multikMath.sumDN(asMultik().array, dim).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.min(): T? = asMultik().array.min()
|
override fun StructureND<T>.min(): T? = asMultik().array.min()
|
||||||
|
|
||||||
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||||
TODO("Not yet implemented")
|
if (keepDim) TODO("keepDim not implemented")
|
||||||
|
return multikMath.minDN(asMultik().array, dim).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.max(): T? = asMultik().array.max()
|
override fun StructureND<T>.max(): T? = asMultik().array.max()
|
||||||
|
|
||||||
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
|
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||||
TODO("Not yet implemented")
|
if (keepDim) TODO("keepDim not implemented")
|
||||||
|
return multikMath.maxDN(asMultik().array, dim).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
||||||
TODO("Not yet implemented")
|
if (keepDim) TODO("keepDim not implemented")
|
||||||
|
val res = multikMath.argMaxDN(asMultik().array, dim)
|
||||||
|
return with(MultikIntAlgebra) { res.wrap() }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
||||||
: MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
|
: MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
override fun T.div(arg: StructureND<T>): MultikTensor<T> = arg.map { elementAlgebra.divide(this@div, it) }
|
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
||||||
|
Multik.ones<T, DN>(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
|
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
|
||||||
asMultik().array.deepCopy().apply { divAssign(arg) }.wrap()
|
asMultik().array.div(arg).wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.div(arg: StructureND<T>): MultikTensor<T> =
|
override fun StructureND<T>.div(arg: StructureND<T>): MultikTensor<T> =
|
||||||
asMultik().array.div(arg.asMultik().array).wrap()
|
asMultik().array.div(arg.asMultik().array).wrap()
|
||||||
@ -309,35 +287,3 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
|
||||||
override val elementAlgebra: FloatField get() = FloatField
|
|
||||||
override val type: DataType get() = DataType.FloatDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
|
||||||
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
|
||||||
|
|
||||||
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
|
|
||||||
override val elementAlgebra: ShortRing get() = ShortRing
|
|
||||||
override val type: DataType get() = DataType.ShortDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
|
||||||
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
|
||||||
|
|
||||||
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
|
|
||||||
override val elementAlgebra: IntRing get() = IntRing
|
|
||||||
override val type: DataType get() = DataType.IntDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
|
||||||
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
|
||||||
|
|
||||||
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
|
||||||
override val elementAlgebra: LongRing get() = LongRing
|
|
||||||
override val type: DataType get() = DataType.LongDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
|
||||||
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
|
Loading…
Reference in New Issue
Block a user