diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt index 888f12923..675be2f33 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/DoubleTensorAlgebra.kt @@ -266,6 +266,10 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra DoubleTensorAlgebra(block: DoubleTensorAlgebra.() -> R): R = diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 0af757d1a..b1adf2962 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -86,4 +86,6 @@ public interface TensorAlgebra> { //https://pytorch.org/docs/stable/generated/torch.flatten.html#torch.flatten public fun TensorType.flatten(startDim: Int, endDim: Int): TensorType + //https://pytorch.org/docs/stable/generated/torch.squeeze.html + public fun TensorType.squeeze(dim: Int): TensorType }