forked from kscience/kmath
Dropping creation methods from interface
This commit is contained in:
parent
ae30d3a03e
commit
b5d3ca76db
@ -1,29 +1,10 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
|
||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
public fun TensorType.value(): T
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
||||
public fun full(value: T, shape: IntArray): TensorType
|
||||
|
||||
public fun ones(shape: IntArray): TensorType
|
||||
public fun zeros(shape: IntArray): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
||||
public fun TensorType.fullLike(value: T): TensorType
|
||||
|
||||
public fun TensorType.zeroesLike(): TensorType
|
||||
public fun TensorType.onesLike(): TensorType
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
||||
public fun eye(n: Int): TensorType
|
||||
|
||||
public fun TensorType.copy(): TensorType
|
||||
|
||||
public operator fun T.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plus(value: T): TensorType
|
||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||
@ -53,8 +34,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
public fun TensorType.view(shape: IntArray): TensorType
|
||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||
|
||||
public fun TensorType.eq(other: TensorType, delta: T): Boolean
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||
|
||||
|
@ -27,27 +27,27 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||
}
|
||||
|
||||
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||
public fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||
checkEmptyShape(shape)
|
||||
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||
public fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||
val shape = this.shape
|
||||
val buffer = DoubleArray(this.linearStructure.size) { value }
|
||||
return DoubleTensor(shape, buffer)
|
||||
}
|
||||
|
||||
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||
|
||||
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||
public fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||
|
||||
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||
|
||||
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||
public fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||
|
||||
override fun eye(n: Int): DoubleTensor {
|
||||
public fun eye(n: Int): DoubleTensor {
|
||||
val shape = intArrayOf(n, n)
|
||||
val buffer = DoubleArray(n * n) { 0.0 }
|
||||
val res = DoubleTensor(shape, buffer)
|
||||
@ -57,7 +57,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return res
|
||||
}
|
||||
|
||||
override fun DoubleTensor.copy(): DoubleTensor {
|
||||
public fun DoubleTensor.copy(): DoubleTensor {
|
||||
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
|
||||
}
|
||||
|
||||
@ -299,7 +299,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
package space.kscience.kmath.tensors.core
|
||||
|
||||
import space.kscience.kmath.structures.*
|
||||
|
||||
import kotlin.random.Random
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
@ -34,3 +35,8 @@ internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
||||
is DoubleBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
||||
|
||||
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||
val u = Random(seed)
|
||||
return (0 until n).map { sqrt(-2.0 * u.nextDouble()) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user