forked from kscience/kmath
Update multik algebra
This commit is contained in:
parent
0e9072710f
commit
9456217935
@ -6,7 +6,7 @@ kotlin.code.style=official
|
||||
kotlin.jupyter.add.scanner=false
|
||||
kotlin.mpp.stability.nowarn=true
|
||||
kotlin.native.ignoreDisabledTargets=true
|
||||
//kotlin.incremental.js.ir=true
|
||||
kotlin.incremental.js.ir=true
|
||||
|
||||
org.gradle.configureondemand=true
|
||||
org.gradle.parallel=true
|
||||
|
@ -6,7 +6,7 @@ description = "JetBrains Multik connector"
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-tensors"))
|
||||
api("org.jetbrains.kotlinx:multik-default:0.1.0")
|
||||
api("org.jetbrains.kotlinx:multik-default:0.2.0")
|
||||
}
|
||||
|
||||
readme {
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
@ -54,6 +56,8 @@ public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleFi
|
||||
|
||||
@PerformancePitfall
|
||||
override fun atanh(arg: StructureND<Double>): MultikTensor<Double> = arg.map { atanh(it) }
|
||||
|
||||
override fun scalar(value: Double): MultikTensor<Double> = Multik.ndarrayOf(value).wrap()
|
||||
}
|
||||
|
||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||
|
@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.operations.FloatField
|
||||
|
||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||
override val elementAlgebra: FloatField get() = FloatField
|
||||
override val type: DataType get() = DataType.FloatDataType
|
||||
|
||||
override fun scalar(value: Float): MultikTensor<Float> = Multik.ndarrayOf(value).wrap()
|
||||
}
|
||||
|
||||
|
||||
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
||||
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.operations.IntRing
|
||||
|
||||
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
|
||||
override val elementAlgebra: IntRing get() = IntRing
|
||||
override val type: DataType get() = DataType.IntDataType
|
||||
override fun scalar(value: Int): MultikTensor<Int> = Multik.ndarrayOf(value).wrap()
|
||||
}
|
||||
|
||||
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
||||
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.operations.LongRing
|
||||
|
||||
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
||||
override val elementAlgebra: LongRing get() = LongRing
|
||||
override val type: DataType get() = DataType.LongDataType
|
||||
|
||||
override fun scalar(value: Long): MultikTensor<Long> = Multik.ndarrayOf(value).wrap()
|
||||
}
|
||||
|
||||
|
||||
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
||||
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.operations.ShortRing
|
||||
|
||||
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
|
||||
override val elementAlgebra: ShortRing get() = ShortRing
|
||||
override val type: DataType get() = DataType.ShortDataType
|
||||
override fun scalar(value: Short): MultikTensor<Short> = Multik.ndarrayOf(value).wrap()
|
||||
}
|
||||
|
||||
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
||||
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
|
||||
@JvmInline
|
||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||
override val shape: Shape get() = array.shape
|
||||
|
||||
override fun get(index: IntArray): T = array[index]
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
array[index] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
internal fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD1Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
|
||||
internal fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD2Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
@ -10,46 +10,16 @@ package space.kscience.kmath.multik
|
||||
import org.jetbrains.kotlinx.multik.api.*
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
||||
import org.jetbrains.kotlinx.multik.api.math.Math
|
||||
import org.jetbrains.kotlinx.multik.api.stat.Statistics
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.DefaultStrides
|
||||
import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.nd.mapInPlace
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||
|
||||
@JvmInline
|
||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||
override val shape: Shape get() = array.shape
|
||||
|
||||
override fun get(index: IntArray): T = array[index]
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
array[index] = value
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T, D : Dimension> MultiArray<T, D>.asD1Array(): D1Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD1Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
|
||||
private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||
if (this is NDArray<T, D>)
|
||||
return this.asD2Array()
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
where T : Number, T : Comparable<T> {
|
||||
|
||||
@ -59,7 +29,6 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
protected val multikLinAl: LinAlg = mk.linalg
|
||||
protected val multikStat: Statistics = mk.stat
|
||||
|
||||
|
||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||
val strides = DefaultStrides(shape)
|
||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||
@ -240,11 +209,15 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
|
||||
override fun Tensor<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
|
||||
|
||||
public abstract fun scalar(value: T): MultikTensor<T>
|
||||
|
||||
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||
Multik.ndarrayOf(
|
||||
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
|
||||
).wrap()
|
||||
scalar(
|
||||
multikLinAl.linAlgEx.dotVV(
|
||||
asMultik().array.asD1Array(), other.asMultik().array.asD1Array()
|
||||
)
|
||||
)
|
||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
||||
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||
@ -254,41 +227,46 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||
|
||||
TODO("Diagonal embedding not implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
||||
elementAlgebra.add(acc, t)
|
||||
}
|
||||
override fun StructureND<T>.sum(): T = multikMath.sum(asMultik().array)
|
||||
|
||||
override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
if (keepDim) TODO("keepDim not implemented")
|
||||
return multikMath.sumDN(asMultik().array, dim).wrap()
|
||||
}
|
||||
|
||||
override fun StructureND<T>.min(): T? = asMultik().array.min()
|
||||
|
||||
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
if (keepDim) TODO("keepDim not implemented")
|
||||
return multikMath.minDN(asMultik().array, dim).wrap()
|
||||
}
|
||||
|
||||
override fun StructureND<T>.max(): T? = asMultik().array.max()
|
||||
|
||||
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
if (keepDim) TODO("keepDim not implemented")
|
||||
return multikMath.maxDN(asMultik().array, dim).wrap()
|
||||
}
|
||||
|
||||
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> {
|
||||
TODO("Not yet implemented")
|
||||
if (keepDim) TODO("keepDim not implemented")
|
||||
val res = multikMath.argMaxDN(asMultik().array, dim)
|
||||
return with(MultikIntAlgebra) { res.wrap() }
|
||||
}
|
||||
}
|
||||
|
||||
public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
||||
: MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||
|
||||
override fun T.div(arg: StructureND<T>): MultikTensor<T> = arg.map { elementAlgebra.divide(this@div, it) }
|
||||
override fun T.div(arg: StructureND<T>): MultikTensor<T> =
|
||||
Multik.ones<T, DN>(arg.shape, type).apply { divAssign(arg.asMultik().array) }.wrap()
|
||||
|
||||
override fun StructureND<T>.div(arg: T): MultikTensor<T> =
|
||||
asMultik().array.deepCopy().apply { divAssign(arg) }.wrap()
|
||||
asMultik().array.div(arg).wrap()
|
||||
|
||||
override fun StructureND<T>.div(arg: StructureND<T>): MultikTensor<T> =
|
||||
asMultik().array.div(arg.asMultik().array).wrap()
|
||||
@ -309,35 +287,3 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||
override val elementAlgebra: FloatField get() = FloatField
|
||||
override val type: DataType get() = DataType.FloatDataType
|
||||
}
|
||||
|
||||
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
||||
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
|
||||
|
||||
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
|
||||
override val elementAlgebra: ShortRing get() = ShortRing
|
||||
override val type: DataType get() = DataType.ShortDataType
|
||||
}
|
||||
|
||||
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
||||
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
|
||||
|
||||
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
|
||||
override val elementAlgebra: IntRing get() = IntRing
|
||||
override val type: DataType get() = DataType.IntDataType
|
||||
}
|
||||
|
||||
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
||||
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
|
||||
|
||||
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
|
||||
override val elementAlgebra: LongRing get() = LongRing
|
||||
override val type: DataType get() = DataType.LongDataType
|
||||
}
|
||||
|
||||
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
||||
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
|
Loading…
Reference in New Issue
Block a user