eq moved to interface
This commit is contained in:
parent
f70f60c0e8
commit
94b5afa6c4
@ -1,5 +1,7 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
|
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
|
|
||||||
@ -52,6 +54,8 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||||
|
|
||||||
|
public fun TensorType.eq(other: TensorType, delta: T): Boolean
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
//https://pytorch.org/docs/stable/generated/torch.abs.html
|
||||||
public fun TensorType.abs(): TensorType
|
public fun TensorType.abs(): TensorType
|
||||||
|
|
||||||
|
@ -344,10 +344,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double = 1e-5): Boolean {
|
override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5)
|
||||||
|
|
||||||
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
public fun DoubleTensor.contentEquals(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean {
|
||||||
if (!(this.shape contentEquals other.shape)) {
|
if (!(this.shape contentEquals other.shape)) {
|
||||||
return false
|
return false
|
||||||
|
Loading…
Reference in New Issue
Block a user