Dropping creation methods from interface

This commit is contained in:
Roland Grinis 2021-03-30 19:20:20 +01:00
parent ae30d3a03e
commit b5d3ca76db
4 changed files with 17 additions and 32 deletions

View File

@ -1,29 +1,10 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.tensors.core.DoubleTensor
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.value(): T
//https://pytorch.org/docs/stable/generated/torch.full.html
public fun full(value: T, shape: IntArray): TensorType
public fun ones(shape: IntArray): TensorType
public fun zeros(shape: IntArray): TensorType
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
public fun TensorType.fullLike(value: T): TensorType
public fun TensorType.zeroesLike(): TensorType
public fun TensorType.onesLike(): TensorType
//https://pytorch.org/docs/stable/generated/torch.eye.html
public fun eye(n: Int): TensorType
public fun TensorType.copy(): TensorType
public operator fun T.plus(other: TensorType): TensorType
public operator fun TensorType.plus(value: T): TensorType
public operator fun TensorType.plus(other: TensorType): TensorType
@ -53,8 +34,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun TensorType.view(shape: IntArray): TensorType
public fun TensorType.viewAs(other: TensorType): TensorType
public fun TensorType.eq(other: TensorType, delta: T): Boolean
//https://pytorch.org/docs/stable/generated/torch.matmul.html
public infix fun TensorType.dot(other: TensorType): TensorType

View File

@ -2,7 +2,7 @@ package space.kscience.kmath.tensors
// https://proofwiki.org/wiki/Definition:Division_Algebra
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
TensorAlgebra<T, TensorType> {
public operator fun TensorType.div(value: T): TensorType
public operator fun TensorType.div(other: TensorType): TensorType
public operator fun TensorType.divAssign(value: T)

View File

@ -27,27 +27,27 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return DoubleTensor(newShape, this.buffer.array(), newStart)
}
override fun full(value: Double, shape: IntArray): DoubleTensor {
public fun full(value: Double, shape: IntArray): DoubleTensor {
checkEmptyShape(shape)
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
return DoubleTensor(shape, buffer)
}
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
public fun DoubleTensor.fullLike(value: Double): DoubleTensor {
val shape = this.shape
val buffer = DoubleArray(this.linearStructure.size) { value }
return DoubleTensor(shape, buffer)
}
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
public fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
public fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
override fun eye(n: Int): DoubleTensor {
public fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n)
val buffer = DoubleArray(n * n) { 0.0 }
val res = DoubleTensor(shape, buffer)
@ -57,7 +57,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return res
}
override fun DoubleTensor.copy(): DoubleTensor {
public fun DoubleTensor.copy(): DoubleTensor {
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
}
@ -299,7 +299,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
}
override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
return this.eq(other) { x, y -> abs(x - y) < delta }
}

View File

@ -1,7 +1,8 @@
package space.kscience.kmath.tensors.core
import space.kscience.kmath.structures.*
import kotlin.random.Random
import kotlin.math.*
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
@ -34,3 +35,8 @@ internal fun Buffer<Double>.array(): DoubleArray = when (this) {
is DoubleBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
val u = Random(seed)
return (0 until n).map { sqrt(-2.0 * u.nextDouble()) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
}