Multik wrapper prototype
This commit is contained in:
parent
8d2770c275
commit
0ac5363acf
@ -29,6 +29,11 @@ dependencies {
|
||||
implementation(project(":kmath-tensors"))
|
||||
implementation(project(":kmath-symja"))
|
||||
implementation(project(":kmath-for-real"))
|
||||
//jafama
|
||||
implementation(project(":kmath-jafama"))
|
||||
//multik
|
||||
implementation(projects.kmathMultik)
|
||||
|
||||
|
||||
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||
|
||||
@ -42,11 +47,12 @@ dependencies {
|
||||
// } else
|
||||
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||
|
||||
implementation("org.slf4j:slf4j-simple:1.7.31")
|
||||
// multik implementation
|
||||
implementation("org.jetbrains.kotlinx:multik-default:0.1.0")
|
||||
|
||||
implementation("org.slf4j:slf4j-simple:1.7.32")
|
||||
// plotting
|
||||
implementation("space.kscience:plotlykt-server:0.4.2")
|
||||
//jafama
|
||||
implementation(project(":kmath-jafama"))
|
||||
implementation("space.kscience:plotlykt-server:0.5.0")
|
||||
}
|
||||
|
||||
kotlin.sourceSets.all {
|
||||
|
@ -0,0 +1,21 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||
import org.jetbrains.kotlinx.multik.api.ndarray
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.unaryMinus
|
||||
|
||||
fun main() {
|
||||
val a = Multik.ndarray(intArrayOf(1, 2, 3))
|
||||
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0))
|
||||
2 + (-a) - 2
|
||||
|
||||
a dot a
|
||||
}
|
14
kmath-multik/build.gradle.kts
Normal file
14
kmath-multik/build.gradle.kts
Normal file
@ -0,0 +1,14 @@
|
||||
plugins {
|
||||
id("ru.mipt.npm.gradle.jvm")
|
||||
}
|
||||
|
||||
description = "JetBrains Multik connector"
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-tensors"))
|
||||
api("org.jetbrains.kotlinx:multik-api:0.1.0")
|
||||
}
|
||||
|
||||
readme {
|
||||
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||
}
|
@ -0,0 +1,199 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.mapInPlace
|
||||
import space.kscience.kmath.operations.Ring
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
|
||||
@JvmInline
|
||||
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
|
||||
override val shape: IntArray get() = array.shape
|
||||
|
||||
override fun get(index: IntArray): T = array[index]
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> =
|
||||
array.multiIndices.iterator().asSequence().map { it to get(it) }
|
||||
|
||||
override fun set(index: IntArray, value: T) {
|
||||
array[index] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public abstract class MultikTensorAlgebra<T>(
|
||||
public val elementAlgebra: Ring<T>,
|
||||
public val comparator: Comparator<T>
|
||||
) : TensorAlgebra<T> {
|
||||
|
||||
/**
|
||||
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
|
||||
* are not reflected back onto the source
|
||||
*/
|
||||
public fun Tensor<T>.asMultik(): MultikTensor<T> {
|
||||
return if (this is MultikTensor) {
|
||||
this
|
||||
} else {
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
||||
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
|
||||
|
||||
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
|
||||
get(intArrayOf(0))
|
||||
} else null
|
||||
|
||||
override fun T.plus(other: Tensor<T>): MultikTensor<T> =
|
||||
other.plus(this)
|
||||
|
||||
override fun Tensor<T>.plus(value: T): MultikTensor<T> =
|
||||
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.plus(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.plusAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.plusAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.add(t, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.plusAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.plusAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.add(t, other[index]) }
|
||||
}
|
||||
}
|
||||
|
||||
//TODO avoid additional copy
|
||||
override fun T.minus(other: Tensor<T>): MultikTensor<T> = -(other - this)
|
||||
|
||||
override fun Tensor<T>.minus(value: T): MultikTensor<T> =
|
||||
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.minus(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.minusAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.minusAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.run { t - value } }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.minusAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.minusAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.run { t - other[index] } }
|
||||
}
|
||||
}
|
||||
|
||||
override fun T.times(other: Tensor<T>): MultikTensor<T> =
|
||||
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
|
||||
|
||||
override fun Tensor<T>.times(value: T): Tensor<T> =
|
||||
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap()
|
||||
|
||||
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> =
|
||||
asMultik().array.times(other.asMultik().array).wrap()
|
||||
|
||||
override fun Tensor<T>.timesAssign(value: T) {
|
||||
if (this is MultikTensor) {
|
||||
array.timesAssign(value)
|
||||
} else {
|
||||
mapInPlace { _, t -> elementAlgebra.multiply(t, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.timesAssign(other: Tensor<T>) {
|
||||
if (this is MultikTensor) {
|
||||
array.timesAssign(other.asMultik().array)
|
||||
} else {
|
||||
mapInPlace { index, t -> elementAlgebra.multiply(t, other[index]) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun Tensor<T>.unaryMinus(): MultikTensor<T> =
|
||||
asMultik().array.unaryMinus().wrap()
|
||||
|
||||
override fun Tensor<T>.get(i: Int): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
||||
require(shape.all { it > 0 })
|
||||
require(shape.fold(1, Int::times) == this.shape.size) {
|
||||
"Cannot reshape array of size ${this.shape.size} into a new shape ${
|
||||
shape.joinToString(
|
||||
prefix = "(",
|
||||
postfix = ")"
|
||||
)
|
||||
}"
|
||||
}
|
||||
|
||||
val mt = asMultik().array
|
||||
return if (mt.shape.contentEquals(shape)) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
this as NDArray<T, DN>
|
||||
} else {
|
||||
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
|
||||
}.wrap()
|
||||
}
|
||||
|
||||
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun diagonalEmbedding(diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
|
||||
elementAlgebra.add(acc, t)
|
||||
}
|
||||
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.min(): T =
|
||||
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
|
||||
|
||||
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.max(): T =
|
||||
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
|
||||
|
||||
|
||||
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user