Drop unused functionality in TensorLinearStructure

This commit is contained in:
Roland Grinis 2021-05-06 09:59:58 +01:00
parent 477e64e4d3
commit 16bed53997
2 changed files with 2 additions and 27 deletions

View File

@ -155,7 +155,7 @@ public object DoubleLinearOpsTensorAlgebra :
* @return triple `(U, S, V)`. * @return triple `(U, S, V)`.
*/ */
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.linearStructure.dim val size = tensor.dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = tensor.shape.sliceArray(size - 2 until size)
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n)) val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))

View File

@ -39,45 +39,20 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
return res return res
} }
internal fun stepIndex(index: IntArray, shape: IntArray, nDim: Int): IntArray {
val res = index.copyOf()
var current = nDim - 1
var carry = 0
do {
res[current]++
if (res[current] >= shape[current]) {
carry = 1
res[current] = 0
}
current--
} while (carry != 0 && current >= 0)
return res
}
/** /**
* This [Strides] implementation follows the last dimension first convention * This [Strides] implementation follows the last dimension first convention
* For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html * For more information: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
* *
* @param shape the shape of the tensor. * @param shape the shape of the tensor.
*/ */
public class TensorLinearStructure(override val shape: IntArray) : Strides public class TensorLinearStructure(override val shape: IntArray) : Strides {
{
override val strides: IntArray override val strides: IntArray
get() = stridesFromShape(shape) get() = stridesFromShape(shape)
override fun index(offset: Int): IntArray = override fun index(offset: Int): IntArray =
indexFromOffset(offset, strides, shape.size) indexFromOffset(offset, strides, shape.size)
// TODO: documentation (Alya)
public fun stepIndex(index: IntArray): IntArray =
stepIndex(index, shape, shape.size)
override val linearSize: Int override val linearSize: Int
get() = shape.reduce(Int::times) get() = shape.reduce(Int::times)
public val dim: Int
get() = shape.size
} }