stack
This commit is contained in:
parent
b7cac3a015
commit
fe81dea243
@ -382,7 +382,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for (i in 0 until n) {
|
for (i in 0 until n) {
|
||||||
if (!eqFunction(tensor.mutableBuffer[tensor.bufferStart + i], other.tensor.mutableBuffer[other.tensor.bufferStart + i])) {
|
if (!eqFunction(
|
||||||
|
tensor.mutableBuffer[tensor.bufferStart + i],
|
||||||
|
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
|
||||||
|
)
|
||||||
|
) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -395,4 +399,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
|
|||||||
public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor =
|
public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor =
|
||||||
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
|
||||||
|
|
||||||
|
// stack tensors by axis 0
|
||||||
|
public fun stack(tensors: List<DoubleTensor>): DoubleTensor {
|
||||||
|
val shape = tensors.firstOrNull()?.shape
|
||||||
|
check(shape != null) { "Collection must have at least 1 element" }
|
||||||
|
check(tensors.all { it.shape contentEquals shape }) {"Stacking tensors must have same shapes"}
|
||||||
|
val resShape = intArrayOf(tensors.size) + shape
|
||||||
|
val resBuffer = tensors.flatMap {
|
||||||
|
it.tensor.mutableBuffer.array().drop(it.bufferStart).take(it.numElements)
|
||||||
|
}.toDoubleArray()
|
||||||
|
return DoubleTensor(resShape, resBuffer, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// build tensor from this rows by given indices
|
||||||
|
public fun TensorStructure<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
|
||||||
|
return stack(indices.map { this[it] })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user