Tensor algebra refactoring

This commit is contained in:
Alexander Nozik 2021-01-19 22:25:04 +03:00
parent 360e0e17e9
commit a11711c336

View File

@ -1,32 +1,36 @@
package kscience.kmath.structures
package kscience.kmath.tensors
public interface TensorStructure<T> : MutableNDStructure<T> {
// A tensor can have empty shape, in which case it represents just a value
public abstract fun value(): T
}
import kscience.kmath.operations.Ring
import kscience.kmath.structures.MutableNDStructure
public typealias Tensor<T> = MutableNDStructure<T>
public val <T : Any> Tensor<T>.value: T
get() {
require(shape.contentEquals(intArrayOf(1))) { "Value available only for a tensor with no dimensions" }
return get(intArrayOf(0))
}
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
/**
* To be moved to a separate project
*/
public interface TensorAlgebra<T, TensorType : Tensor<T>> : Ring<TensorType> {
public operator fun T.plus(other: TensorType): TensorType
public operator fun TensorType.plus(value: T): TensorType
public operator fun TensorType.plus(other: TensorType): TensorType
public operator fun TensorType.plusAssign(value: T): Unit
public operator fun TensorType.plusAssign(other: TensorType): Unit
public operator fun T.minus(other: TensorType): TensorType
public operator fun TensorType.minus(value: T): TensorType
public operator fun TensorType.minus(other: TensorType): TensorType
public operator fun TensorType.minusAssign(value: T): Unit
public operator fun TensorType.minusAssign(other: TensorType): Unit
public operator fun T.times(other: TensorType): TensorType
public operator fun TensorType.times(value: T): TensorType
public operator fun TensorType.times(other: TensorType): TensorType
public operator fun TensorType.timesAssign(value: T): Unit
public operator fun TensorType.timesAssign(other: TensorType): Unit
public operator fun TensorType.unaryMinus(): TensorType
public infix fun TensorType.dot(other: TensorType): TensorType
@ -35,7 +39,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
public fun diagonalEmbedding(
diagonalEntries: TensorType,
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
offset: Int = 0, dim1: Int = -2, dim2: Int = -1,
): TensorType
public fun TensorType.transpose(i: Int, j: Int): TensorType
@ -51,8 +55,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
// https://proofwiki.org/wiki/Definition:Division_Algebra
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
TensorAlgebra<T, TensorType> {
public interface TensorPartialDivisionAlgebra<T, TensorType : Tensor<T>> : TensorAlgebra<T, TensorType> {
public operator fun TensorType.div(other: TensorType): TensorType
public operator fun TensorType.divAssign(other: TensorType)