Multik wrapper prototype
This commit is contained in:
parent
8d2770c275
commit
0ac5363acf
@ -29,6 +29,11 @@ dependencies {
|
|||||||
implementation(project(":kmath-tensors"))
|
implementation(project(":kmath-tensors"))
|
||||||
implementation(project(":kmath-symja"))
|
implementation(project(":kmath-symja"))
|
||||||
implementation(project(":kmath-for-real"))
|
implementation(project(":kmath-for-real"))
|
||||||
|
//jafama
|
||||||
|
implementation(project(":kmath-jafama"))
|
||||||
|
//multik
|
||||||
|
implementation(projects.kmathMultik)
|
||||||
|
|
||||||
|
|
||||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||||
|
|
||||||
@ -42,11 +47,12 @@ dependencies {
|
|||||||
// } else
|
// } else
|
||||||
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
|
||||||
implementation("org.slf4j:slf4j-simple:1.7.31")
|
// multik implementation
|
||||||
|
implementation("org.jetbrains.kotlinx:multik-default:0.1.0")
|
||||||
|
|
||||||
|
implementation("org.slf4j:slf4j-simple:1.7.32")
|
||||||
// plotting
|
// plotting
|
||||||
implementation("space.kscience:plotlykt-server:0.4.2")
|
implementation("space.kscience:plotlykt-server:0.5.0")
|
||||||
//jafama
|
|
||||||
implementation(project(":kmath-jafama"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kotlin.sourceSets.all {
|
kotlin.sourceSets.all {
|
||||||
|
@ -0,0 +1,21 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.api.Multik
|
||||||
|
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||||
|
import org.jetbrains.kotlinx.multik.api.ndarray
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.operations.unaryMinus
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val a = Multik.ndarray(intArrayOf(1, 2, 3))
|
||||||
|
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0))
|
||||||
|
2 + (-a) - 2
|
||||||
|
|
||||||
|
a dot a
|
||||||
|
}
|
14
kmath-multik/build.gradle.kts
Normal file
14
kmath-multik/build.gradle.kts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
plugins {
|
||||||
|
id("ru.mipt.npm.gradle.jvm")
|
||||||
|
}
|
||||||
|
|
||||||
|
description = "JetBrains Multik connector"
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
api(project(":kmath-tensors"))
|
||||||
|
api("org.jetbrains.kotlinx:multik-api:0.1.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
readme {
|
||||||
|
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||||
|
}
|
@ -0,0 +1,199 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
|
import space.kscience.kmath.nd.mapInPlace
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
|
|
||||||
|
@JvmInline
|
||||||
|
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||||
|
override val shape: IntArray get() = array.shape
|
||||||
|
|
||||||
|
override fun get(index: IntArray): T = array[index]
|
||||||
|
|
||||||
|
@PerformancePitfall
|
||||||
|
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||||
|
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||||
|
|
||||||
|
override fun set(index: IntArray, value: T) {
|
||||||
|
array[index] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public abstract class MultikTensorAlgebra<T>(
|
||||||
|
public val elementAlgebra: Ring<T>,
|
||||||
|
public val comparator: Comparator<T>
|
||||||
|
) : TensorAlgebra<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
|
||||||
|
* are not reflected back onto the source
|
||||||
|
*/
|
||||||
|
public fun Tensor<T>.asMultik(): MultikTensor<T> {
|
||||||
|
return if (this is MultikTensor) {
|
||||||
|
this
|
||||||
|
} else {
|
||||||
|
TODO()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
||||||
|
|
||||||
|
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
|
||||||
|
get(intArrayOf(0))
|
||||||
|
} else null
|
||||||
|
|
||||||
|
override fun T.plus(other: Tensor<T>): MultikTensor<T> =
|
||||||
|
other.plus(this)
|
||||||
|
|
||||||
|
override fun Tensor<T>.plus(value: T): MultikTensor<T> =
|
||||||
|
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.plus(other.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.plusAssign(value: T) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.plusAssign(value)
|
||||||
|
} else {
|
||||||
|
mapInPlace { _, t -> elementAlgebra.add(t, value) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.plusAssign(other.asMultik().array)
|
||||||
|
} else {
|
||||||
|
mapInPlace { index, t -> elementAlgebra.add(t, other[index]) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO avoid additional copy
|
||||||
|
override fun T.minus(other: Tensor<T>): MultikTensor<T> = -(other - this)
|
||||||
|
|
||||||
|
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
|
||||||
|
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.minus(other.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.minusAssign(value: T) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.minusAssign(value)
|
||||||
|
} else {
|
||||||
|
mapInPlace { _, t -> elementAlgebra.run { t - value } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.minusAssign(other.asMultik().array)
|
||||||
|
} else {
|
||||||
|
mapInPlace { index, t -> elementAlgebra.run { t - other[index] } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun T.times(other: Tensor<T>): MultikTensor<T> =
|
||||||
|
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.times(value: T): Tensor<T> =
|
||||||
|
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> =
|
||||||
|
asMultik().array.times(other.asMultik().array).wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.timesAssign(value: T) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.timesAssign(value)
|
||||||
|
} else {
|
||||||
|
mapInPlace { _, t -> elementAlgebra.multiply(t, value) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
|
||||||
|
if (this is MultikTensor) {
|
||||||
|
array.timesAssign(other.asMultik().array)
|
||||||
|
} else {
|
||||||
|
mapInPlace { index, t -> elementAlgebra.multiply(t, other[index]) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
|
||||||
|
asMultik().array.unaryMinus().wrap()
|
||||||
|
|
||||||
|
override fun Tensor<T>.get(i: Int): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
||||||
|
require(shape.all { it > 0 })
|
||||||
|
require(shape.fold(1, Int::times) == this.shape.size) {
|
||||||
|
"Cannot reshape array of size ${this.shape.size} into a new shape ${
|
||||||
|
shape.joinToString(
|
||||||
|
prefix = "(",
|
||||||
|
postfix = ")"
|
||||||
|
)
|
||||||
|
}"
|
||||||
|
}
|
||||||
|
|
||||||
|
val mt = asMultik().array
|
||||||
|
return if (mt.shape.contentEquals(shape)) {
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
this as NDArray<T, DN>
|
||||||
|
} else {
|
||||||
|
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
|
||||||
|
}.wrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
||||||
|
elementAlgebra.add(acc, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.min(): T =
|
||||||
|
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
|
||||||
|
|
||||||
|
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.max(): T =
|
||||||
|
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
|
||||||
|
|
||||||
|
|
||||||
|
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user