KMP library for tensors #300
@ -155,7 +155,7 @@ public object DoubleLinearOpsTensorAlgebra :
|
|||||||
* @return triple `(U, S, V)`.
|
* @return triple `(U, S, V)`.
|
||||||
*/
|
*/
|
||||||
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||||
val size = tensor.linearStructure.dim
|
val size = tensor.dimension
|
||||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||||
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||||
|
@ -39,45 +39,20 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
|
||||||
val res = index.copyOf()
|
|
||||||
var current = nDim - 1
|
|
||||||
var carry = 0
|
|
||||||
|
|
||||||
do {
|
|
||||||
res[current]++
|
|
||||||
if (res[current] >= shape[current]) {
|
|
||||||
carry = 1
|
|
||||||
res[current] = 0
|
|
||||||
}
|
|
||||||
current--
|
|
||||||
} while (carry != 0 && current >= 0)
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This [Strides] implementation follows the last dimension first convention
|
* This [Strides] implementation follows the last dimension first convention
|
||||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
||||||
*
|
*
|
||||||
* @param shape the shape of the tensor.
|
* @param shape the shape of the tensor.
|
||||||
*/
|
*/
|
||||||
public class TensorLinearStructure(override val shape: IntArray) : Strides
|
public class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||||
{
|
|
||||||
override val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = stridesFromShape(shape)
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray =
|
override fun index(offset: Int): IntArray =
|
||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
// TODO: documentation (Alya)
|
|
||||||
public fun stepIndex(index: IntArray): IntArray =
|
|
||||||
stepIndex(index, shape, shape.size)
|
|
||||||
|
|
||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
public val dim: Int
|
|
||||||
get() = shape.size
|
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user