diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 5b7515b20..d7c6eaefd 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -1,29 +1,10 @@ package space.kscience.kmath.tensors -import space.kscience.kmath.tensors.core.DoubleTensor - // https://proofwiki.org/wiki/Definition:Algebra_over_Ring public interface TensorAlgebra> { public fun TensorType.value(): T - //https://pytorch.org/docs/stable/generated/torch.full.html - public fun full(value: T, shape: IntArray): TensorType - - public fun ones(shape: IntArray): TensorType - public fun zeros(shape: IntArray): TensorType - - //https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like - public fun TensorType.fullLike(value: T): TensorType - - public fun TensorType.zeroesLike(): TensorType - public fun TensorType.onesLike(): TensorType - - //https://pytorch.org/docs/stable/generated/torch.eye.html - public fun eye(n: Int): TensorType - - public fun TensorType.copy(): TensorType - public operator fun T.plus(other: TensorType): TensorType public operator fun TensorType.plus(value: T): TensorType public operator fun TensorType.plus(other: TensorType): TensorType @@ -53,8 +34,6 @@ public interface TensorAlgebra> { public fun TensorType.view(shape: IntArray): TensorType public fun TensorType.viewAs(other: TensorType): TensorType - public fun TensorType.eq(other: TensorType, delta: T): Boolean - //https://pytorch.org/docs/stable/generated/torch.matmul.html public infix fun TensorType.dot(other: TensorType): TensorType diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt index 67b9c9d73..0b9079967 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt @@ -2,7 +2,7 @@ package space.kscience.kmath.tensors // https://proofwiki.org/wiki/Definition:Division_Algebra public interface TensorPartialDivisionAlgebra> : - TensorAlgebra { + TensorAlgebra { public operator fun TensorType.div(value: T): TensorType public operator fun TensorType.div(other: TensorType): TensorType public operator fun TensorType.divAssign(value: T) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 965aef1ea..323ade45d 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -27,27 +27,27 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra abs(x - y) < delta } } - override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean { + public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean { return this.eq(other) { x, y -> abs(x - y) < delta } } diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt index 591ebb89c..33e78db33 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/tensors/core/utils.kt @@ -1,7 +1,8 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.structures.* - +import kotlin.random.Random +import kotlin.math.* /** * Returns a reference to [IntArray] containing all of the elements of this [Buffer]. @@ -34,3 +35,8 @@ internal fun Buffer.array(): DoubleArray = when (this) { is DoubleBuffer -> array else -> throw RuntimeException("Failed to cast Buffer to DoubleArray") } + +internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray { + val u = Random(seed) + return (0 until n).map { sqrt(-2.0 * u.nextDouble()) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray() +}