v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
Showing only changes of commit d31726a0d9 - Show all commits

View File

@ -56,6 +56,17 @@ public open class DoubleTensorAlgebra :
return DoubleTensor(shape, buffer, 0)
}
/**
* Constructs a tensor with the specified shape and initializer.
*
* @param shape the desired shape for the tensor.
* @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by initializer.
*/
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
fromArray(shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray())
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)