more updates
This commit is contained in:
commit
1e7ee53c82
@ -13,6 +13,7 @@ import space.kscience.kmath.commons.linear.CMLinearSpace
|
|||||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||||
import space.kscience.kmath.linear.invoke
|
import space.kscience.kmath.linear.invoke
|
||||||
import space.kscience.kmath.linear.linearSpace
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.multik.multikAlgebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
@ -58,6 +59,16 @@ internal class DotBenchmark {
|
|||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @Benchmark
|
||||||
|
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
// blackhole.consume(matrix1 dot matrix2)
|
||||||
|
// }
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
||||||
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) {
|
fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
|
@ -35,7 +35,7 @@ public interface WithShape {
|
|||||||
* @param T the type of ND-structure element.
|
* @param T the type of ND-structure element.
|
||||||
* @param C the type of the element context.
|
* @param C the type of the element context.
|
||||||
*/
|
*/
|
||||||
public interface AlgebraND<T, out C : Algebra<T>> {
|
public interface AlgebraND<T, out C : Algebra<T>>: Algebra<StructureND<T>> {
|
||||||
/**
|
/**
|
||||||
* The algebra over elements of ND structure.
|
* The algebra over elements of ND structure.
|
||||||
*/
|
*/
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
@ -27,11 +26,6 @@ public open class BufferND<out T>(
|
|||||||
|
|
||||||
override val shape: IntArray get() = indices.shape
|
override val shape: IntArray get() = indices.shape
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = indices.asSequence().map {
|
|
||||||
it to this[it]
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = StructureND.toString(this)
|
override fun toString(): String = StructureND.toString(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stub for DSL the [Algebra] is.
|
* Stub for DSL the [Algebra] is.
|
||||||
@ -99,6 +100,14 @@ public interface Algebra<T> {
|
|||||||
*/
|
*/
|
||||||
public fun binaryOperation(operation: String, left: T, right: T): T =
|
public fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
binaryOperationFunction(operation)(left, right)
|
binaryOperationFunction(operation)(left, right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export an algebra element, so it could be accessed even after algebra scope is closed.
|
||||||
|
* This method must be used on algebras where data is stored externally or any local algebra state is used.
|
||||||
|
* By default (if not overridden), exports the object itself.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun export(arg: T): T = arg
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
|
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
|
||||||
@ -162,6 +171,7 @@ public interface GroupOps<T> : Algebra<T> {
|
|||||||
* @return the difference.
|
* @return the difference.
|
||||||
*/
|
*/
|
||||||
public operator fun T.minus(arg: T): T = add(this, -arg)
|
public operator fun T.minus(arg: T): T = add(this, -arg)
|
||||||
|
|
||||||
// Dynamic dispatch of operations
|
// Dynamic dispatch of operations
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> { arg -> +arg }
|
PLUS_OPERATION -> { arg -> +arg }
|
||||||
|
@ -0,0 +1,136 @@
|
|||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
|
||||||
|
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>(),
|
||||||
|
AnalyticTensorAlgebra<Double, DoubleField>, LinearOpsTensorAlgebra<Double, DoubleField> {
|
||||||
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
override val type: DataType get() = DataType.DoubleDataType
|
||||||
|
|
||||||
|
override fun StructureND<Double>.mean(): Double = multikStat.mean(asMultik().array)
|
||||||
|
|
||||||
|
override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): Tensor<Double> =
|
||||||
|
multikStat.mean<Double,DN, DN>(asMultik().array, dim).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<Double>.std(): Double {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.variance(): Double {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.exp(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.ln(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sqrt(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cos(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.acos(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cosh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.acosh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sin(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.asin(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sinh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.asinh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.tan(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.atan(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.tanh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.atanh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.ceil(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.floor(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.det(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.inv(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cholesky(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.qr(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.lu(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.svd(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.symEig(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
|
@ -7,11 +7,9 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.*
|
||||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
||||||
import org.jetbrains.kotlinx.multik.api.mk
|
import org.jetbrains.kotlinx.multik.api.math.Math
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
|
||||||
import org.jetbrains.kotlinx.multik.api.zeros
|
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
@ -52,10 +50,16 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
|||||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||||
|
where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
public abstract val type: DataType
|
public abstract val type: DataType
|
||||||
|
|
||||||
|
protected val multikMath: Math = mk.math
|
||||||
|
protected val multikLinAl: LinAlg = mk.linalg
|
||||||
|
protected val multikStat: Statistics = mk.stat
|
||||||
|
|
||||||
|
|
||||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||||
@ -65,6 +69,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
|
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
||||||
val data = initMemoryView<T>(array.size, type)
|
val data = initMemoryView<T>(array.size, type)
|
||||||
var count = 0
|
var count = 0
|
||||||
@ -76,6 +81,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
||||||
if (this is MultikTensor) {
|
if (this is MultikTensor) {
|
||||||
val array = asMultik().array
|
val array = asMultik().array
|
||||||
@ -96,6 +102,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||||
val leftArray = left.asMultik().array
|
val leftArray = left.asMultik().array
|
||||||
@ -236,12 +243,12 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
||||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
Multik.ndarrayOf(
|
Multik.ndarrayOf(
|
||||||
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
|
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
|
||||||
).asDNArray().wrap()
|
).wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
multikLinAl.linAlgEx.dotMV(asMultik().array.asD2Array(), other.asMultik().array.asD1Array()).wrap()
|
||||||
} else {
|
} else {
|
||||||
TODO("Not implemented for broadcasting")
|
TODO("Not implemented for broadcasting")
|
||||||
}
|
}
|
||||||
@ -303,14 +310,6 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>() {
|
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
|
||||||
override val type: DataType get() = DataType.DoubleDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
|
||||||
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
|
||||||
|
|
||||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||||
override val elementAlgebra: FloatField get() = FloatField
|
override val elementAlgebra: FloatField get() = FloatField
|
||||||
override val type: DataType get() = DataType.FloatDataType
|
override val type: DataType get() = DataType.FloatDataType
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
|
|
||||||
@ -122,3 +123,6 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
|
|||||||
public fun StructureND<T>.floor(): Tensor<T>
|
public fun StructureND<T>.floor(): Tensor<T>
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, ATA : AnalyticTensorAlgebra<T, *>> ATA.exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
@ -24,9 +24,20 @@ public class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
|||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean = false
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
override fun hashCode(): Int = 0
|
other as TensorLinearStructure
|
||||||
|
|
||||||
|
if (!shape.contentEquals(other.shape)) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
return shape.contentHashCode()
|
||||||
|
}
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
|
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2018-2021 KMath contributors.
|
|
||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package space.kscience.kmath.tensors.core.internal
|
|
||||||
|
|
||||||
import space.kscience.kmath.nd.Strides
|
|
||||||
import kotlin.math.max
|
|
||||||
|
|
||||||
|
|
||||||
internal fun stridesFromShape(shape: IntArray): IntArray {
|
|
||||||
val nDim = shape.size
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
if (nDim == 0)
|
|
||||||
return res
|
|
||||||
|
|
||||||
var current = nDim - 1
|
|
||||||
res[current] = 1
|
|
||||||
|
|
||||||
while (current > 0) {
|
|
||||||
res[current - 1] = max(1, shape[current]) * res[current]
|
|
||||||
current--
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArray {
|
|
||||||
val res = IntArray(nDim)
|
|
||||||
var current = offset
|
|
||||||
var strideIndex = 0
|
|
||||||
|
|
||||||
while (strideIndex < nDim) {
|
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
|
||||||
current %= strides[strideIndex]
|
|
||||||
strideIndex++
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This [Strides] implementation follows the last dimension first convention
|
|
||||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
|
||||||
*
|
|
||||||
* @param shape the shape of the tensor.
|
|
||||||
*/
|
|
||||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
|
||||||
override val strides: IntArray
|
|
||||||
get() = stridesFromShape(shape)
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
|
||||||
indexFromOffset(offset, strides, shape.size)
|
|
||||||
|
|
||||||
override val linearSize: Int
|
|
||||||
get() = shape.reduce(Int::times)
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
if (this === other) return true
|
|
||||||
if (other == null || this::class != other::class) return false
|
|
||||||
|
|
||||||
other as TensorLinearStructure
|
|
||||||
|
|
||||||
if (!shape.contentEquals(other.shape)) return false
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
return shape.contentHashCode()
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user