v0.3.0-dev-18 #459
@ -13,6 +13,7 @@ import space.kscience.kmath.commons.linear.CMLinearSpace
|
|||||||
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
import space.kscience.kmath.ejml.EjmlLinearSpaceDDRM
|
||||||
import space.kscience.kmath.linear.invoke
|
import space.kscience.kmath.linear.invoke
|
||||||
import space.kscience.kmath.linear.linearSpace
|
import space.kscience.kmath.linear.linearSpace
|
||||||
|
import space.kscience.kmath.multik.multikAlgebra
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
@ -58,6 +59,16 @@ internal class DotBenchmark {
|
|||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// @Benchmark
|
||||||
|
// fun tensorDot(blackhole: Blackhole) = with(Double.tensorAlgebra) {
|
||||||
|
// blackhole.consume(matrix1 dot matrix2)
|
||||||
|
// }
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun multikDot(blackhole: Blackhole) = with(Double.multikAlgebra) {
|
||||||
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) {
|
fun bufferedDot(blackhole: Blackhole) = with(DoubleField.linearSpace(Buffer.Companion::auto)) {
|
||||||
blackhole.consume(matrix1 dot matrix2)
|
blackhole.consume(matrix1 dot matrix2)
|
||||||
|
@ -35,7 +35,7 @@ public interface WithShape {
|
|||||||
* @param T the type of ND-structure element.
|
* @param T the type of ND-structure element.
|
||||||
* @param C the type of the element context.
|
* @param C the type of the element context.
|
||||||
*/
|
*/
|
||||||
public interface AlgebraND<T, out C : Algebra<T>> {
|
public interface AlgebraND<T, out C : Algebra<T>>: Algebra<StructureND<T>> {
|
||||||
/**
|
/**
|
||||||
* The algebra over elements of ND structure.
|
* The algebra over elements of ND structure.
|
||||||
*/
|
*/
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.MutableBuffer
|
import space.kscience.kmath.structures.MutableBuffer
|
||||||
@ -27,11 +26,6 @@ public open class BufferND<out T>(
|
|||||||
|
|
||||||
override val shape: IntArray get() = indices.shape
|
override val shape: IntArray get() = indices.shape
|
||||||
|
|
||||||
@PerformancePitfall
|
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = indices.asSequence().map {
|
|
||||||
it to this[it]
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String = StructureND.toString(this)
|
override fun toString(): String = StructureND.toString(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stub for DSL the [Algebra] is.
|
* Stub for DSL the [Algebra] is.
|
||||||
@ -99,6 +100,14 @@ public interface Algebra<T> {
|
|||||||
*/
|
*/
|
||||||
public fun binaryOperation(operation: String, left: T, right: T): T =
|
public fun binaryOperation(operation: String, left: T, right: T): T =
|
||||||
binaryOperationFunction(operation)(left, right)
|
binaryOperationFunction(operation)(left, right)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export an algebra element, so it could be accessed even after algebra scope is closed.
|
||||||
|
* This method must be used on algebras where data is stored externally or any local algebra state is used.
|
||||||
|
* By default (if not overridden), exports the object itself.
|
||||||
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun export(arg: T): T = arg
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
|
public fun <T> Algebra<T>.bindSymbolOrNull(symbol: Symbol): T? = bindSymbolOrNull(symbol.identity)
|
||||||
@ -162,6 +171,7 @@ public interface GroupOps<T> : Algebra<T> {
|
|||||||
* @return the difference.
|
* @return the difference.
|
||||||
*/
|
*/
|
||||||
public operator fun T.minus(arg: T): T = add(this, -arg)
|
public operator fun T.minus(arg: T): T = add(this, -arg)
|
||||||
|
|
||||||
// Dynamic dispatch of operations
|
// Dynamic dispatch of operations
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> { arg -> +arg }
|
PLUS_OPERATION -> { arg -> +arg }
|
||||||
|
@ -0,0 +1,136 @@
|
|||||||
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DN
|
||||||
|
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
|
||||||
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
|
|
||||||
|
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>(),
|
||||||
|
AnalyticTensorAlgebra<Double, DoubleField>, LinearOpsTensorAlgebra<Double, DoubleField> {
|
||||||
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
override val type: DataType get() = DataType.DoubleDataType
|
||||||
|
|
||||||
|
override fun StructureND<Double>.mean(): Double = multikStat.mean(asMultik().array)
|
||||||
|
|
||||||
|
override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): Tensor<Double> =
|
||||||
|
multikStat.mean<Double,DN, DN>(asMultik().array, dim).wrap()
|
||||||
|
|
||||||
|
override fun StructureND<Double>.std(): Double {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.variance(): Double {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.exp(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.ln(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sqrt(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cos(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.acos(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cosh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.acosh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sin(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.asin(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.sinh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.asinh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.tan(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.atan(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.tanh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.atanh(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.ceil(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.floor(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.det(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.inv(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.cholesky(): Tensor<Double> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.qr(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.lu(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.svd(): Triple<Tensor<Double>, Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<Double>.symEig(): Pair<Tensor<Double>, Tensor<Double>> {
|
||||||
|
TODO("Not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
||||||
|
|
@ -7,11 +7,9 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.multik
|
package space.kscience.kmath.multik
|
||||||
|
|
||||||
import org.jetbrains.kotlinx.multik.api.Multik
|
import org.jetbrains.kotlinx.multik.api.*
|
||||||
import org.jetbrains.kotlinx.multik.api.linalg.dot
|
import org.jetbrains.kotlinx.multik.api.linalg.LinAlg
|
||||||
import org.jetbrains.kotlinx.multik.api.mk
|
import org.jetbrains.kotlinx.multik.api.math.Math
|
||||||
import org.jetbrains.kotlinx.multik.api.ndarrayOf
|
|
||||||
import org.jetbrains.kotlinx.multik.api.zeros
|
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
import org.jetbrains.kotlinx.multik.ndarray.data.*
|
||||||
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
import org.jetbrains.kotlinx.multik.ndarray.operations.*
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
@ -52,10 +50,16 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
|
|||||||
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
else throw ClassCastException("Cannot cast MultiArray to NDArray.")
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
|
public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
||||||
|
where T : Number, T : Comparable<T> {
|
||||||
|
|
||||||
public abstract val type: DataType
|
public abstract val type: DataType
|
||||||
|
|
||||||
|
protected val multikMath: Math = mk.math
|
||||||
|
protected val multikLinAl: LinAlg = mk.linalg
|
||||||
|
protected val multikStat: Statistics = mk.stat
|
||||||
|
|
||||||
|
|
||||||
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||||
@ -65,6 +69,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
|
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
|
||||||
val data = initMemoryView<T>(array.size, type)
|
val data = initMemoryView<T>(array.size, type)
|
||||||
var count = 0
|
var count = 0
|
||||||
@ -76,6 +81,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
|
||||||
if (this is MultikTensor) {
|
if (this is MultikTensor) {
|
||||||
val array = asMultik().array
|
val array = asMultik().array
|
||||||
@ -96,6 +102,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@OptIn(PerformancePitfall::class)
|
||||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
|
||||||
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
|
||||||
val leftArray = left.asMultik().array
|
val leftArray = left.asMultik().array
|
||||||
@ -208,9 +215,9 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
|
override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
|
||||||
asMultik().array.unaryMinus().wrap()
|
asMultik().array.unaryMinus().wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
|
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
|
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
|
||||||
|
|
||||||
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> {
|
||||||
require(shape.all { it > 0 })
|
require(shape.all { it > 0 })
|
||||||
@ -236,12 +243,12 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A>
|
|||||||
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
|
||||||
if (this.shape.size == 1 && other.shape.size == 1) {
|
if (this.shape.size == 1 && other.shape.size == 1) {
|
||||||
Multik.ndarrayOf(
|
Multik.ndarrayOf(
|
||||||
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
|
multikLinAl.linAlgEx.dotVV(asMultik().array.asD1Array(), other.asMultik().array.asD1Array())
|
||||||
).asDNArray().wrap()
|
).wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
} else if (this.shape.size == 2 && other.shape.size == 2) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD2Array()).asDNArray().wrap()
|
multikLinAl.linAlgEx.dotMM(asMultik().array.asD2Array(), other.asMultik().array.asD2Array()).wrap()
|
||||||
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
} else if (this.shape.size == 2 && other.shape.size == 1) {
|
||||||
(asMultik().array.asD2Array() dot other.asMultik().array.asD1Array()).asDNArray().wrap()
|
multikLinAl.linAlgEx.dotMV(asMultik().array.asD2Array(), other.asMultik().array.asD1Array()).wrap()
|
||||||
} else {
|
} else {
|
||||||
TODO("Not implemented for broadcasting")
|
TODO("Not implemented for broadcasting")
|
||||||
}
|
}
|
||||||
@ -303,14 +310,6 @@ public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>() {
|
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
|
||||||
override val type: DataType get() = DataType.DoubleDataType
|
|
||||||
}
|
|
||||||
|
|
||||||
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
|
||||||
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
|
|
||||||
|
|
||||||
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
|
||||||
override val elementAlgebra: FloatField get() = FloatField
|
override val elementAlgebra: FloatField get() = FloatField
|
||||||
override val type: DataType get() = DataType.FloatDataType
|
override val type: DataType get() = DataType.FloatDataType
|
||||||
|
@ -92,8 +92,8 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
|
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
|
||||||
override fun StructureND<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
|
override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
|
||||||
override fun StructureND<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
|
override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
|
||||||
override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
|
override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
|
||||||
|
|
||||||
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.operations.Field
|
import space.kscience.kmath.operations.Field
|
||||||
|
|
||||||
@ -121,4 +122,7 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
|
|||||||
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
||||||
public fun StructureND<T>.floor(): Tensor<T>
|
public fun StructureND<T>.floor(): Tensor<T>
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@UnstableKMathAPI
|
||||||
|
public fun <T, ATA : AnalyticTensorAlgebra<T, *>> ATA.exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
@ -166,7 +166,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
|
|||||||
* @param i index of the extractable tensor
|
* @param i index of the extractable tensor
|
||||||
* @return subtensor of the original tensor with index [i]
|
* @return subtensor of the original tensor with index [i]
|
||||||
*/
|
*/
|
||||||
public operator fun StructureND<T>.get(i: Int): Tensor<T>
|
public operator fun Tensor<T>.get(i: Int): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
|
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
|
||||||
@ -176,7 +176,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
|
|||||||
* @param j the second dimension to be transposed
|
* @param j the second dimension to be transposed
|
||||||
* @return transposed tensor
|
* @return transposed tensor
|
||||||
*/
|
*/
|
||||||
public fun StructureND<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
|
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a new tensor with the same data as the self tensor but of a different shape.
|
* Returns a new tensor with the same data as the self tensor but of a different shape.
|
||||||
|
@ -115,7 +115,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
|
TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
|
||||||
)
|
)
|
||||||
|
|
||||||
override operator fun StructureND<Double>.get(i: Int): DoubleTensor {
|
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
|
||||||
val lastShape = tensor.shape.drop(1).toIntArray()
|
val lastShape = tensor.shape.drop(1).toIntArray()
|
||||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
|
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
|
||||||
@ -344,7 +344,7 @@ public open class DoubleTensorAlgebra :
|
|||||||
return DoubleTensor(tensor.shape, resBuffer)
|
return DoubleTensor(tensor.shape, resBuffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<Double>.transpose(i: Int, j: Int): DoubleTensor {
|
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
|
||||||
val ii = tensor.minusIndex(i)
|
val ii = tensor.minusIndex(i)
|
||||||
val jj = tensor.minusIndex(j)
|
val jj = tensor.minusIndex(j)
|
||||||
checkTranspose(tensor.dimension, ii, jj)
|
checkTranspose(tensor.dimension, ii, jj)
|
||||||
|
Loading…
Reference in New Issue
Block a user