diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 7ec920d88..1673657d3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -3,19 +3,17 @@ package space.kscience.kmath.tensors // https://proofwiki.org/wiki/Definition:Algebra_over_Ring public interface TensorAlgebra<T, TensorType : TensorStructure<T>> { + //https://pytorch.org/docs/stable/generated/torch.full.html + public fun full(value: T, shape: IntArray): TensorType + + //https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like + public fun TensorType.fullLike(value: T): TensorType public fun zeros(shape: IntArray): TensorType public fun TensorType.zeroesLike(): TensorType // mb it shouldn't be tensor but algebra method (like in numpy/torch) ? public fun ones(shape: IntArray): TensorType public fun TensorType.onesLike(): TensorType - - //https://pytorch.org/docs/stable/generated/torch.full.html - public fun full(shape: IntArray, value: T): TensorType - - //https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like - public fun TensorType.fullLike(value: T): TensorType - //https://pytorch.org/docs/stable/generated/torch.eye.html public fun eye(n: Int): TensorType diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 9decc0e6a..1c156b4e3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -20,23 +20,25 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou return DoubleTensor(newShape, this.buffer.array(), newStart) } - override fun zeros(shape: IntArray): DoubleTensor { - TODO("Not yet implemented") - } - - override fun DoubleTensor.zeroesLike(): DoubleTensor { - val shape = this.shape - val buffer = DoubleArray(this.strides.linearSize) { 0.0 } + override fun full(value: Double, shape: IntArray): DoubleTensor { + checkEmptyShape(shape) + val buffer = DoubleArray(shape.reduce(Int::times)) { value } return DoubleTensor(shape, buffer) } - override fun ones(shape: IntArray): DoubleTensor { - TODO("Not yet implemented") + override fun DoubleTensor.fullLike(value: Double): DoubleTensor { + val shape = this.shape + val buffer = DoubleArray(this.strides.linearSize) { value } + return DoubleTensor(shape, buffer) } - override fun DoubleTensor.onesLike(): DoubleTensor { - TODO("Not yet implemented") - } + override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) + + override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0) + + override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) + + override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0) override fun eye(n: Int): DoubleTensor { val shape = intArrayOf(n, n) @@ -225,15 +227,6 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou TODO("Not yet implemented") } - override fun full(shape: IntArray, value: Double): DoubleTensor { - TODO("Not yet implemented") - } - - override fun DoubleTensor.fullLike(value: Double): DoubleTensor { - TODO("Not yet implemented") - } - - override fun DoubleTensor.sum(dim: Int, keepDim: Boolean): DoubleTensor { TODO("Not yet implemented") } diff --git a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt index 8b4d5ca16..ddfba0d59 100644 --- a/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt +++ b/kmath-core/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensorAlgebra.kt @@ -47,4 +47,36 @@ class TestDoubleTensorAlgebra { assertTrue(res12.buffer.array() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) } + @Test + fun linearStructure() = DoubleTensorAlgebra { + val shape = intArrayOf(3) + val tensorA = full(value = -4.5, shape = shape) + val tensorB = full(value = 10.9, shape = shape) + val tensorC = full(value = 789.3, shape = shape) + val tensorD = full(value = -72.9, shape = shape) + val tensorE = full(value = 553.1, shape = shape) + val result = 15.8 * tensorA - 1.5 * tensorB * (-tensorD) + 0.02 * tensorC / tensorE - 39.4 + + val expected = fromArray( + shape, + (1..3).map { + 15.8 * (-4.5) - 1.5 * 10.9 * 72.9 + 0.02 * 789.3 / 553.1 - 39.4 + }.toDoubleArray() + ) + + val assignResult = zeros(shape) + tensorA *= 15.8 + tensorB *= 1.5 + tensorB *= -tensorD + tensorC *= 0.02 + tensorC /= tensorE + assignResult += tensorA + assignResult -= tensorB + assignResult += tensorC + assignResult += -39.4 + + assertTrue(expected.buffer.array() contentEquals result.buffer.array()) + assertTrue(expected.buffer.array() contentEquals assignResult.buffer.array()) + } + } \ No newline at end of file