KMP library for tensors #300
12
README.md
12
README.md
@ -230,6 +230,18 @@ One can still use generic algebras though.
|
|||||||
> **Maturity**: EXPERIMENTAL
|
> **Maturity**: EXPERIMENTAL
|
||||||
<hr/>
|
<hr/>
|
||||||
|
|
||||||
|
* ### [kmath-tensors](kmath-tensors)
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> **Maturity**: PROTOTYPE
|
||||||
|
>
|
||||||
|
> **Features:**
|
||||||
|
> - [tensor algebra](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
|
||||||
|
> - [tensor algebra with broadcasting](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
|
||||||
|
> - [linear algebra operations](kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
|
||||||
|
|
||||||
|
<hr/>
|
||||||
|
|
||||||
* ### [kmath-viktor](kmath-viktor)
|
* ### [kmath-viktor](kmath-viktor)
|
||||||
>
|
>
|
||||||
>
|
>
|
||||||
|
37
kmath-tensors/README.md
Normal file
37
kmath-tensors/README.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Module kmath-tensors
|
||||||
|
|
||||||
|
Common linear algebra operations on tensors.
|
||||||
|
|
||||||
|
- [tensor algebra](src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt) : Basic linear algebra operations on tensors (plus, dot, etc.)
|
||||||
|
- [tensor algebra with broadcasting](src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt) : Basic linear algebra operations implemented with broadcasting.
|
||||||
|
- [linear algebra operations](src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt) : Advanced linear algebra operations like LU decomposition, SVD, etc.
|
||||||
|
|
||||||
|
|
||||||
|
## Artifact:
|
||||||
|
|
||||||
|
The Maven coordinates of this project are `space.kscience:kmath-tensors:0.3.0-dev-7`.
|
||||||
|
|
||||||
|
**Gradle:**
|
||||||
|
```gradle
|
||||||
|
repositories {
|
||||||
|
maven { url 'https://repo.kotlin.link' }
|
||||||
|
maven { url 'https://dl.bintray.com/hotkeytlt/maven' }
|
||||||
|
maven { url "https://dl.bintray.com/kotlin/kotlin-eap" } // include for builds based on kotlin-eap
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation 'space.kscience:kmath-tensors:0.3.0-dev-7'
|
||||||
|
}
|
||||||
|
```
|
||||||
|
**Gradle Kotlin DSL:**
|
||||||
|
```kotlin
|
||||||
|
repositories {
|
||||||
|
maven("https://repo.kotlin.link")
|
||||||
|
maven("https://dl.bintray.com/kotlin/kotlin-eap") // include for builds based on kotlin-eap
|
||||||
|
maven("https://dl.bintray.com/hotkeytlt/maven") // required for a
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
implementation("space.kscience:kmath-tensors:0.3.0-dev-7")
|
||||||
|
}
|
||||||
|
```
|
@ -11,6 +11,30 @@ kotlin.sourceSets {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
readme {
|
tasks.dokkaHtml {
|
||||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
dependsOn(tasks.build)
|
||||||
|
}
|
||||||
|
|
||||||
|
readme {
|
||||||
|
maturity = ru.mipt.npm.gradle.Maturity.PROTOTYPE
|
||||||
|
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "tensor algebra",
|
||||||
|
description = "Basic linear algebra operations on tensors (plus, dot, etc.)",
|
||||||
|
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "tensor algebra with broadcasting",
|
||||||
|
description = "Basic linear algebra operations implemented with broadcasting.",
|
||||||
|
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt"
|
||||||
|
)
|
||||||
|
|
||||||
|
feature(
|
||||||
|
id = "linear algebra operations",
|
||||||
|
description = "Advanced linear algebra operations like LU decomposition, SVD, etc.",
|
||||||
|
ref = "src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt"
|
||||||
|
)
|
||||||
|
|
||||||
}
|
}
|
7
kmath-tensors/docs/README-TEMPLATE.md
Normal file
7
kmath-tensors/docs/README-TEMPLATE.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Module kmath-tensors
|
||||||
|
|
||||||
|
Common linear algebra operations on tensors.
|
||||||
|
|
||||||
|
${features}
|
||||||
|
|
||||||
|
${artifact}
|
@ -5,33 +5,99 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Common linear algebra operations. Operates on [TensorStructure].
|
||||||
|
*
|
||||||
|
* @param T the type of items in the tensors.
|
||||||
|
*/
|
||||||
public interface LinearOpsTensorAlgebra<T> :
|
public interface LinearOpsTensorAlgebra<T> :
|
||||||
TensorPartialDivisionAlgebra<T> {
|
TensorPartialDivisionAlgebra<T> {
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
/**
|
||||||
|
* Computes the determinant of a square matrix input, or of each square matrix in a batched input.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.det
|
||||||
|
*
|
||||||
|
* @return the determinant.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.det(): TensorStructure<T>
|
public fun TensorStructure<T>.det(): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
/**
|
||||||
|
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
|
||||||
|
* Given a square matrix `a`, return the matrix `aInv` satisfying
|
||||||
|
* ``a.dot(aInv) = aInv.dot(a) = eye(a.shape[0])``.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv
|
||||||
|
*
|
||||||
|
* @return the multiplicative inverse of a matrix.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.inv(): TensorStructure<T>
|
public fun TensorStructure<T>.inv(): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
/**
|
||||||
|
* Cholesky decomposition.
|
||||||
|
*
|
||||||
|
* Computes the Cholesky decomposition of a Hermitian (or symmetric for real-valued matrices)
|
||||||
|
* positive-definite matrix or the Cholesky decompositions for a batch of such matrices.
|
||||||
|
* Each decomposition has the form:
|
||||||
|
* Given a tensor `input`, return the tensor `L` satisfying ``input = L * L.H``,
|
||||||
|
* where L is a lower-triangular matrix and L.H is the conjugate transpose of L,
|
||||||
|
* which is just a transpose for the case of real-valued input matrices.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky
|
||||||
|
*
|
||||||
|
* @return the batch of L matrices.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.cholesky(): TensorStructure<T>
|
public fun TensorStructure<T>.cholesky(): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
/**
|
||||||
|
* QR decomposition.
|
||||||
|
*
|
||||||
|
* Computes the QR decomposition of a matrix or a batch of matrices, and returns a namedtuple `(Q, R)` of tensors.
|
||||||
|
* Given a tensor `input`, return tensors (Q, R) satisfying ``input = Q * R``,
|
||||||
|
* with `Q` being an orthogonal matrix or batch of orthogonal matrices
|
||||||
|
* and `R` being an upper triangular matrix or batch of upper triangular matrices.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr
|
||||||
|
*
|
||||||
|
* @return tuple of Q and R tensors.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
|
public fun TensorStructure<T>.qr(): Pair<TensorStructure<T>, TensorStructure<T>>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu.html
|
/**
|
||||||
|
* TODO('Andrew')
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.lu.html
|
||||||
|
*
|
||||||
|
* @return ...
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>>
|
public fun TensorStructure<T>.lu(): Pair<TensorStructure<T>, TensorStructure<Int>>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
/**
|
||||||
|
* TODO('Andrew')
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.lu_unpack.html
|
||||||
|
*
|
||||||
|
* @param luTensor ...
|
||||||
|
* @param pivotsTensor ...
|
||||||
|
* @return ...
|
||||||
|
*/
|
||||||
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>):
|
public fun luPivot(luTensor: TensorStructure<T>, pivotsTensor: TensorStructure<Int>):
|
||||||
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
/**
|
||||||
|
* Singular Value Decomposition.
|
||||||
|
*
|
||||||
|
* Computes the singular value decomposition of either a matrix or batch of matrices `input`.
|
||||||
|
* The singular value decomposition is represented as a namedtuple `(U, S, V)`,
|
||||||
|
* such that ``input = U.dot(diagonalEmbedding(S).dot(V.T))``.
|
||||||
|
* If input is a batch of tensors, then U, S, and Vh are also batched with the same batch dimensions as input.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd
|
||||||
|
*
|
||||||
|
* @return the determinant.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.svd(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
public fun TensorStructure<T>.svd(): Triple<TensorStructure<T>, TensorStructure<T>, TensorStructure<T>>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.symeig.html
|
/**
|
||||||
|
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
|
||||||
|
* represented by a namedtuple (eigenvalues, eigenvectors).
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.symeig.html
|
||||||
|
*
|
||||||
|
* @return a namedtuple (eigenvalues, eigenvectors)
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.symEig(): Pair<TensorStructure<T>, TensorStructure<T>>
|
public fun TensorStructure<T>.symEig(): Pair<TensorStructure<T>, TensorStructure<T>>
|
||||||
|
|
||||||
}
|
}
|
@ -5,44 +5,243 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.tensors.api
|
package space.kscience.kmath.tensors.api
|
||||||
|
|
||||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
/**
|
||||||
|
* Basic linear algebra operations on [TensorStructure].
|
||||||
|
* For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||||
|
*
|
||||||
|
* @param T the type of items in the tensors.
|
||||||
|
*/
|
||||||
public interface TensorAlgebra<T> {
|
public interface TensorAlgebra<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
|
||||||
|
*
|
||||||
|
* @return the value of a scalar tensor.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.value(): T
|
public fun TensorStructure<T>.value(): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is added to this value.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be added.
|
||||||
|
* @return the sum of this value and tensor [other].
|
||||||
|
*/
|
||||||
public operator fun T.plus(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun T.plus(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be added to each element of this tensor.
|
||||||
|
* @return the sum of this tensor and [value].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.plus(value: T): TensorStructure<T>
|
public operator fun TensorStructure<T>.plus(value: T): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is added to each element of this tensor.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be added.
|
||||||
|
* @return the sum of this tensor and [other].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.plus(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun TensorStructure<T>.plus(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds the scalar [value] to each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be added to each element of this tensor.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.plusAssign(value: T): Unit
|
public operator fun TensorStructure<T>.plusAssign(value: T): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is added to each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param other tensor to be added.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.plusAssign(other: TensorStructure<T>): Unit
|
public operator fun TensorStructure<T>.plusAssign(other: TensorStructure<T>): Unit
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is subtracted from this value.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be subtracted.
|
||||||
|
* @return the difference between this value and tensor [other].
|
||||||
|
*/
|
||||||
public operator fun T.minus(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun T.minus(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be subtracted from each element of this tensor.
|
||||||
|
* @return the difference between this tensor and [value].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.minus(value: T): TensorStructure<T>
|
public operator fun TensorStructure<T>.minus(value: T): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is subtracted from each element of this tensor.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be subtracted.
|
||||||
|
* @return the difference between this tensor and [other].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.minus(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun TensorStructure<T>.minus(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtracts the scalar [value] from each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be subtracted from each element of this tensor.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.minusAssign(value: T): Unit
|
public operator fun TensorStructure<T>.minusAssign(value: T): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is subtracted from each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param other tensor to be subtracted.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.minusAssign(other: TensorStructure<T>): Unit
|
public operator fun TensorStructure<T>.minusAssign(other: TensorStructure<T>): Unit
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is multiplied by this value.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be multiplied.
|
||||||
|
* @return the product of this value and tensor [other].
|
||||||
|
*/
|
||||||
public operator fun T.times(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun T.times(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be multiplied by each element of this tensor.
|
||||||
|
* @return the product of this tensor and [value].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.times(value: T): TensorStructure<T>
|
public operator fun TensorStructure<T>.times(value: T): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is multiplied by each element of this tensor.
|
||||||
|
* The resulting tensor is returned.
|
||||||
|
*
|
||||||
|
* @param other tensor to be multiplied.
|
||||||
|
* @return the product of this tensor and [other].
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.times(other: TensorStructure<T>): TensorStructure<T>
|
public operator fun TensorStructure<T>.times(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplies the scalar [value] by each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param value the number to be multiplied by each element of this tensor.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.timesAssign(value: T): Unit
|
public operator fun TensorStructure<T>.timesAssign(value: T): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Each element of the tensor [other] is multiplied by each element of this tensor.
|
||||||
|
*
|
||||||
|
* @param other tensor to be multiplied.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.timesAssign(other: TensorStructure<T>): Unit
|
public operator fun TensorStructure<T>.timesAssign(other: TensorStructure<T>): Unit
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Numerical negative, element-wise.
|
||||||
|
*
|
||||||
|
* @return tensor - negation of the original tensor.
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.unaryMinus(): TensorStructure<T>
|
public operator fun TensorStructure<T>.unaryMinus(): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
/**
|
||||||
|
* Returns the tensor at index i
|
||||||
|
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||||
|
*
|
||||||
|
* @param i index of the extractable tensor
|
||||||
|
* @return subtensor of the original tensor with index [i]
|
||||||
|
*/
|
||||||
public operator fun TensorStructure<T>.get(i: Int): TensorStructure<T>
|
public operator fun TensorStructure<T>.get(i: Int): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.transpose.html
|
/**
|
||||||
|
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.transpose.html
|
||||||
|
*
|
||||||
|
* @param i the first dimension to be transposed
|
||||||
|
* @param j the second dimension to be transposed
|
||||||
|
* @return transposed tensor
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.transpose(i: Int = -2, j: Int = -1): TensorStructure<T>
|
public fun TensorStructure<T>.transpose(i: Int = -2, j: Int = -1): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/tensor_view.html
|
/**
|
||||||
|
* Returns a new tensor with the same data as the self tensor but of a different shape.
|
||||||
|
* The returned tensor shares the same data and must have the same number of elements, but may have a different size
|
||||||
|
* For more information: https://pytorch.org/docs/stable/tensor_view.html
|
||||||
|
*
|
||||||
|
* @param shape the desired size
|
||||||
|
* @return tensor with new shape
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.view(shape: IntArray): TensorStructure<T>
|
public fun TensorStructure<T>.view(shape: IntArray): TensorStructure<T>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* View this tensor as the same size as [other].
|
||||||
|
* ``this.viewAs(other) is equivalent to this.view(other.shape)``.
|
||||||
|
* For more information: https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
||||||
|
*
|
||||||
|
* @param other the result tensor has the same size as other.
|
||||||
|
* @return the result tensor with the same size as other.
|
||||||
|
*/
|
||||||
public fun TensorStructure<T>.viewAs(other: TensorStructure<T>): TensorStructure<T>
|
public fun TensorStructure<T>.viewAs(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.matmul.html
|
/**
|
||||||
|
* Matrix product of two tensors.
|
||||||
|
*
|
||||||
|
* The behavior depends on the dimensionality of the tensors as follows:
|
||||||
|
* 1. If both tensors are 1-dimensional, the dot product (scalar) is returned.
|
||||||
|
*
|
||||||
|
* 2. If both arguments are 2-dimensional, the matrix-matrix product is returned.
|
||||||
|
*
|
||||||
|
* 3. If the first argument is 1-dimensional and the second argument is 2-dimensional,
|
||||||
|
* a 1 is prepended to its dimension for the purpose of the matrix multiply.
|
||||||
|
* After the matrix multiply, the prepended dimension is removed.
|
||||||
|
*
|
||||||
|
* 4. If the first argument is 2-dimensional and the second argument is 1-dimensional,
|
||||||
|
* the matrix-vector product is returned.
|
||||||
|
*
|
||||||
|
* 5. If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2),
|
||||||
|
* then a batched matrix multiply is returned. If the first argument is 1-dimensional,
|
||||||
|
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
|
||||||
|
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
|
||||||
|
* multiple and removed after.
|
||||||
|
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
|
||||||
|
* For example, if `input` is a (j \times 1 \times n \times n) tensor and `other` is a
|
||||||
|
* (k \times n \times n) tensor, out will be a (j \times k \times n \times n) tensor.
|
||||||
|
*
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||||
|
*
|
||||||
|
* @param other tensor to be multiplied
|
||||||
|
* @return mathematical product of two tensors
|
||||||
|
*/
|
||||||
public infix fun TensorStructure<T>.dot(other: TensorStructure<T>): TensorStructure<T>
|
public infix fun TensorStructure<T>.dot(other: TensorStructure<T>): TensorStructure<T>
|
||||||
|
|
||||||
//https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
/**
|
||||||
|
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
|
||||||
|
* are filled by [diagonalEntries].
|
||||||
|
* To facilitate creating batched diagonal matrices,
|
||||||
|
* the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
|
||||||
|
*
|
||||||
|
* The argument [offset] controls which diagonal to consider:
|
||||||
|
* 1. If [offset] = 0, it is the main diagonal.
|
||||||
|
* 2. If [offset] > 0, it is above the main diagonal.
|
||||||
|
* 3. If [offset] < 0, it is below the main diagonal.
|
||||||
|
*
|
||||||
|
* The size of the new matrix will be calculated
|
||||||
|
* to make the specified diagonal of the size of the last input dimension.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/generated/torch.diag_embed.html
|
||||||
|
*
|
||||||
|
* @param diagonalEntries - the input tensor. Must be at least 1-dimensional.
|
||||||
|
* @param offset - which diagonal to consider. Default: 0 (main diagonal).
|
||||||
|
* @param dim1 - first dimension with respect to which to take diagonal. Default: -2.
|
||||||
|
* @param dim2 - second dimension with respect to which to take diagonal. Default: -1.
|
||||||
|
*
|
||||||
|
* @return tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
|
||||||
|
* are filled by [diagonalEntries]
|
||||||
|
*/
|
||||||
public fun diagonalEmbedding(
|
public fun diagonalEmbedding(
|
||||||
diagonalEntries: TensorStructure<T>,
|
diagonalEntries: TensorStructure<T>,
|
||||||
offset: Int = 0,
|
offset: Int = 0,
|
||||||
|
@ -10,6 +10,10 @@ import space.kscience.kmath.tensors.core.*
|
|||||||
import space.kscience.kmath.tensors.core.broadcastTensors
|
import space.kscience.kmath.tensors.core.broadcastTensors
|
||||||
import space.kscience.kmath.tensors.core.broadcastTo
|
import space.kscience.kmath.tensors.core.broadcastTo
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Basic linear algebra operations implemented with broadcasting.
|
||||||
|
* For more information: https://pytorch.org/docs/stable/notes/broadcasting.html
|
||||||
|
*/
|
||||||
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
|
||||||
|
|
||||||
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor {
|
override fun TensorStructure<Double>.plus(other: TensorStructure<Double>): DoubleTensor {
|
||||||
|
Loading…
Reference in New Issue
Block a user