forked from kscience/kmath
stack
This commit is contained in:
parent
b7cac3a015
commit
fe81dea243
@ -382,7 +382,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
return false
|
||||
}
|
||||
for (i in 0 until n) {
|
||||
if (!eqFunction(tensor.mutableBuffer[tensor.bufferStart + i], other.tensor.mutableBuffer[other.tensor.bufferStart + i])) {
|
||||
if (!eqFunction(
|
||||
tensor.mutableBuffer[tensor.bufferStart + i],
|
||||
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
|
||||
)
|
||||
) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -395,4 +399,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
||||
public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor =
|
||||
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
||||
|
||||
// stack tensors by axis 0
|
||||
public fun stack(tensors: List<DoubleTensor>): DoubleTensor {
|
||||
val shape = tensors.firstOrNull()?.shape
|
||||
check(shape != null) { "Collection must have at least 1 element" }
|
||||
check(tensors.all { it.shape contentEquals shape }) {"Stacking tensors must have same shapes"}
|
||||
val resShape = intArrayOf(tensors.size) + shape
|
||||
val resBuffer = tensors.flatMap {
|
||||
it.tensor.mutableBuffer.array().drop(it.bufferStart).take(it.numElements)
|
||||
}.toDoubleArray()
|
||||
return DoubleTensor(resShape, resBuffer, 0)
|
||||
}
|
||||
|
||||
// build tensor from this rows by given indices
|
||||
public fun TensorStructure<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
|
||||
return stack(indices.map { this[it] })
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user