This commit is contained in:
Andrei Kislitsyn 2021-05-01 14:22:05 +03:00
parent b7cac3a015
commit fe81dea243

View File

@ -382,7 +382,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return false return false
} }
for (i in 0 until n) { for (i in 0 until n) {
if (!eqFunction(tensor.mutableBuffer[tensor.bufferStart + i], other.tensor.mutableBuffer[other.tensor.bufferStart + i])) { if (!eqFunction(
tensor.mutableBuffer[tensor.bufferStart + i],
other.tensor.mutableBuffer[other.tensor.bufferStart + i]
)
) {
return false return false
} }
} }
@ -395,4 +399,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor = public fun TensorStructure<Double>.randNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
// stack tensors by axis 0
public fun stack(tensors: List<DoubleTensor>): DoubleTensor {
val shape = tensors.firstOrNull()?.shape
check(shape != null) { "Collection must have at least 1 element" }
check(tensors.all { it.shape contentEquals shape }) {"Stacking tensors must have same shapes"}
val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.bufferStart).take(it.numElements)
}.toDoubleArray()
return DoubleTensor(resShape, resBuffer, 0)
}
// build tensor from this rows by given indices
public fun TensorStructure<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
return stack(indices.map { this[it] })
}
} }