TensorStructure to Tensor

This commit is contained in:
Roland Grinis 2021-05-02 16:19:05 +01:00
parent 48d86fac56
commit 5aaba0dae4
11 changed files with 156 additions and 156 deletions

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
/** /**
* Element-wise analytic operations on [TensorStructure]. * Element-wise analytic operations on [Tensor].
* *
* @param T the type of items closed under analytic functions in the tensors. * @param T the type of items closed under analytic functions in the tensors.
*/ */
@ -14,54 +14,54 @@ public interface AnalyticTensorAlgebra<T> :
TensorPartialDivisionAlgebra<T> { TensorPartialDivisionAlgebra<T> {
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html //For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun TensorStructure<T>.exp(): TensorStructure<T> public fun Tensor<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html //For information: https://pytorch.org/docs/stable/generated/torch.log.html
public fun TensorStructure<T>.log(): TensorStructure<T> public fun Tensor<T>.log(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun TensorStructure<T>.sqrt(): TensorStructure<T> public fun Tensor<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun TensorStructure<T>.cos(): TensorStructure<T> public fun Tensor<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun TensorStructure<T>.acos(): TensorStructure<T> public fun Tensor<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun TensorStructure<T>.cosh(): TensorStructure<T> public fun Tensor<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun TensorStructure<T>.acosh(): TensorStructure<T> public fun Tensor<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun TensorStructure<T>.sin(): TensorStructure<T> public fun Tensor<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun TensorStructure<T>.asin(): TensorStructure<T> public fun Tensor<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun TensorStructure<T>.sinh(): TensorStructure<T> public fun Tensor<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun TensorStructure<T>.asinh(): TensorStructure<T> public fun Tensor<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun TensorStructure<T>.tan(): TensorStructure<T> public fun Tensor<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun TensorStructure<T>.atan(): TensorStructure<T> public fun Tensor<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun TensorStructure<T>.tanh(): TensorStructure<T> public fun Tensor<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun TensorStructure<T>.atanh(): TensorStructure<T> public fun Tensor<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun TensorStructure<T>.ceil(): TensorStructure<T> public fun Tensor<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun TensorStructure<T>.floor(): TensorStructure<T> public fun Tensor<T>.floor(): Tensor<T>
} }

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
/** /**
* Common linear algebra operations. Operates on [TensorStructure]. * Common linear algebra operations. Operates on [Tensor].
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
*/ */
@ -19,7 +19,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return the determinant. * @return the determinant.
*/ */
public fun TensorStructure<T>.det(): TensorStructure<T> public fun Tensor<T>.det(): Tensor<T>
/** /**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
@ -29,7 +29,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
*/ */
public fun TensorStructure<T>.inv(): TensorStructure<T> public fun Tensor<T>.inv(): Tensor<T>
/** /**
* Cholesky decomposition. * Cholesky decomposition.
@ -44,7 +44,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return the batch of L matrices. * @return the batch of L matrices.
*/ */
public fun TensorStructure<T>.cholesky(): TensorStructure<T> public fun Tensor<T>.cholesky(): Tensor<T>
/** /**
* QR decomposition. * QR decomposition.
@ -57,7 +57,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return pair of Q and R tensors. * @return pair of Q and R tensors.
*/ */
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>> public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>>
/** /**
* LUP decomposition * LUP decomposition
@ -70,7 +70,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* * @return triple of P, L and U tensors * * @return triple of P, L and U tensors
*/ */
public fun TensorStructure<T>.lu(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>> public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Singular Value Decomposition. * Singular Value Decomposition.
@ -83,7 +83,7 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return the determinant. * @return the determinant.
*/ */
public fun TensorStructure<T>.svd(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>> public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
@ -92,6 +92,6 @@ public interface LinearOpsTensorAlgebra<T> :
* *
* @return a pair (eigenvalues, eigenvectors) * @return a pair (eigenvalues, eigenvectors)
*/ */
public fun TensorStructure<T>.symEig(): Pair<TensorStructure<T>, TensorStructure<T>> public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
} }

View File

@ -2,4 +2,4 @@ package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.MutableStructureND import space.kscience.kmath.nd.MutableStructureND
public typealias TensorStructure<T> = MutableStructureND<T> public typealias Tensor<T> = MutableStructureND<T>

View File

@ -8,19 +8,19 @@ package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Algebra import space.kscience.kmath.operations.Algebra
/** /**
* Algebra over a ring on [TensorStructure]. * Algebra over a ring on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring * For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring
* *
* @param T the type of items in the tensors. * @param T the type of items in the tensors.
*/ */
public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> { public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
/** /**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
* *
* @return the value of a scalar tensor. * @return the value of a scalar tensor.
*/ */
public fun TensorStructure<T>.value(): T public fun Tensor<T>.value(): T
/** /**
* Each element of the tensor [other] is added to this value. * Each element of the tensor [other] is added to this value.
@ -29,7 +29,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this value and tensor [other]. * @return the sum of this value and tensor [other].
*/ */
public operator fun T.plus(other: TensorStructure<T>): TensorStructure<T> public operator fun T.plus(other: Tensor<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
@ -37,7 +37,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param value the number to be added to each element of this tensor. * @param value the number to be added to each element of this tensor.
* @return the sum of this tensor and [value]. * @return the sum of this tensor and [value].
*/ */
public operator fun TensorStructure<T>.plus(value: T): TensorStructure<T> public operator fun Tensor<T>.plus(value: T): Tensor<T>
/** /**
* Each element of the tensor [other] is added to each element of this tensor. * Each element of the tensor [other] is added to each element of this tensor.
@ -46,21 +46,21 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this tensor and [other]. * @return the sum of this tensor and [other].
*/ */
public operator fun TensorStructure<T>.plus(other: TensorStructure<T>): TensorStructure<T> public operator fun Tensor<T>.plus(other: Tensor<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor. * Adds the scalar [value] to each element of this tensor.
* *
* @param value the number to be added to each element of this tensor. * @param value the number to be added to each element of this tensor.
*/ */
public operator fun TensorStructure<T>.plusAssign(value: T): Unit public operator fun Tensor<T>.plusAssign(value: T): Unit
/** /**
* Each element of the tensor [other] is added to each element of this tensor. * Each element of the tensor [other] is added to each element of this tensor.
* *
* @param other tensor to be added. * @param other tensor to be added.
*/ */
public operator fun TensorStructure<T>.plusAssign(other: TensorStructure<T>): Unit public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
/** /**
@ -70,7 +70,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this value and tensor [other]. * @return the difference between this value and tensor [other].
*/ */
public operator fun T.minus(other: TensorStructure<T>): TensorStructure<T> public operator fun T.minus(other: Tensor<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. * Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor.
@ -78,7 +78,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param value the number to be subtracted from each element of this tensor. * @param value the number to be subtracted from each element of this tensor.
* @return the difference between this tensor and [value]. * @return the difference between this tensor and [value].
*/ */
public operator fun TensorStructure<T>.minus(value: T): TensorStructure<T> public operator fun Tensor<T>.minus(value: T): Tensor<T>
/** /**
* Each element of the tensor [other] is subtracted from each element of this tensor. * Each element of the tensor [other] is subtracted from each element of this tensor.
@ -87,21 +87,21 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this tensor and [other]. * @return the difference between this tensor and [other].
*/ */
public operator fun TensorStructure<T>.minus(other: TensorStructure<T>): TensorStructure<T> public operator fun Tensor<T>.minus(other: Tensor<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor. * Subtracts the scalar [value] from each element of this tensor.
* *
* @param value the number to be subtracted from each element of this tensor. * @param value the number to be subtracted from each element of this tensor.
*/ */
public operator fun TensorStructure<T>.minusAssign(value: T): Unit public operator fun Tensor<T>.minusAssign(value: T): Unit
/** /**
* Each element of the tensor [other] is subtracted from each element of this tensor. * Each element of the tensor [other] is subtracted from each element of this tensor.
* *
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
*/ */
public operator fun TensorStructure<T>.minusAssign(other: TensorStructure<T>): Unit public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
/** /**
@ -111,7 +111,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this value and tensor [other]. * @return the product of this value and tensor [other].
*/ */
public operator fun T.times(other: TensorStructure<T>): TensorStructure<T> public operator fun T.times(other: Tensor<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. * Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor.
@ -119,7 +119,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param value the number to be multiplied by each element of this tensor. * @param value the number to be multiplied by each element of this tensor.
* @return the product of this tensor and [value]. * @return the product of this tensor and [value].
*/ */
public operator fun TensorStructure<T>.times(value: T): TensorStructure<T> public operator fun Tensor<T>.times(value: T): Tensor<T>
/** /**
* Each element of the tensor [other] is multiplied by each element of this tensor. * Each element of the tensor [other] is multiplied by each element of this tensor.
@ -128,28 +128,28 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this tensor and [other]. * @return the product of this tensor and [other].
*/ */
public operator fun TensorStructure<T>.times(other: TensorStructure<T>): TensorStructure<T> public operator fun Tensor<T>.times(other: Tensor<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor. * Multiplies the scalar [value] by each element of this tensor.
* *
* @param value the number to be multiplied by each element of this tensor. * @param value the number to be multiplied by each element of this tensor.
*/ */
public operator fun TensorStructure<T>.timesAssign(value: T): Unit public operator fun Tensor<T>.timesAssign(value: T): Unit
/** /**
* Each element of the tensor [other] is multiplied by each element of this tensor. * Each element of the tensor [other] is multiplied by each element of this tensor.
* *
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
*/ */
public operator fun TensorStructure<T>.timesAssign(other: TensorStructure<T>): Unit public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
/** /**
* Numerical negative, element-wise. * Numerical negative, element-wise.
* *
* @return tensor negation of the original tensor. * @return tensor negation of the original tensor.
*/ */
public operator fun TensorStructure<T>.unaryMinus(): TensorStructure<T> public operator fun Tensor<T>.unaryMinus(): Tensor<T>
/** /**
* Returns the tensor at index i * Returns the tensor at index i
@ -158,7 +158,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param i index of the extractable tensor * @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i] * @return subtensor of the original tensor with index [i]
*/ */
public operator fun TensorStructure<T>.get(i: Int): TensorStructure<T> public operator fun Tensor<T>.get(i: Int): Tensor<T>
/** /**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
@ -168,7 +168,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param j the second dimension to be transposed * @param j the second dimension to be transposed
* @return transposed tensor * @return transposed tensor
*/ */
public fun TensorStructure<T>.transpose(i: Int = -2, j: Int = -1): TensorStructure<T> public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
/** /**
* Returns a new tensor with the same data as the self tensor but of a different shape. * Returns a new tensor with the same data as the self tensor but of a different shape.
@ -178,7 +178,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param shape the desired size * @param shape the desired size
* @return tensor with new shape * @return tensor with new shape
*/ */
public fun TensorStructure<T>.view(shape: IntArray): TensorStructure<T> public fun Tensor<T>.view(shape: IntArray): Tensor<T>
/** /**
* View this tensor as the same size as [other]. * View this tensor as the same size as [other].
@ -188,7 +188,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other the result tensor has the same size as other. * @param other the result tensor has the same size as other.
* @return the result tensor with the same size as other. * @return the result tensor with the same size as other.
*/ */
public fun TensorStructure<T>.viewAs(other: TensorStructure<T>): TensorStructure<T> public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T>
/** /**
* Matrix product of two tensors. * Matrix product of two tensors.
@ -219,7 +219,7 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* @param other tensor to be multiplied * @param other tensor to be multiplied
* @return mathematical product of two tensors * @return mathematical product of two tensors
*/ */
public infix fun TensorStructure<T>.dot(other: TensorStructure<T>): TensorStructure<T> public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T>
/** /**
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
@ -245,10 +245,10 @@ public interface TensorAlgebra<T>: Algebra<TensorStructure<T>> {
* are filled by [diagonalEntries] * are filled by [diagonalEntries]
*/ */
public fun diagonalEmbedding( public fun diagonalEmbedding(
diagonalEntries: TensorStructure<T>, diagonalEntries: Tensor<T>,
offset: Int = 0, offset: Int = 0,
dim1: Int = -2, dim1: Int = -2,
dim2: Int = -1 dim2: Int = -1
): TensorStructure<T> ): Tensor<T>
} }

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
/** /**
* Algebra over a field with partial division on [TensorStructure]. * Algebra over a field with partial division on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
@ -21,7 +21,7 @@ public interface TensorPartialDivisionAlgebra<T> :
* @param other tensor to divide by. * @param other tensor to divide by.
* @return the division of this value by the tensor [other]. * @return the division of this value by the tensor [other].
*/ */
public operator fun T.div(other: TensorStructure<T>): TensorStructure<T> public operator fun T.div(other: Tensor<T>): Tensor<T>
/** /**
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor. * Divide by the scalar [value] each element of this tensor returns a new resulting tensor.
@ -29,7 +29,7 @@ public interface TensorPartialDivisionAlgebra<T> :
* @param value the number to divide by each element of this tensor. * @param value the number to divide by each element of this tensor.
* @return the division of this tensor by the [value]. * @return the division of this tensor by the [value].
*/ */
public operator fun TensorStructure<T>.div(value: T): TensorStructure<T> public operator fun Tensor<T>.div(value: T): Tensor<T>
/** /**
* Each element of the tensor [other] is divided by each element of this tensor. * Each element of the tensor [other] is divided by each element of this tensor.
@ -38,19 +38,19 @@ public interface TensorPartialDivisionAlgebra<T> :
* @param other tensor to be divided by. * @param other tensor to be divided by.
* @return the division of this tensor by [other]. * @return the division of this tensor by [other].
*/ */
public operator fun TensorStructure<T>.div(other: TensorStructure<T>): TensorStructure<T> public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T>
/** /**
* Divides by the scalar [value] each element of this tensor. * Divides by the scalar [value] each element of this tensor.
* *
* @param value the number to divide by each element of this tensor. * @param value the number to divide by each element of this tensor.
*/ */
public operator fun TensorStructure<T>.divAssign(value: T) public operator fun Tensor<T>.divAssign(value: T)
/** /**
* Each element of this tensor is divided by each element of the [other] tensor. * Each element of this tensor is divided by each element of the [other] tensor.
* *
* @param other tensor to be divide by. * @param other tensor to be divide by.
*/ */
public operator fun TensorStructure<T>.divAssign(other: TensorStructure<T>) public operator fun Tensor<T>.divAssign(other: Tensor<T>)
} }

View File

@ -2,7 +2,7 @@ package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableBufferND import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure
@ -10,7 +10,7 @@ public open class BufferedTensor<T>(
override val shape: IntArray, override val shape: IntArray,
internal val mutableBuffer: MutableBuffer<T>, internal val mutableBuffer: MutableBuffer<T>,
internal val bufferStart: Int internal val bufferStart: Int
) : TensorStructure<T> { ) : Tensor<T> {
public val linearStructure: TensorLinearStructure public val linearStructure: TensorLinearStructure
get() = TensorLinearStructure(shape) get() = TensorLinearStructure(shape)
@ -53,33 +53,33 @@ internal fun BufferedTensor<Int>.asTensor(): IntTensor =
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> TensorStructure<T>.copyToBufferedTensor(): BufferedTensor<T> = internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor( BufferedTensor(
this.shape, this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
) )
internal fun <T> TensorStructure<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }
internal val TensorStructure<Double>.tensor: DoubleTensor internal val Tensor<Double>.tensor: DoubleTensor
get() = when (this) { get() = when (this) {
is DoubleTensor -> this is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor() else -> this.toBufferedTensor().asTensor()
} }
internal val TensorStructure<Int>.tensor: IntTensor internal val Tensor<Int>.tensor: IntTensor
get() = when (this) { get() = when (this) {
is IntTensor -> this is IntTensor -> this
else -> this.toBufferedTensor().asTensor() else -> this.toBufferedTensor().asTensor()
} }
public fun TensorStructure<Double>.toDoubleTensor(): DoubleTensor = this.tensor public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
public fun TensorStructure<Int>.toIntTensor(): IntTensor = this.tensor public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
public fun Array<DoubleArray>.toDoubleTensor(): DoubleTensor { public fun Array<DoubleArray>.toDoubleTensor(): DoubleTensor {
val n = size val n = size

View File

@ -5,7 +5,7 @@
package space.kscience.kmath.tensors.core.algebras package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.broadcastTensors import space.kscience.kmath.tensors.core.broadcastTensors
import space.kscience.kmath.tensors.core.broadcastTo import space.kscience.kmath.tensors.core.broadcastTo
@ -16,7 +16,7 @@ import space.kscience.kmath.tensors.core.broadcastTo
*/ */
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
@ -26,7 +26,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
@ -34,7 +34,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
} }
override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
@ -44,7 +44,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
@ -52,7 +52,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
} }
override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
@ -63,7 +63,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
@ -71,7 +71,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
} }
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
@ -82,7 +82,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.linearStructure.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=

