produce routine

This commit is contained in:
Roland Grinis 2021-05-07 13:12:18 +01:00
parent 0ef64130ae
commit d31726a0d9

View File

@ -56,6 +56,17 @@ public open class DoubleTensorAlgebra :
return DoubleTensor(shape, buffer, 0) return DoubleTensor(shape, buffer, 0)
} }
/**
* Constructs a tensor with the specified shape and initializer.
*
* @param shape the desired shape for the tensor.
* @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by initializer.
*/
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
fromArray(shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray())
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)