diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index 4e286b489..867b4fb7a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -49,6 +49,7 @@ public class DoubleTensor internal constructor( internal fun BufferedTensor.asTensor(): IntTensor = IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) + internal fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) @@ -78,4 +79,39 @@ internal val TensorStructure.tensor: IntTensor } public fun TensorStructure.toDoubleTensor(): DoubleTensor = this.tensor -public fun TensorStructure.toIntTensor(): IntTensor = this.tensor \ No newline at end of file +public fun TensorStructure.toIntTensor(): IntTensor = this.tensor + +public fun Array.toDoubleTensor(): DoubleTensor { + val n = size + check(n > 0) { "An empty array cannot be casted to tensor" } + val m = first().size + check(m > 0) { "Inner arrays must have at least 1 argument" } + check(all { size == m }) { "Inner arrays must be the same size" } + + val shape = intArrayOf(n, m) + val buffer = this.flatMap { arr -> arr.map { it } }.toDoubleArray() + + return DoubleTensor(shape, buffer, 0) +} + + +public fun Array.toIntTensor(): IntTensor { + val n = size + check(n > 0) { "An empty array cannot be casted to tensor" } + val m = first().size + check(m > 0) { "Inner arrays must have at least 1 argument" } + check(all { size == m }) { "Inner arrays must be the same size" } + + val shape = intArrayOf(n, m) + val buffer = this.flatMap { arr -> arr.map { it } }.toIntArray() + + return IntTensor(shape, buffer, 0) +} + +public fun DoubleTensor.toDoubleArray(): DoubleArray { + return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toDoubleArray() +} + +public fun IntTensor.toIntArray(): IntArray { + return tensor.mutableBuffer.array().drop(bufferStart).take(numElements).toIntArray() +} \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt index f428b9d2e..c6fb301b5 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt @@ -382,7 +382,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return false } for (i in 0 until n) { - if (!eqFunction(tensor.mutableBuffer[tensor.bufferStart + i], other.tensor.mutableBuffer[other.tensor.bufferStart + i])) { + if (!eqFunction( + tensor.mutableBuffer[tensor.bufferStart + i], + other.tensor.mutableBuffer[other.tensor.bufferStart + i] + ) + ) { return false } } @@ -395,4 +399,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { public fun TensorStructure.randNormalLike(seed: Long = 0): DoubleTensor = DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) + // stack tensors by axis 0 + public fun stack(tensors: List): DoubleTensor { + val shape = tensors.firstOrNull()?.shape + check(shape != null) { "Collection must have at least 1 element" } + check(tensors.all { it.shape contentEquals shape }) {"Stacking tensors must have same shapes"} + val resShape = intArrayOf(tensors.size) + shape + val resBuffer = tensors.flatMap { + it.tensor.mutableBuffer.array().drop(it.bufferStart).take(it.numElements) + }.toDoubleArray() + return DoubleTensor(resShape, resBuffer, 0) + } + + // build tensor from this rows by given indices + public fun TensorStructure.rowsByIndices(indices: IntArray): DoubleTensor { + return stack(indices.map { this[it] }) + } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linUtils.kt index a152b3a17..e54cc4d26 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linUtils.kt @@ -18,7 +18,7 @@ internal fun BufferedTensor.vectorSequence(): Sequence> val vectorOffset = shape[n - 1] val vectorShape = intArrayOf(shape.last()) for (offset in 0 until numElements step vectorOffset) { - val vector = BufferedTensor(vectorShape, mutableBuffer, offset) + val vector = BufferedTensor(vectorShape, mutableBuffer, bufferStart + offset) yield(vector) } } @@ -29,7 +29,7 @@ internal fun BufferedTensor.matrixSequence(): Sequence> val matrixOffset = shape[n - 1] * shape[n - 2] val matrixShape = intArrayOf(shape[n - 2], shape[n - 1]) for (offset in 0 until numElements step matrixOffset) { - val matrix = BufferedTensor(matrixShape, mutableBuffer, offset) + val matrix = BufferedTensor(matrixShape, mutableBuffer, bufferStart + offset) yield(matrix) } }