diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 1673657d3..e41b7546c 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -6,12 +6,13 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.full.html public fun full(value: T, shape: IntArray): TensorType + public fun ones(shape: IntArray): TensorType + public fun zeros(shape: IntArray): TensorType + //https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like public fun TensorType.fullLike(value: T): TensorType - public fun zeros(shape: IntArray): TensorType - public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ? - public fun ones(shape: IntArray): TensorType + public fun TensorType.zeroesLike(): TensorType public fun TensorType.onesLike(): TensorType //https://pytorch.org/docs/stable/generated/torch.eye.html diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index cbfb15be0..5a0fc6c75 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -33,28 +33,6 @@ public open class BufferedTensor( override fun hashCode(): Int = 0 - // todo rename to vector - public inline fun forEachVector(vectorAction : (MutableStructure1D) -> Unit): Unit { - check(shape.size >= 1) {"todo"} - val vectorOffset = strides.strides[0] - val vectorShape = intArrayOf(shape.last()) - for (offset in 0 until numel step vectorOffset) { - val vector = BufferedTensor(vectorShape, buffer, offset).as1D() - vectorAction(vector) - } - } - - public inline fun forEachMatrix(matrixAction : (MutableStructure2D) -> Unit): Unit { - check(shape.size >= 2) {"todo"} - val matrixOffset = strides.strides[1] - val matrixShape = intArrayOf(shape[shape.size - 2], shape.last()) //todo better way? - for (offset in 0 until numel step matrixOffset) { - val matrix = BufferedTensor(matrixShape, buffer, offset).as2D() - matrixAction(matrix) - } - } - // todo remove code copy-pasting - public fun vectorSequence(): Sequence> = sequence { check(shape.size >= 1) {"todo"} val vectorOffset = strides.strides[0] @@ -75,6 +53,18 @@ public open class BufferedTensor( } } + public inline fun forEachVector(vectorAction : (MutableStructure1D) -> Unit): Unit { + for (vector in vectorSequence()){ + vectorAction(vector) + } + } + + public inline fun forEachMatrix(matrixAction : (MutableStructure2D) -> Unit): Unit { + for (matrix in matrixSequence()){ + matrixAction(matrix) + } + } + } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 1c156b4e3..465c2ddc1 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -20,24 +20,41 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra>>>>>> ups/feature/tensor-algebra:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt override fun DoubleTensor.fullLike(value: Double): DoubleTensor { val shape = this.shape val buffer = DoubleArray(this.strides.linearSize) { value } return DoubleTensor(shape, buffer) } +<<<<<<< HEAD:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt + override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0) + +======= override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0) override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) +>>>>>>> ups/feature/tensor-algebra:kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0) override fun eye(n: Int): DoubleTensor {