Multik wrapper prototype

This commit is contained in:
Alexander Nozik 2021-10-16 11:10:34 +03:00
parent 8d2770c275
commit 0ac5363acf
4 changed files with 244 additions and 4 deletions

View File

@ -29,6 +29,11 @@ dependencies {
implementation(project(":kmath-tensors")) implementation(project(":kmath-tensors"))
implementation(project(":kmath-symja")) implementation(project(":kmath-symja"))
implementation(project(":kmath-for-real")) implementation(project(":kmath-for-real"))
//jafama
implementation(project(":kmath-jafama"))
//multik
implementation(projects.kmathMultik)
implementation("org.nd4j:nd4j-native:1.0.0-beta7") implementation("org.nd4j:nd4j-native:1.0.0-beta7")
@ -42,11 +47,12 @@ dependencies {
// } else // } else
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7") implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
implementation("org.slf4j:slf4j-simple:1.7.31") // multik implementation
implementation("org.jetbrains.kotlinx:multik-default:0.1.0")
implementation("org.slf4j:slf4j-simple:1.7.32")
// plotting // plotting
implementation("space.kscience:plotlykt-server:0.4.2") implementation("space.kscience:plotlykt-server:0.5.0")
//jafama
implementation(project(":kmath-jafama"))
} }
kotlin.sourceSets.all { kotlin.sourceSets.all {

View File

@ -0,0 +1,21 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.tensors
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.linalg.dot
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
import org.jetbrains.kotlinx.multik.ndarray.operations.unaryMinus
fun main() {
val a = Multik.ndarray(intArrayOf(1, 2, 3))
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0))
2 + (-a) - 2
a dot a
}

View File

@ -0,0 +1,14 @@
plugins {
id("ru.mipt.npm.gradle.jvm")
}
description = "JetBrains Multik connector"
dependencies {
api(project(":kmath-tensors"))
api("org.jetbrains.kotlinx:multik-api:0.1.0")
}
readme {
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
}

View File

@ -0,0 +1,199 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
*/
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.Ring
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
@JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: IntArray get() = array.shape
override fun get(index: IntArray): T = array[index]
@PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> =
array.multiIndices.iterator().asSequence().map { it to get(it) }
override fun set(index: IntArray, value: T) {
array[index] = value
}
}
public abstract class MultikTensorAlgebra<T>(
public val elementAlgebra: Ring<T>,
public val comparator: Comparator<T>
) : TensorAlgebra<T> {
/**
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source
*/
public fun Tensor<T>.asMultik(): MultikTensor<T> {
return if (this is MultikTensor) {
this
} else {
TODO()
}
}
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0))
} else null
override fun T.plus(other: Tensor<T>): MultikTensor<T> =
other.plus(this)
override fun Tensor<T>.plus(value: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap()
override fun Tensor<T>.plusAssign(value: T) {
if (this is MultikTensor) {
array.plusAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.add(t, value) }
}
}
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.plusAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.add(t, other[index]) }
}
}
//TODO avoid additional copy
override fun T.minus(other: Tensor<T>): MultikTensor<T> = -(other - this)
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap()
override fun Tensor<T>.minusAssign(value: T) {
if (this is MultikTensor) {
array.minusAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.run { t - value } }
}
}
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.minusAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.run { t - other[index] } }
}
}
override fun T.times(other: Tensor<T>): MultikTensor<T> =
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
override fun Tensor<T>.times(value: T): Tensor<T> =
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap()
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap()
override fun Tensor<T>.timesAssign(value: T) {
if (this is MultikTensor) {
array.timesAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.multiply(t, value) }
}
}
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
if (this is MultikTensor) {
array.timesAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.multiply(t, other[index]) }
}
}
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
require(shape.all { it > 0 })
require(shape.fold(1, Int::times) == this.shape.size) {
"Cannot reshape array of size ${this.shape.size} into a new shape ${
shape.joinToString(
prefix = "(",
postfix = ")"
)
}"
}
val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) {
@Suppress("UNCHECKED_CAST")
this as NDArray<T, DN>
} else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap()
}
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
elementAlgebra.add(acc, t)
}
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.min(): T =
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.max(): T =
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented")
}
}