View File

@ -6,7 +6,7 @@
package space.kscience.kmath.tensors.core.algebras package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.tensor import space.kscience.kmath.tensors.core.tensor
import kotlin.math.* import kotlin.math.*
@ -14,38 +14,38 @@ import kotlin.math.*
public object DoubleAnalyticTensorAlgebra : public object DoubleAnalyticTensorAlgebra :
AnalyticTensorAlgebra<Double>, AnalyticTensorAlgebra<Double>,
DoubleTensorAlgebra() { DoubleTensorAlgebra() {
override fun TensorStructure<Double>.exp(): DoubleTensor = tensor.map(::exp) override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp)
override fun TensorStructure<Double>.log(): DoubleTensor = tensor.map(::ln) override fun Tensor<Double>.log(): DoubleTensor = tensor.map(::ln)
override fun TensorStructure<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt) override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt)
override fun TensorStructure<Double>.cos(): DoubleTensor = tensor.map(::cos) override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos)
override fun TensorStructure<Double>.acos(): DoubleTensor = tensor.map(::acos) override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos)
override fun TensorStructure<Double>.cosh(): DoubleTensor = tensor.map(::cosh) override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh)
override fun TensorStructure<Double>.acosh(): DoubleTensor = tensor.map(::acosh) override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh)
override fun TensorStructure<Double>.sin(): DoubleTensor = tensor.map(::sin) override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin)
override fun TensorStructure<Double>.asin(): DoubleTensor = tensor.map(::asin) override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin)
override fun TensorStructure<Double>.sinh(): DoubleTensor = tensor.map(::sinh) override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh)
override fun TensorStructure<Double>.asinh(): DoubleTensor = tensor.map(::asinh) override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh)
override fun TensorStructure<Double>.tan(): DoubleTensor = tensor.map(::tan) override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan)
override fun TensorStructure<Double>.atan(): DoubleTensor = tensor.map(::atan) override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan)
override fun TensorStructure<Double>.tanh(): DoubleTensor = tensor.map(::tanh) override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh)
override fun TensorStructure<Double>.atanh(): DoubleTensor = tensor.map(::atanh) override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh)
override fun TensorStructure<Double>.ceil(): DoubleTensor = tensor.map(::ceil) override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil)
override fun TensorStructure<Double>.floor(): DoubleTensor = tensor.map(::floor) override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor)
} }

