Update stack docs
This commit is contained in:
parent
db5378c9f4
commit
febe526325
@ -493,14 +493,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Concatenates a sequence of tensors along the first dimension.
|
* Concatenates a sequence of tensors with equal shapes along the first dimension.
|
||||||
*
|
*
|
||||||
* @param tensors the [List] of tensors with same shapes to concatenate
|
* @param tensors the [List] of tensors with same shapes to concatenate
|
||||||
* @param dim the dimension to insert
|
|
||||||
* @return tensor with concatenation result
|
* @return tensor with concatenation result
|
||||||
*/
|
*/
|
||||||
public fun stack(tensors: List<Tensor<Double>>, dim: Int = 0): DoubleTensor {
|
public fun stack(tensors: List<Tensor<Double>>): DoubleTensor {
|
||||||
check(dim == 0) { "Stack by non-zero dimension not implemented yet" }
|
|
||||||
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
check(tensors.isNotEmpty()) { "List must have at least 1 element" }
|
||||||
val shape = tensors[0].shape
|
val shape = tensors[0].shape
|
||||||
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
|
||||||
|
Loading…
Reference in New Issue
Block a user