forked from kscience/kmath
Dropping creation methods from interface
This commit is contained in:
parent
ae30d3a03e
commit
b5d3ca76db
@ -1,29 +1,10 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||||
|
|
||||||
public fun TensorType.value(): T
|
public fun TensorType.value(): T
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full.html
|
|
||||||
public fun full(value: T, shape: IntArray): TensorType
|
|
||||||
|
|
||||||
public fun ones(shape: IntArray): TensorType
|
|
||||||
public fun zeros(shape: IntArray): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.full_like.html#torch.full_like
|
|
||||||
public fun TensorType.fullLike(value: T): TensorType
|
|
||||||
|
|
||||||
public fun TensorType.zeroesLike(): TensorType
|
|
||||||
public fun TensorType.onesLike(): TensorType
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.eye.html
|
|
||||||
public fun eye(n: Int): TensorType
|
|
||||||
|
|
||||||
public fun TensorType.copy(): TensorType
|
|
||||||
|
|
||||||
public operator fun T.plus(other: TensorType): TensorType
|
public operator fun T.plus(other: TensorType): TensorType
|
||||||
public operator fun TensorType.plus(value: T): TensorType
|
public operator fun TensorType.plus(value: T): TensorType
|
||||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||||
@ -53,8 +34,6 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
|||||||
public fun TensorType.view(shape: IntArray): TensorType
|
public fun TensorType.view(shape: IntArray): TensorType
|
||||||
public fun TensorType.viewAs(other: TensorType): TensorType
|
public fun TensorType.viewAs(other: TensorType): TensorType
|
||||||
|
|
||||||
public fun TensorType.eq(other: TensorType, delta: T): Boolean
|
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||||
|
|
||||||
|
@ -27,27 +27,27 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
return DoubleTensor(newShape, this.buffer.array(), newStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun full(value: Double, shape: IntArray): DoubleTensor {
|
public fun full(value: Double, shape: IntArray): DoubleTensor {
|
||||||
checkEmptyShape(shape)
|
checkEmptyShape(shape)
|
||||||
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
val buffer = DoubleArray(shape.reduce(Int::times)) { value }
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
public fun DoubleTensor.fullLike(value: Double): DoubleTensor {
|
||||||
val shape = this.shape
|
val shape = this.shape
|
||||||
val buffer = DoubleArray(this.linearStructure.size) { value }
|
val buffer = DoubleArray(this.linearStructure.size) { value }
|
||||||
return DoubleTensor(shape, buffer)
|
return DoubleTensor(shape, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
|
||||||
|
|
||||||
override fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
public fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0)
|
||||||
|
|
||||||
override fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
|
||||||
|
|
||||||
override fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
public fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0)
|
||||||
|
|
||||||
override fun eye(n: Int): DoubleTensor {
|
public fun eye(n: Int): DoubleTensor {
|
||||||
val shape = intArrayOf(n, n)
|
val shape = intArrayOf(n, n)
|
||||||
val buffer = DoubleArray(n * n) { 0.0 }
|
val buffer = DoubleArray(n * n) { 0.0 }
|
||||||
val res = DoubleTensor(shape, buffer)
|
val res = DoubleTensor(shape, buffer)
|
||||||
@ -57,7 +57,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.copy(): DoubleTensor {
|
public fun DoubleTensor.copy(): DoubleTensor {
|
||||||
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
|
return DoubleTensor(this.shape, this.buffer.array().copyOf(), this.bufferStart)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -299,7 +299,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double, Dou
|
|||||||
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
return this.contentEquals(other) { x, y -> abs(x - y) < delta }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean {
|
||||||
return this.eq(other) { x, y -> abs(x - y) < delta }
|
return this.eq(other) { x, y -> abs(x - y) < delta }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
package space.kscience.kmath.tensors.core
|
package space.kscience.kmath.tensors.core
|
||||||
|
|
||||||
import space.kscience.kmath.structures.*
|
import space.kscience.kmath.structures.*
|
||||||
|
import kotlin.random.Random
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||||
@ -34,3 +35,8 @@ internal fun Buffer<Double>.array(): DoubleArray = when (this) {
|
|||||||
is DoubleBuffer -> array
|
is DoubleBuffer -> array
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal inline fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||||
|
val u = Random(seed)
|
||||||
|
return (0 until n).map { sqrt(-2.0 * u.nextDouble()) * cos(2.0 * PI * u.nextDouble()) }.toDoubleArray()
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user