View File

@ -8,7 +8,7 @@ package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.checkSquareMatrix import space.kscience.kmath.tensors.core.checkSquareMatrix
import space.kscience.kmath.tensors.core.choleskyHelper import space.kscience.kmath.tensors.core.choleskyHelper
@ -25,19 +25,19 @@ public object DoubleLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double>, LinearOpsTensorAlgebra<Double>,
DoubleTensorAlgebra() { DoubleTensorAlgebra() {
override fun TensorStructure<Double>.inv(): DoubleTensor = invLU(1e-9) override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9)
override fun TensorStructure<Double>.det(): DoubleTensor = detLU(1e-9) override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9)
public fun TensorStructure<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
public fun TensorStructure<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9) public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
public fun luPivot( public fun luPivot(
luTensor: TensorStructure<Double>, luTensor: Tensor<Double>,
pivotsTensor: TensorStructure<Int> pivotsTensor: Tensor<Int>
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { ): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape) checkSquareMatrix(luTensor.shape)
check( check(
@ -66,7 +66,7 @@ public object DoubleLinearOpsTensorAlgebra :
return Triple(pTensor, lTensor, uTensor) return Triple(pTensor, lTensor, uTensor)
} }
public fun TensorStructure<Double>.cholesky(epsilon: Double): DoubleTensor { public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon) checkPositiveDefinite(tensor, epsilon)
@ -79,9 +79,9 @@ public object DoubleLinearOpsTensorAlgebra :
return lTensor return lTensor
} }
override fun TensorStructure<Double>.cholesky(): DoubleTensor = cholesky(1e-6) override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
override fun TensorStructure<Double>.qr(): Pair<DoubleTensor, DoubleTensor> { override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
val rTensor = zeroesLike() val rTensor = zeroesLike()
@ -95,10 +95,10 @@ public object DoubleLinearOpsTensorAlgebra :
return qTensor to rTensor return qTensor to rTensor
} }
override fun TensorStructure<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10) svd(epsilon = 1e-10)
public fun TensorStructure<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.linearStructure.dim val size = tensor.linearStructure.dim
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = tensor.shape.sliceArray(size - 2 until size)
@ -122,11 +122,11 @@ public object DoubleLinearOpsTensorAlgebra :
return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
} }
override fun TensorStructure<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> =
symEig(epsilon = 1e-15) symEig(epsilon = 1e-15)
//For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html //For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html
public fun TensorStructure<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(tensor, epsilon)
val (u, s, v) = tensor.svd(epsilon) val (u, s, v) = tensor.svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
@ -139,7 +139,7 @@ public object DoubleLinearOpsTensorAlgebra :
return eig to v return eig to v
} }
public fun TensorStructure<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape) checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy() val luTensor = tensor.copy()
@ -164,7 +164,7 @@ public object DoubleLinearOpsTensorAlgebra :
return detTensor return detTensor
} }
public fun TensorStructure<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor { public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon) val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
@ -177,11 +177,11 @@ public object DoubleLinearOpsTensorAlgebra :
return invTensor return invTensor
} }
public fun TensorStructure<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = this.luFactor(epsilon) val (lu, pivots) = this.luFactor(epsilon)
return luPivot(lu, pivots) return luPivot(lu, pivots)
} }
override fun TensorStructure<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9) override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
} }

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.tensors.core.algebras
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.*
import space.kscience.kmath.tensors.core.broadcastOuterTensors import space.kscience.kmath.tensors.core.broadcastOuterTensors
import space.kscience.kmath.tensors.core.checkBufferShapeConsistency import space.kscience.kmath.tensors.core.checkBufferShapeConsistency
@ -25,7 +25,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()
override fun TensorStructure<Double>.value(): Double { override fun Tensor<Double>.value(): Double {
check(tensor.shape contentEquals intArrayOf(1)) { check(tensor.shape contentEquals intArrayOf(1)) {
"Inconsistent value for tensor of shape ${shape.toList()}" "Inconsistent value for tensor of shape ${shape.toList()}"
} }
@ -39,7 +39,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(shape, buffer, 0) return DoubleTensor(shape, buffer, 0)
} }
override operator fun TensorStructure<Double>.get(i: Int): DoubleTensor { override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
@ -52,7 +52,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
} }
public fun TensorStructure<Double>.fullLike(value: Double): DoubleTensor { public fun Tensor<Double>.fullLike(value: Double): DoubleTensor {
val shape = tensor.shape val shape = tensor.shape
val buffer = DoubleArray(tensor.numElements) { value } val buffer = DoubleArray(tensor.numElements) { value }
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
@ -60,11 +60,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape)
public fun TensorStructure<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape)
public fun TensorStructure<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0) public fun Tensor<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0)
public fun eye(n: Int): DoubleTensor { public fun eye(n: Int): DoubleTensor {
val shape = intArrayOf(n, n) val shape = intArrayOf(n, n)
@ -76,20 +76,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return res return res
} }
public fun TensorStructure<Double>.copy(): DoubleTensor { public fun Tensor<Double>.copy(): DoubleTensor {
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
} }
override fun Double.plus(other: TensorStructure<Double>): DoubleTensor { override fun Double.plus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.plus(value: Double): DoubleTensor = value + tensor override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
@ -97,13 +97,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.plusAssign(value: Double) { override fun Tensor<Double>.plusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value tensor.mutableBuffer.array()[tensor.bufferStart + i] += value
} }
} }
override fun TensorStructure<Double>.plusAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.plusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
@ -111,21 +111,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
} }
override fun Double.minus(other: TensorStructure<Double>): DoubleTensor { override fun Double.minus(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.minus(value: Double): DoubleTensor { override fun Tensor<Double>.minus(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value tensor.mutableBuffer.array()[tensor.bufferStart + i] - value
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.minus(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
@ -133,13 +133,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.minusAssign(value: Double) { override fun Tensor<Double>.minusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value
} }
} }
override fun TensorStructure<Double>.minusAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.minusAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
@ -147,16 +147,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
} }
override fun Double.times(other: TensorStructure<Double>): DoubleTensor { override fun Double.times(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.times(value: Double): DoubleTensor = value * tensor override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor
override fun TensorStructure<Double>.times(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] * tensor.mutableBuffer.array()[tensor.bufferStart + i] *
@ -165,13 +165,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.timesAssign(value: Double) { override fun Tensor<Double>.timesAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value
} }
} }
override fun TensorStructure<Double>.timesAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.timesAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
@ -179,21 +179,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
} }
override fun Double.div(other: TensorStructure<Double>): DoubleTensor { override fun Double.div(other: Tensor<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun TensorStructure<Double>.div(value: Double): DoubleTensor { override fun Tensor<Double>.div(value: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value tensor.mutableBuffer.array()[tensor.bufferStart + i] / value
} }
return DoubleTensor(shape, resBuffer) return DoubleTensor(shape, resBuffer)
} }
override fun TensorStructure<Double>.div(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
@ -202,13 +202,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.divAssign(value: Double) { override fun Tensor<Double>.divAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value
} }
} }
override fun TensorStructure<Double>.divAssign(other: TensorStructure<Double>) { override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
@ -216,14 +216,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
} }
override fun TensorStructure<Double>.unaryMinus(): DoubleTensor { override fun Tensor<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun TensorStructure<Double>.transpose(i: Int, j: Int): DoubleTensor { override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i) val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j) val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj) checkTranspose(tensor.dimension, ii, jj)
@ -248,16 +248,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
override fun TensorStructure<Double>.view(shape: IntArray): DoubleTensor { override fun Tensor<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape) checkView(tensor, shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
} }
override fun TensorStructure<Double>.viewAs(other: TensorStructure<Double>): DoubleTensor { override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor {
return tensor.view(other.shape) return tensor.view(other.shape)
} }
override infix fun TensorStructure<Double>.dot(other: TensorStructure<Double>): DoubleTensor { override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
@ -309,7 +309,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
return resTensor return resTensor
} }
override fun diagonalEmbedding(diagonalEntries: TensorStructure<Double>, offset: Int, dim1: Int, dim2: Int): override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
DoubleTensor { DoubleTensor {
val n = diagonalEntries.shape.size val n = diagonalEntries.shape.size
val d1 = minusIndexFrom(n + 1, dim1) val d1 = minusIndexFrom(n + 1, dim1)
@ -358,7 +358,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
public fun TensorStructure<Double>.map(transform: (Double) -> Double): DoubleTensor { public fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor {
return DoubleTensor( return DoubleTensor(
tensor.shape, tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(), tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
@ -366,14 +366,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
) )
} }
public fun TensorStructure<Double>.eq(other: TensorStructure<Double>, epsilon: Double): Boolean { public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean {
return tensor.eq(other) { x, y -> abs(x - y) < epsilon } return tensor.eq(other) { x, y -> abs(x - y) < epsilon }
} }
public infix fun TensorStructure<Double>.eq(other: TensorStructure<Double>): Boolean = tensor.eq(other, 1e-5) public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = tensor.eq(other, 1e-5)
private fun TensorStructure<Double>.eq( private fun Tensor<Double>.eq(
other: TensorStructure<Double>, other: Tensor<Double>,
eqFunction: (Double, Double) -> Boolean eqFunction: (Double, Double) -> Boolean
): Boolean { ): Boolean {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
@ -396,7 +396,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor = public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor =
DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed)) DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed))
public fun TensorStructure<Double>.randomNormalLike(seed: Long = 0): DoubleTensor = public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed))
// stack tensors by axis 0 // stack tensors by axis 0
@ -412,7 +412,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra<Double> {
} }
// build tensor from this rows by given indices // build tensor from this rows by given indices
public fun TensorStructure<Double>.rowsByIndices(indices: IntArray): DoubleTensor { public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor {
return stack(indices.map { this[it] }) return stack(indices.map { this[it] })
} }
} }

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.tensors.api.TensorStructure import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra
import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra
@ -20,7 +20,7 @@ internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided" "Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
} }
internal fun <T> checkShapesCompatible(a: TensorStructure<T>, b: TensorStructure<T>) = internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) =
check(a.shape contentEquals b.shape) { check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
} }
@ -30,7 +30,7 @@ internal fun checkTranspose(dim: Int, i: Int, j: Int) =
"Cannot transpose $i to $j for a tensor of dim $dim" "Cannot transpose $i to $j for a tensor of dim $dim"
} }
internal fun <T> checkView(a: TensorStructure<T>, shape: IntArray) = internal fun <T> checkView(a: Tensor<T>, shape: IntArray) =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
internal fun checkSquareMatrix(shape: IntArray) { internal fun checkSquareMatrix(shape: IntArray) {
@ -44,7 +44,7 @@ internal fun checkSquareMatrix(shape: IntArray) {
} }
internal fun DoubleTensorAlgebra.checkSymmetric( internal fun DoubleTensorAlgebra.checkSymmetric(
tensor: TensorStructure<Double>, epsilon: Double = 1e-6 tensor: Tensor<Double>, epsilon: Double = 1e-6
) = ) =
check(tensor.eq(tensor.transpose(), epsilon)) { check(tensor.eq(tensor.transpose(), epsilon)) {
"Tensor is not symmetric about the last 2 dimensions at precision $epsilon" "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"