produce routine
This commit is contained in:
parent
0ef64130ae
commit
d31726a0d9
@ -56,6 +56,17 @@ public open class DoubleTensorAlgebra :
|
|||||||
return DoubleTensor(shape, buffer, 0)
|
return DoubleTensor(shape, buffer, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a tensor with the specified shape and initializer.
|
||||||
|
*
|
||||||
|
* @param shape the desired shape for the tensor.
|
||||||
|
* @param initializer mapping tensor indices to values.
|
||||||
|
* @return tensor with the [shape] shape and data generated by initializer.
|
||||||
|
*/
|
||||||
|
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor =
|
||||||
|
fromArray(shape,
|
||||||
|
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray())
|
||||||
|
|
||||||
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
|
override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
|
||||||
val lastShape = tensor.shape.drop(1).toIntArray()
|
val lastShape = tensor.shape.drop(1).toIntArray()
|
||||||
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
|
||||||
|
Loading…
Reference in New Issue
Block a user