Initial draft for TensorAlgebra
This commit is contained in:
parent
758508ba96
commit
835d64d797
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
build/
|
build/
|
||||||
out/
|
out/
|
||||||
.idea/
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
|
||||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||||
!gradle-wrapper.jar
|
!gradle-wrapper.jar
|
||||||
|
@ -0,0 +1,49 @@
|
|||||||
|
package kscience.kmath.structures
|
||||||
|
|
||||||
|
|
||||||
|
import kscience.kmath.operations.*
|
||||||
|
|
||||||
|
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
|
||||||
|
|
||||||
|
public operator fun T.plus(other: TorchTensorType): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.plus(value: T): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.plusAssign(value: T): Unit
|
||||||
|
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
|
||||||
|
|
||||||
|
public operator fun T.minus(other: TorchTensorType): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.minus(value: T): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.minusAssign(value: T): Unit
|
||||||
|
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
|
||||||
|
|
||||||
|
public operator fun T.times(other: TorchTensorType): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.times(value: T): TorchTensorType
|
||||||
|
public operator fun TorchTensorType.timesAssign(value: T): Unit
|
||||||
|
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
|
||||||
|
|
||||||
|
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
|
||||||
|
|
||||||
|
public fun diagonalEmbedding(
|
||||||
|
diagonalEntries: TorchTensorType,
|
||||||
|
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||||
|
): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
|
||||||
|
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.abs(): TorchTensorType
|
||||||
|
public fun TorchTensorType.sum(): TorchTensorType
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
|
||||||
|
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
|
||||||
|
|
||||||
|
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
|
||||||
|
|
||||||
|
public fun TorchTensorType.exp(): TorchTensorType
|
||||||
|
public fun TorchTensorType.log(): TorchTensorType
|
||||||
|
|
||||||
|
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
|
||||||
|
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
package kscience.kmath.structures
|
||||||
|
|
||||||
|
public abstract class TensorStructure<T>: MutableNDStructure<T> {
|
||||||
|
|
||||||
|
// A tensor can have empty shape, in which case it represents just a value
|
||||||
|
public abstract fun value(): T
|
||||||
|
|
||||||
|
// Tensors are mutable and might hold shared resources
|
||||||
|
override fun equals(other: Any?): Boolean = false
|
||||||
|
override fun hashCode(): Int = 0
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user