forked from kscience/kmath
Tensor algebra refactoring
This commit is contained in:
parent
360e0e17e9
commit
a11711c336
@ -1,32 +1,36 @@
|
||||
package kscience.kmath.structures
|
||||
package kscience.kmath.tensors
|
||||
|
||||
public interface TensorStructure<T> : MutableNDStructure<T> {
|
||||
// A tensor can have empty shape, in which case it represents just a value
|
||||
public abstract fun value(): T
|
||||
}
|
||||
import kscience.kmath.operations.Ring
|
||||
import kscience.kmath.structures.MutableNDStructure
|
||||
|
||||
public typealias Tensor<T> = MutableNDStructure<T>
|
||||
|
||||
public val <T : Any> Tensor<T>.value: T
|
||||
get() {
|
||||
require(shape.contentEquals(intArrayOf(1))) { "Value available only for a tensor with no dimensions" }
|
||||
return get(intArrayOf(0))
|
||||
}
|
||||
|
||||
// https://proofwiki.org/wiki/Definition:Algebra_over_Ring
|
||||
|
||||
public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
/**
|
||||
* To be moved to a separate project
|
||||
*/
|
||||
public interface TensorAlgebra<T, TensorType : Tensor<T>> : Ring<TensorType> {
|
||||
|
||||
public operator fun T.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plus(value: T): TensorType
|
||||
public operator fun TensorType.plus(other: TensorType): TensorType
|
||||
public operator fun TensorType.plusAssign(value: T): Unit
|
||||
public operator fun TensorType.plusAssign(other: TensorType): Unit
|
||||
|
||||
public operator fun T.minus(other: TensorType): TensorType
|
||||
public operator fun TensorType.minus(value: T): TensorType
|
||||
public operator fun TensorType.minus(other: TensorType): TensorType
|
||||
public operator fun TensorType.minusAssign(value: T): Unit
|
||||
public operator fun TensorType.minusAssign(other: TensorType): Unit
|
||||
|
||||
public operator fun T.times(other: TensorType): TensorType
|
||||
public operator fun TensorType.times(value: T): TensorType
|
||||
public operator fun TensorType.times(other: TensorType): TensorType
|
||||
public operator fun TensorType.timesAssign(value: T): Unit
|
||||
public operator fun TensorType.timesAssign(other: TensorType): Unit
|
||||
public operator fun TensorType.unaryMinus(): TensorType
|
||||
|
||||
|
||||
public infix fun TensorType.dot(other: TensorType): TensorType
|
||||
@ -35,7 +39,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
public fun diagonalEmbedding(
|
||||
diagonalEntries: TensorType,
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1
|
||||
offset: Int = 0, dim1: Int = -2, dim2: Int = -1,
|
||||
): TensorType
|
||||
|
||||
public fun TensorType.transpose(i: Int, j: Int): TensorType
|
||||
@ -51,8 +55,7 @@ public interface TensorAlgebra<T, TensorType : TensorStructure<T>> {
|
||||
|
||||
// https://proofwiki.org/wiki/Definition:Division_Algebra
|
||||
|
||||
public interface TensorPartialDivisionAlgebra<T, TensorType : TensorStructure<T>> :
|
||||
TensorAlgebra<T, TensorType> {
|
||||
public interface TensorPartialDivisionAlgebra<T, TensorType : Tensor<T>> : TensorAlgebra<T, TensorType> {
|
||||
|
||||
public operator fun TensorType.div(other: TensorType): TensorType
|
||||
public operator fun TensorType.divAssign(other: TensorType)
|
Loading…
Reference in New Issue
Block a user