diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 702dc15e2..37c6a34f8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -1,5 +1,7 @@ package space.kscience.kmath.tensors +import space.kscience.kmath.tensors.core.DoubleTensor + // https://proofwiki.org/wiki/Definition:Algebra_over_Ring public interface TensorAlgebra> { @@ -52,6 +54,8 @@ public interface TensorAlgebra> { public fun TensorType.view(shape: IntArray): TensorType public fun TensorType.viewAs(other: TensorType): TensorType + public fun TensorType.eq(other: TensorType, delta: T): Boolean + //https://pytorch.org/docs/stable/generated/torch.abs.html public fun TensorType.abs(): TensorType diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index df4a32bd8..0c82ec9c8 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -344,10 +344,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra abs(x - y) < delta } } - public fun DoubleTensor.eq(other: DoubleTensor, delta: Double = 1e-5): Boolean { + override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean { return this.eq(other) { x, y -> abs(x - y) < delta } } + public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5) + public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean { if (!(this.shape contentEquals other.shape)) { return false