eq moved to interface
This commit is contained in:
parent
f70f60c0e8
commit
94b5afa6c4
@ -1,5 +1,7 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
|
||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
@ -52,6 +54,8 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
public fun TensorType.view(shape: IntArray): TensorType
|
||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||
|
||||
public fun TensorType.eq(other: TensorType, delta: T): Boolean
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||
public fun TensorType.abs(): TensorType
|
||||
|
||||
|
@ -344,10 +344,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double = 1e-5): Boolean {
|
||||
override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
||||
|
||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
||||
if (!(this.shape contentEquals other.shape)) {
|
||||
return false
|
||||
|
Loading…
Reference in New Issue
Block a user