more updates
This commit is contained in:
commit
1e7ee53c82
@ -13,6 +13,7 @@ import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||
import space.kscience.kmath.linear.invoke
|
||||
import space.kscience.kmath.linear.linearSpace
|
||||
import space.kscience.kmath.multik.multikAlgebra
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import kotlin.random.Random
|
||||
@ -58,6 +59,16 @@ internal class DotBenchmark {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
// @Benchmark
|
||||
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||
// blackhole.consume(matrix1 dot matrix2)
|
||||
// }
|
||||
|
||||
@Benchmark
|
||||
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
}
|
||||
|
||||
@Benchmark
|
||||
fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) {
|
||||
blackhole.consume(matrix1 dot matrix2)
|
||||
|
@ -35,7 +35,7 @@ public interface WithShape {
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
*/
|
||||
public interface AlgebraND<T, out C : Algebra<T>> {
|
||||
public interface AlgebraND<T, out C : Algebra<T>>: Algebra<StructureND<T>> {
|
||||
/**
|
||||
* The algebra over elements of ND structure.
|
||||
*/
|
||||
|
@ -5,7 +5,6 @@
|
||||
|
||||
package space.kscience.kmath.nd
|
||||
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
@ -27,11 +26,6 @@ public open class BufferND<out T>(
|
||||
|
||||
override val shape: IntArray get() = indices.shape
|
||||
|
||||
@PerformancePitfall
|
||||
override fun elements(): Sequence<Pair<IntArray, T>> = indices.asSequence().map {
|
||||
it to this[it]
|
||||
}
|
||||
|
||||
override fun toString(): String = StructureND.toString(this)
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
|
||||
/**
|
||||
* Stub for DSL the [Algebra] is.
|
||||
@ -99,6 +100,14 @@ public interface Algebra<T> {
|
||||
*/
|
||||
public fun binaryOperation(operation: String, left: T, right: T): T =
|
||||
binaryOperationFunction(operation)(left, right)
|
||||
|
||||
/**
|
||||
* Export an algebra element, so it could be accessed even after algebra scope is closed.
|
||||
* This method must be used on algebras where data is stored externally or any local algebra state is used.
|
||||
* By default (if not overridden), exports the object itself.
|
||||
*/
|
||||
@UnstableKMathAPI
|
||||
public fun export(arg: T): T = arg
|
||||
}
|
||||
|
||||
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
|
||||
@ -162,6 +171,7 @@ public interface GroupOps<T> : Algebra<T> {
|
||||
* @return the difference.
|
||||
*/
|
||||
public operator fun T.minus(arg: T): T = add(this, -arg)
|
||||
|
||||
// Dynamic dispatch of operations
|
||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||
PLUS_OPERATION -> { arg -> +arg }
|
||||
|
@ -0,0 +1,136 @@
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
|
||||
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>(),
|
||||
AnalyticTensorAlgebra<Double, DoubleField>, LinearOpsTensorAlgebra<Double, DoubleField> {
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
override val type: DataType get() = DataType.DoubleDataType
|
||||
|
||||
override fun StructureND<Double>.mean(): Double = multikStat.mean(asMultik().array)
|
||||
|
||||
override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): Tensor<Double> =
|
||||
multikStat.mean<Double,DN, DN>(asMultik().array, dim).wrap()
|
||||
|
||||
override fun StructureND<Double>.std(): Double {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.variance(): Double {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.exp(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.ln(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.sqrt(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.cos(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.acos(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.cosh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.acosh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.sin(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.asin(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.sinh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.asinh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.tan(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.atan(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.tanh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.atanh(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.ceil(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.floor(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.det(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.inv(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.cholesky(): Tensor<Double> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.qr(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.lu(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.svd(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.symEig(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||
TODO("Not yet implemented")
|
||||
}
|
||||
}
|
||||
|
||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||
|
@ -7,11 +7,9 @@
|
||||
|
||||
package space.kscience.kmath.multik
|
||||
|
||||
import org.jetbrains.kotlinx.multik.api.Multik
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
||||
import org.jetbrains.kotlinx.multik.api.zeros
|
||||
import org.jetbrains.kotlinx.multik.api.*
|
||||
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
||||
import org.jetbrains.kotlinx.multik.api.math.Math
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
@ -52,10 +50,16 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||
}
|
||||
|
||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
where T : Number, T : Comparable<T> {
|
||||
|
||||
public abstract val type: DataType
|
||||
|
||||
protected val multikMath: Math = mk.math
|
||||
protected val multikLinAl: LinAlg = mk.linalg
|
||||
protected val multikStat: Statistics = mk.stat
|
||||
|
||||
|
||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||
val strides = DefaultStrides(shape)
|
||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||
@ -65,6 +69,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
||||
val data = initMemoryView<T>(array.size, type)
|
||||
var count = 0
|
||||
@ -76,6 +81,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
||||
if (this is MultikTensor) {
|
||||
val array = asMultik().array
|
||||
@ -96,6 +102,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||
val leftArray = left.asMultik().array
|
||||
@ -236,12 +243,12 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||
Multik.ndarrayOf(
|
||||
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
|
||||
).asDNArray().wrap()
|
||||
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
|
||||
).wrap()
|
||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
||||
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
||||
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
||||
multikLinAl.linAlgEx.dotMV(asMultik().array.asD2Array(), other.asMultik().array.asD1Array()).wrap()
|
||||
} else {
|
||||
TODO("Not implemented for broadcasting")
|
||||
}
|
||||
@ -303,14 +310,6 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
||||
}
|
||||
}
|
||||
|
||||
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>() {
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
override val type: DataType get() = DataType.DoubleDataType
|
||||
}
|
||||
|
||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||
|
||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||
override val elementAlgebra: FloatField get() = FloatField
|
||||
override val type: DataType get() = DataType.FloatDataType
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
package space.kscience.kmath.tensors.api
|
||||
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.Field
|
||||
|
||||
@ -122,3 +123,6 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
|
||||
public fun StructureND<T>.floor(): Tensor<T>
|
||||
|
||||
}
|
||||
|
||||
@UnstableKMathAPI
|
||||
public fun <T, ATA : AnalyticTensorAlgebra<T, *>> ATA.exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
@ -24,9 +24,20 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
override fun equals(other: Any?): Boolean = false
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other == null || this::class != other::class) return false
|
||||
|
||||
override fun hashCode(): Int = 0
|
||||
other as TensorLinearStructure
|
||||
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
return shape.contentHashCode()
|
||||
}
|
||||
|
||||
public companion object {
|
||||
|
||||
|
@ -1,71 +0,0 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.tensors.core.internal
|
||||
|
||||
import space.kscience.kmath.nd.Strides
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
internal fun stridesFromShape(shape: IntArray): IntArray {
|
||||
val nDim = shape.size
|
||||
val res = IntArray(nDim)
|
||||
if (nDim == 0)
|
||||
return res
|
||||
|
||||
var current = nDim - 1
|
||||
res[current] = 1
|
||||
|
||||
while (current > 0) {
|
||||
res[current - 1] = max(1, shape[current]) * res[current]
|
||||
current--
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
||||
val res = IntArray(nDim)
|
||||
var current = offset
|
||||
var strideIndex = 0
|
||||
|
||||
while (strideIndex < nDim) {
|
||||
res[strideIndex] = (current / strides[strideIndex])
|
||||
current %= strides[strideIndex]
|
||||
strideIndex++
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
/**
|
||||
* This [Strides] implementation follows the last dimension first convention
|
||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
||||
*
|
||||
* @param shape the shape of the tensor.
|
||||
*/
|
||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||
override val strides: IntArray
|
||||
get() = stridesFromShape(shape)
|
||||
|
||||
override fun index(offset: Int): IntArray =
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
override fun equals(other: Any?): Boolean {
|
||||
if (this === other) return true
|
||||
if (other == null || this::class != other::class) return false
|
||||
|
||||
other as TensorLinearStructure
|
||||
|
||||
if (!shape.contentEquals(other.shape)) return false
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
override fun hashCode(): Int {
|
||||
return shape.contentHashCode()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user