Initial draft for TensorAlgebra

This commit is contained in:
Roland Grinis 2021-01-16 19:29:47 +00:00
parent 758508ba96
commit 835d64d797
3 changed files with 61 additions and 0 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
build/
out/
.idea/
.vscode/
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
!gradle-wrapper.jar

View File

@ -0,0 +1,49 @@
package kscience.kmath.structures
import kscience.kmath.operations.*
public interface TensorAlgebra<T, TorchTensorType : TensorStructure<T>> : Ring<TorchTensorType> {
public operator fun T.plus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.plus(value: T): TorchTensorType
public operator fun TorchTensorType.plusAssign(value: T): Unit
public operator fun TorchTensorType.plusAssign(b: TorchTensorType): Unit
public operator fun T.minus(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.minus(value: T): TorchTensorType
public operator fun TorchTensorType.minusAssign(value: T): Unit
public operator fun TorchTensorType.minusAssign(b: TorchTensorType): Unit
public operator fun T.times(other: TorchTensorType): TorchTensorType
public operator fun TorchTensorType.times(value: T): TorchTensorType
public operator fun TorchTensorType.timesAssign(value: T): Unit
public operator fun TorchTensorType.timesAssign(b: TorchTensorType): Unit
public infix fun TorchTensorType.dot(b: TorchTensorType): TorchTensorType
public fun diagonalEmbedding(
diagonalEntries: TorchTensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
): TorchTensorType
public fun TorchTensorType.transpose(i: Int, j: Int): TorchTensorType
public fun TorchTensorType.view(shape: IntArray): TorchTensorType
public fun TorchTensorType.abs(): TorchTensorType
public fun TorchTensorType.sum(): TorchTensorType
}
public interface TensorFieldAlgebra<T, TorchTensorType : TensorStructure<T>> :
TensorAlgebra<T, TorchTensorType>, Field<TorchTensorType> {
public operator fun TorchTensorType.divAssign(b: TorchTensorType)
public fun TorchTensorType.exp(): TorchTensorType
public fun TorchTensorType.log(): TorchTensorType
public fun TorchTensorType.svd(): Triple<TorchTensorType, TorchTensorType, TorchTensorType>
public fun TorchTensorType.symEig(eigenvectors: Boolean = true): Pair<TorchTensorType, TorchTensorType>
}

View File

@ -0,0 +1,11 @@
package kscience.kmath.structures
public abstract class TensorStructure<T>: MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
// Tensors are mutable and might hold shared resources
override fun equals(other: Any?): Boolean = false
override fun hashCode(): Int = 0
}