v0.3.0-dev-9 #324
@ -56,6 +56,17 @@ public open class DoubleTensorAlgebra :
|
||||
return DoubleTensor(shape, buffer, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and initializer.
|
||||
*
|
||||
* @param shape the desired shape for the tensor.
|
||||
* @param initializer mapping tensor indices to values.
|
||||
* @return tensor with the [shape] shape and data generated by initializer.
|
||||
*/
|
||||
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
|
||||
fromArray(shape,
|
||||
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray())
|
||||
|
||||
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
|
||||
val lastShape = tensor.shape.drop(1).toIntArray()
|
||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||
|
Loading…
Reference in New Issue
Block a user