Update stack docs

This commit is contained in:
Roland Grinis 2021-05-06 14:50:05 +01:00
parent db5378c9f4
commit febe526325

View File

@ -493,14 +493,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
/** /**
* Concatenates a sequence of tensors along the first dimension. * Concatenates a sequence of tensors with equal shapes along the first dimension.
* *
* @param tensors the [List] of tensors with same shapes to concatenate * @param tensors the [List] of tensors with same shapes to concatenate
* @param dim the dimension to insert
* @return tensor with concatenation result * @return tensor with concatenation result
*/ */
public fun stack(tensors: List<Tensor<Double>>, dim: Int = 0): DoubleTensor { public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
check(dim == 0) { "Stack by non-zero dimension not implemented yet" }
check(tensors.isNotEmpty()) { "List must have at least 1 element" } check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val shape = tensors[0].shape val shape = tensors[0].shape
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }