forked from kscience/kmath
Merge remote-tracking branch 'ups/feature/tensor-algebra' into andrew
This commit is contained in:
commit
0e793eba26
@ -155,7 +155,7 @@ public object DoubleLinearOpsTensorAlgebra :
|
||||
* @return triple `(U, S, V)`.
|
||||
*/
|
||||
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
val size = tensor.linearStructure.dim
|
||||
val size = tensor.dimension
|
||||
val commonShape = tensor.shape.sliceArray(0 until size - 2)
|
||||
val (n, m) = tensor.shape.sliceArray(size - 2 until size)
|
||||
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
|
||||
|
@ -39,39 +39,20 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
|
||||
return res
|
||||
}
|
||||
|
||||
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
|
||||
val res = index.copyOf()
|
||||
var current = nDim - 1
|
||||
var carry = 0
|
||||
|
||||
do {
|
||||
res[current]++
|
||||
if (res[current] >= shape[current]) {
|
||||
carry = 1
|
||||
res[current] = 0
|
||||
}
|
||||
current--
|
||||
} while (carry != 0 && current >= 0)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
|
||||
public class TensorLinearStructure(override val shape: IntArray) : Strides
|
||||
{
|
||||
/**
|
||||
* This [Strides] implementation follows the last dimension first convention
|
||||
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
|
||||
*
|
||||
* @param shape the shape of the tensor.
|
||||
*/
|
||||
public class TensorLinearStructure(override val shape: IntArray) : Strides {
|
||||
override val strides: IntArray
|
||||
get() = stridesFromShape(shape)
|
||||
|
||||
override fun index(offset: Int): IntArray =
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
public fun stepIndex(index: IntArray): IntArray =
|
||||
stepIndex(index, shape, shape.size)
|
||||
|
||||
override val linearSize: Int
|
||||
get() = shape.reduce(Int::times)
|
||||
|
||||
public val dim: Int
|
||||
get() = shape.size
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user