Tensor algebra generified

This commit is contained in:
Alexander Nozik 2021-10-27 14:48:36 +03:00
parent 4635cd3fb3
commit 29a90efca5
12 changed files with 323 additions and 347 deletions

View File

@ -45,6 +45,7 @@
- Buffer algebra does not require size anymore - Buffer algebra does not require size anymore
- Operations -> Ops - Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
- Tensor algebra takes read-only structures as input and inherits AlgebraND
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`

View File

@ -13,8 +13,7 @@ import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ones import org.jetbrains.kotlinx.multik.api.ones
import org.jetbrains.kotlinx.multik.ndarray.data.DN import org.jetbrains.kotlinx.multik.ndarray.data.DN
import org.jetbrains.kotlinx.multik.ndarray.data.DataType import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.multik.multikND import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.multik.multikTensorAlgebra
import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
@ -79,7 +78,7 @@ internal class NDFieldBenchmark {
} }
@Benchmark @Benchmark
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) { fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap() val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
@ -100,7 +99,7 @@ internal class NDFieldBenchmark {
private val specializedField = DoubleField.ndAlgebra private val specializedField = DoubleField.ndAlgebra
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
private val nd4jField = DoubleField.nd4j private val nd4jField = DoubleField.nd4j
private val multikField = DoubleField.multikND private val multikField = DoubleField.multikAlgebra
private val viktorField = DoubleField.viktorAlgebra private val viktorField = DoubleField.viktorAlgebra
} }
} }

View File

@ -7,12 +7,12 @@ package space.kscience.kmath.tensors
import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarray import org.jetbrains.kotlinx.multik.api.ndarray
import space.kscience.kmath.multik.multikND import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.nd.one import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
fun main(): Unit = with(DoubleField.multikND) { fun main(): Unit = with(DoubleField.multikAlgebra) {
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap() val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap()
one(a.shape) - a + b * 3 one(a.shape) - a + b * 3.0
} }

View File

@ -1,137 +0,0 @@
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.math.cos
import org.jetbrains.kotlinx.multik.api.math.sin
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.RingOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.*
/**
* A ring algebra for Multik operations
*/
public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
public val type: DataType,
override val elementAlgebra: A
) : RingOpsND<T, A> {
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val res = mk.zeros<T, DN>(shape, type).asDNArray()
for (index in res.multiIndices) {
res[index] = elementAlgebra.initializer(index)
}
return res.wrap()
}
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
this
} else {
structureND(shape) { get(it) }
}
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el)
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
val indexIter = array.multiIndices.iterator()
var index = 0
for (item in array) {
if (indexIter.hasNext()) {
data[index++] = elementAlgebra.transform(indexIter.next(), item)
} else {
throw ArithmeticException("Index overflow has happened.")
}
}
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type)
var counter = 0
val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator()
//iterating them together
while (leftIterator.hasNext()) {
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
}
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
}
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = asMultik().array.unaryMinus().wrap()
override fun add(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
(left.asMultik().array + right.asMultik().array).wrap()
override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
asMultik().array.plus(arg).wrap()
override fun StructureND<T>.minus(arg: T): MultikTensor<T> = asMultik().array.minus(arg).wrap()
override fun T.plus(arg: StructureND<T>): MultikTensor<T> = arg + this
override fun T.minus(arg: StructureND<T>): MultikTensor<T> = arg.map { this@minus - it }
override fun multiply(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
left.asMultik().array.times(right.asMultik().array).wrap()
override fun StructureND<T>.times(arg: T): MultikTensor<T> =
asMultik().array.times(arg).wrap()
override fun T.times(arg: StructureND<T>): MultikTensor<T> = arg * this
override fun StructureND<T>.unaryPlus(): MultikTensor<T> = asMultik()
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap()
override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap()
override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap()
}
/**
* A field algebra for multik operations
*/
public class MultikFieldOpsND<T, A : Field<T>> internal constructor(
type: DataType,
elementAlgebra: A
) : MultikRingOpsND<T, A>(type, elementAlgebra), FieldOpsND<T, A> {
override fun StructureND<T>.div(other: StructureND<T>): StructureND<T> =
asMultik().array.div(other.asMultik().array).wrap()
}
public val DoubleField.multikND: MultikFieldOpsND<Double, DoubleField>
get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField)
public val FloatField.multikND: MultikFieldOpsND<Float, FloatField>
get() = MultikFieldOpsND(DataType.FloatDataType, FloatField)
public val ShortRing.multikND: MultikRingOpsND<Short, ShortRing>
get() = MultikRingOpsND(DataType.ShortDataType, ShortRing)
public val IntRing.multikND: MultikRingOpsND<Int, IntRing>
get() = MultikRingOpsND(DataType.IntDataType, IntRing)
public val LongRing.multikND: MultikRingOpsND<Long, LongRing>
get() = MultikRingOpsND(DataType.LongDataType, LongRing)

View File

@ -15,12 +15,14 @@ import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
@JvmInline @JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> { public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
@ -50,29 +52,80 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
else throw ClassCastException("Cannot cast MultiArray to NDArray.") else throw ClassCastException("Cannot cast MultiArray to NDArray.")
} }
public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor( public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
public val type: DataType,
override val elementAlgebra: A, public abstract val type: DataType
public val comparator: Comparator<T>
) : TensorAlgebra<T, A> { override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val strides = DefaultStrides(shape)
val memoryView = initMemoryView<T>(strides.linearSize, type)
strides.indices().forEachIndexed { linearIndex, tensorIndex ->
memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex)
}
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
}
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
val data = initMemoryView<T>(array.size, type)
var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el)
NDArray(data, shape = shape, dim = array.dim).wrap()
} else {
structureND(shape) { index ->
transform(get(index))
}
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
if (this is MultikTensor) {
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
val indexIter = array.multiIndices.iterator()
var index = 0
for (item in array) {
if (indexIter.hasNext()) {
data[index++] = elementAlgebra.transform(indexIter.next(), item)
} else {
throw ArithmeticException("Index overflow has happened.")
}
}
NDArray(data, shape = array.shape, dim = array.dim).wrap()
} else {
structureND(shape) { index ->
transform(index, get(index))
}
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type)
var counter = 0
val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator()
//iterating them together
while (leftIterator.hasNext()) {
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
}
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
}
/** /**
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source * are not reflected back onto the source
*/ */
public fun StructureND<T>.asMultik(): MultikTensor<T> { public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
return if (this is MultikTensor) { this
this } else {
} else { val res = mk.zeros<T, DN>(shape, type).asDNArray()
val res = mk.zeros<T, DN>(shape, type).asDNArray() for (index in res.multiIndices) {
for (index in res.multiIndices) { res[index] = this[index]
res[index] = this[index]
}
res.wrap()
} }
res.wrap()
} }
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this) public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0)) get(intArrayOf(0))
@ -81,8 +134,8 @@ public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor(
override fun T.plus(other: StructureND<T>): MultikTensor<T> = override fun T.plus(other: StructureND<T>): MultikTensor<T> =
other.plus(this) other.plus(this)
override fun StructureND<T>.plus(value: T): MultikTensor<T> = override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap()
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> = override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap() asMultik().array.plus(other.asMultik().array).wrap()
@ -155,11 +208,11 @@ public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor(
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap() asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap() override fun StructureND<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap() override fun StructureND<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> { override fun StructureND<T>.view(shape: IntArray): MultikTensor<T> {
require(shape.all { it > 0 }) require(shape.all { it > 0 })
require(shape.fold(1, Int::times) == this.shape.size) { require(shape.fold(1, Int::times) == this.shape.size) {
"Cannot reshape array of size ${this.shape.size} into a new shape ${ "Cannot reshape array of size ${this.shape.size} into a new shape ${
@ -178,9 +231,9 @@ public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor(
}.wrap() }.wrap()
} }
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape) override fun StructureND<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
if (this.shape.size == 1 && other.shape.size == 1) { if (this.shape.size == 1 && other.shape.size == 1) {
Multik.ndarrayOf( Multik.ndarrayOf(
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
@ -197,45 +250,95 @@ public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor(
TODO("Diagonal embedding not implemented") TODO("Diagonal embedding not implemented")
} }
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> override fun StructureND<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
elementAlgebra.add(acc, t) elementAlgebra.add(acc, t)
} }
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.min(): T = override fun StructureND<T>.min(): T? = asMultik().array.min()
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.max(): T = override fun StructureND<T>.max(): T? = asMultik().array.max()
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
} }
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double, DoubleField> public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } : MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float, FloatField> override fun T.div(arg: StructureND<T>): MultikTensor<T> = arg.map { elementAlgebra.divide(this@div, it) }
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short, ShortRing> override fun StructureND<T>.div(arg: T): MultikTensor<T> =
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } asMultik().array.deepCopy().apply { divAssign(arg) }.wrap()
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int, IntRing> override fun StructureND<T>.div(other: StructureND<T>): MultikTensor<T> =
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } asMultik().array.div(other.asMultik().array).wrap()
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long, LongRing> override fun Tensor<T>.divAssign(value: T) {
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } if (this is MultikTensor) {
array.divAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.divide(t, value) }
}
}
override fun Tensor<T>.divAssign(other: StructureND<T>) {
if (this is MultikTensor) {
array.divAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.divide(t, other[index]) }
}
}
}
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>() {
override val elementAlgebra: DoubleField get() = DoubleField
override val type: DataType get() = DataType.DoubleDataType
}
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
override val elementAlgebra: FloatField get() = FloatField
override val type: DataType get() = DataType.FloatDataType
}
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
override val elementAlgebra: ShortRing get() = ShortRing
override val type: DataType get() = DataType.ShortDataType
}
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
override val elementAlgebra: IntRing get() = IntRing
override val type: DataType get() = DataType.IntDataType
}
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
override val elementAlgebra: LongRing get() = LongRing
override val type: DataType get() = DataType.LongDataType
}
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra

View File

@ -7,7 +7,7 @@ import space.kscience.kmath.operations.invoke
internal class MultikNDTest { internal class MultikNDTest {
@Test @Test
fun basicAlgebra(): Unit = DoubleField.multikND{ fun basicAlgebra(): Unit = DoubleField.multikAlgebra{
one(2,2) + 1.0 one(2,2) + 1.0
} }
} }

View File

@ -13,6 +13,7 @@ import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.factory.ops.NDBase
import org.nd4j.linalg.ops.transforms.Transforms import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
@ -27,22 +28,6 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
*/ */
public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> { public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> {
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T> {
val array =
}
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
TODO("Not yet implemented")
}
/** /**
* Wraps [INDArray] to [Nd4jArrayStructure]. * Wraps [INDArray] to [Nd4jArrayStructure].
*/ */
@ -53,8 +38,21 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T>
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> =
structureND(shape) { index -> elementAlgebra.transform(get(index)) }
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> =
structureND(shape) { index -> elementAlgebra.transform(index, get(index)) }
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape))
return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
}
override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap() override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap()
override fun StructureND<T>.plus(value: T): Nd4jArrayStructure<T> = ndArray.add(value).wrap() override fun StructureND<T>.plus(arg: T): Nd4jArrayStructure<T> = ndArray.add(arg).wrap()
override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap() override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap()
@ -94,51 +92,52 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
} }
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap() override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap() override fun StructureND<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap() override fun StructureND<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
override fun Tensor<T>.dot(other: Tensor<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap() override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.min(keepDim, dim).wrap() ndArray.min(keepDim, dim).wrap()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.sum(keepDim, dim).wrap() ndArray.sum(keepDim, dim).wrap()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.max(keepDim, dim).wrap() ndArray.max(keepDim, dim).wrap()
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap() override fun StructureND<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: Tensor<T>): Nd4jArrayStructure<T> = view(other.shape) override fun StructureND<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap() ndBase.get().argmax(ndArray, keepDim, dim).wrap()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = ndArray.mean(keepDim, dim).wrap() override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.mean(keepDim, dim).wrap()
override fun Tensor<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap() override fun StructureND<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
override fun Tensor<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap() override fun StructureND<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
override fun Tensor<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap() override fun StructureND<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
override fun Tensor<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap() override fun StructureND<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
override fun Tensor<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap() override fun StructureND<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
override fun Tensor<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap() override fun StructureND<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
override fun Tensor<T>.acosh(): Nd4jArrayStructure<T> = override fun StructureND<T>.acosh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap() override fun StructureND<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
override fun Tensor<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap() override fun StructureND<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap() override fun StructureND<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
override fun Tensor<T>.asinh(): Nd4jArrayStructure<T> = override fun StructureND<T>.asinh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap() override fun StructureND<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
override fun Tensor<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap() override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun Tensor<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap() override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun Tensor<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap() override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun Tensor<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap() override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun Tensor<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap() override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.std(true, keepDim, dim).wrap() ndArray.std(true, keepDim, dim).wrap()
override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap() override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap()
@ -153,7 +152,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
ndArray.divi(other.ndArray) ndArray.divi(other.ndArray)
} }
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
private companion object { private companion object {
@ -170,6 +169,16 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure<Double> {
val array: INDArray = Nd4j.zeros(*shape)
val indices = DefaultStrides(shape)
indices.indices().forEach { index ->
array.putScalar(index, elementAlgebra.initializer(index))
}
return array.wrap()
}
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray override val StructureND<Double>.ndArray: INDArray
get() = when (this) { get() = when (this) {
@ -190,10 +199,10 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
dim2: Int, dim2: Int,
): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) ): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2)
override fun Tensor<Double>.sum(): Double = ndArray.sumNumber().toDouble() override fun StructureND<Double>.sum(): Double = ndArray.sumNumber().toDouble()
override fun Tensor<Double>.min(): Double = ndArray.minNumber().toDouble() override fun StructureND<Double>.min(): Double = ndArray.minNumber().toDouble()
override fun Tensor<Double>.max(): Double = ndArray.maxNumber().toDouble() override fun StructureND<Double>.max(): Double = ndArray.maxNumber().toDouble()
override fun Tensor<Double>.mean(): Double = ndArray.meanNumber().toDouble() override fun StructureND<Double>.mean(): Double = ndArray.meanNumber().toDouble()
override fun Tensor<Double>.std(): Double = ndArray.stdNumber().toDouble() override fun StructureND<Double>.std(): Double = ndArray.stdNumber().toDouble()
override fun Tensor<Double>.variance(): Double = ndArray.varNumber().toDouble() override fun StructureND<Double>.variance(): Double = ndArray.varNumber().toDouble()
} }

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.Field
@ -18,7 +19,7 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
/** /**
* @return the mean of all elements in the input tensor. * @return the mean of all elements in the input tensor.
*/ */
public fun Tensor<T>.mean(): T public fun StructureND<T>.mean(): T
/** /**
* Returns the mean of each row of the input tensor in the given dimension [dim]. * Returns the mean of each row of the input tensor in the given dimension [dim].
@ -31,12 +32,12 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the mean of each row of the input tensor in the given dimension [dim]. * @return the mean of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the standard deviation of all elements in the input tensor. * @return the standard deviation of all elements in the input tensor.
*/ */
public fun Tensor<T>.std(): T public fun StructureND<T>.std(): T
/** /**
* Returns the standard deviation of each row of the input tensor in the given dimension [dim]. * Returns the standard deviation of each row of the input tensor in the given dimension [dim].
@ -49,12 +50,12 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the standard deviation of each row of the input tensor in the given dimension [dim]. * @return the standard deviation of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.std(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the variance of all elements in the input tensor. * @return the variance of all elements in the input tensor.
*/ */
public fun Tensor<T>.variance(): T public fun StructureND<T>.variance(): T
/** /**
* Returns the variance of each row of the input tensor in the given dimension [dim]. * Returns the variance of each row of the input tensor in the given dimension [dim].
@ -67,57 +68,57 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionA
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the variance of each row of the input tensor in the given dimension [dim]. * @return the variance of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html //For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun Tensor<T>.exp(): Tensor<T> public fun StructureND<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html //For information: https://pytorch.org/docs/stable/generated/torch.log.html
public fun Tensor<T>.ln(): Tensor<T> public fun StructureND<T>.ln(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun Tensor<T>.sqrt(): Tensor<T> public fun StructureND<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun Tensor<T>.cos(): Tensor<T> public fun StructureND<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun Tensor<T>.acos(): Tensor<T> public fun StructureND<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun Tensor<T>.cosh(): Tensor<T> public fun StructureND<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun Tensor<T>.acosh(): Tensor<T> public fun StructureND<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun Tensor<T>.sin(): Tensor<T> public fun StructureND<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun Tensor<T>.asin(): Tensor<T> public fun StructureND<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun Tensor<T>.sinh(): Tensor<T> public fun StructureND<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun Tensor<T>.asinh(): Tensor<T> public fun StructureND<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun Tensor<T>.tan(): Tensor<T> public fun StructureND<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun Tensor<T>.atan(): Tensor<T> public fun StructureND<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun Tensor<T>.tanh(): Tensor<T> public fun StructureND<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun Tensor<T>.atanh(): Tensor<T> public fun StructureND<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun Tensor<T>.ceil(): Tensor<T> public fun StructureND<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun Tensor<T>.floor(): Tensor<T> public fun StructureND<T>.floor(): Tensor<T>
} }

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field import space.kscience.kmath.operations.Field
/** /**
@ -20,7 +21,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* *
* @return the determinant. * @return the determinant.
*/ */
public fun Tensor<T>.det(): Tensor<T> public fun StructureND<T>.det(): Tensor<T>
/** /**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
@ -30,7 +31,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* *
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
*/ */
public fun Tensor<T>.inv(): Tensor<T> public fun StructureND<T>.inv(): Tensor<T>
/** /**
* Cholesky decomposition. * Cholesky decomposition.
@ -46,7 +47,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* @receiver the `input`. * @receiver the `input`.
* @return the batch of `L` matrices. * @return the batch of `L` matrices.
*/ */
public fun Tensor<T>.cholesky(): Tensor<T> public fun StructureND<T>.cholesky(): Tensor<T>
/** /**
* QR decomposition. * QR decomposition.
@ -60,7 +61,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* @receiver the `input`. * @receiver the `input`.
* @return pair of `Q` and `R` tensors. * @return pair of `Q` and `R` tensors.
*/ */
public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>> public fun StructureND<T>.qr(): Pair<Tensor<T>, Tensor<T>>
/** /**
* LUP decomposition * LUP decomposition
@ -74,7 +75,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* @receiver the `input`. * @receiver the `input`.
* @return triple of P, L and U tensors * @return triple of P, L and U tensors
*/ */
public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>> public fun StructureND<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Singular Value Decomposition. * Singular Value Decomposition.
@ -90,7 +91,7 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* @receiver the `input`. * @receiver the `input`.
* @return triple `Triple(U, S, V)`. * @return triple `Triple(U, S, V)`.
*/ */
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>> public fun StructureND<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices,
@ -100,6 +101,6 @@ public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivision
* @receiver the `input`. * @receiver the `input`.
* @return a pair `eigenvalues to eigenvectors` * @return a pair `eigenvalues to eigenvectors`
*/ */
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>> public fun StructureND<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
} }

View File

@ -41,12 +41,12 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
override operator fun T.plus(other: StructureND<T>): Tensor<T> override operator fun T.plus(other: StructureND<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be added to each element of this tensor. * @param arg the number to be added to each element of this tensor.
* @return the sum of this tensor and [value]. * @return the sum of this tensor and [arg].
*/ */
override operator fun StructureND<T>.plus(value: T): Tensor<T> override operator fun StructureND<T>.plus(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is added to each element of this tensor. * Each element of the tensor [other] is added to each element of this tensor.
@ -166,7 +166,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param i index of the extractable tensor * @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i] * @return subtensor of the original tensor with index [i]
*/ */
public operator fun Tensor<T>.get(i: Int): Tensor<T> public operator fun StructureND<T>.get(i: Int): Tensor<T>
/** /**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
@ -176,7 +176,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param j the second dimension to be transposed * @param j the second dimension to be transposed
* @return transposed tensor * @return transposed tensor
*/ */
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T> public fun StructureND<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
/** /**
* Returns a new tensor with the same data as the self tensor but of a different shape. * Returns a new tensor with the same data as the self tensor but of a different shape.
@ -186,7 +186,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param shape the desired size * @param shape the desired size
* @return tensor with new shape * @return tensor with new shape
*/ */
public fun Tensor<T>.view(shape: IntArray): Tensor<T> public fun StructureND<T>.view(shape: IntArray): Tensor<T>
/** /**
* View this tensor as the same size as [other]. * View this tensor as the same size as [other].
@ -196,7 +196,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param other the result tensor has the same size as other. * @param other the result tensor has the same size as other.
* @return the result tensor with the same size as other. * @return the result tensor with the same size as other.
*/ */
public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> public fun StructureND<T>.viewAs(other: StructureND<T>): Tensor<T>
/** /**
* Matrix product of two tensors. * Matrix product of two tensors.
@ -227,7 +227,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return a mathematical product of two tensors. * @return a mathematical product of two tensors.
*/ */
public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> public infix fun StructureND<T>.dot(other: StructureND<T>): Tensor<T>
/** /**
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
@ -262,7 +262,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
/** /**
* @return the sum of all elements in the input tensor. * @return the sum of all elements in the input tensor.
*/ */
public fun Tensor<T>.sum(): T public fun StructureND<T>.sum(): T
/** /**
* Returns the sum of each row of the input tensor in the given dimension [dim]. * Returns the sum of each row of the input tensor in the given dimension [dim].
@ -275,12 +275,12 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the sum of each row of the input tensor in the given dimension [dim]. * @return the sum of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the minimum value of all elements in the input tensor. * @return the minimum value of all elements in the input tensor or null if there are no values
*/ */
public fun Tensor<T>.min(): T public fun StructureND<T>.min(): T?
/** /**
* Returns the minimum value of each row of the input tensor in the given dimension [dim]. * Returns the minimum value of each row of the input tensor in the given dimension [dim].
@ -293,12 +293,12 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim]. * @return the minimum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* Returns the maximum value of all elements in the input tensor. * Returns the maximum value of all elements in the input tensor or null if there are no values
*/ */
public fun Tensor<T>.max(): T public fun StructureND<T>.max(): T?
/** /**
* Returns the maximum value of each row of the input tensor in the given dimension [dim]. * Returns the maximum value of each row of the input tensor in the given dimension [dim].
@ -311,7 +311,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim]. * @return the maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. * Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
@ -324,7 +324,7 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the index of maximum value of each row of the input tensor in the given dimension [dim]. * @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right

View File

@ -85,7 +85,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.divAssign(other: Tensor<Double>) { override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.indices.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=

View File

@ -115,7 +115,7 @@ public open class DoubleTensorAlgebra :
TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
) )
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override operator fun StructureND<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
@ -160,7 +160,7 @@ public open class DoubleTensorAlgebra :
* *
* @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor.
*/ */
public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun StructureND<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
/** /**
* Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape].
@ -198,9 +198,8 @@ public open class DoubleTensorAlgebra :
* *
* @return a copy of the `input` tensor with a copied buffer. * @return a copy of the `input` tensor with a copied buffer.
*/ */
public fun Tensor<Double>.copy(): DoubleTensor { public fun StructureND<Double>.copy(): DoubleTensor =
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
}
override fun Double.plus(other: StructureND<Double>): DoubleTensor { override fun Double.plus(other: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
@ -209,7 +208,7 @@ public open class DoubleTensorAlgebra :
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun StructureND<Double>.plus(value: Double): DoubleTensor = value + tensor override fun StructureND<Double>.plus(arg: Double): DoubleTensor = arg + tensor
override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
@ -345,7 +344,7 @@ public open class DoubleTensorAlgebra :
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor { override fun StructureND<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i) val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j) val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj) checkTranspose(tensor.dimension, ii, jj)
@ -369,15 +368,15 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor { override fun StructureND<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape) checkView(tensor, shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
} }
override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor = override fun StructureND<Double>.viewAs(other: StructureND<Double>): DoubleTensor =
tensor.view(other.shape) tensor.view(other.shape)
override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor { override infix fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
@ -569,10 +568,10 @@ public open class DoubleTensorAlgebra :
*/ */
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.copyArray()) foldFunction(tensor.copyArray())
internal inline fun Tensor<Double>.foldDim( internal inline fun StructureND<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> Double,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,
@ -596,30 +595,30 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() } override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim) foldDim({ x -> x.sum() }, dim, keepDim)
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! } override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim) foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
}, dim, keepDim) }, dim, keepDim)
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements } override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim( foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
@ -629,12 +628,12 @@ public open class DoubleTensorAlgebra :
keepDim keepDim
) )
override fun Tensor<Double>.std(): Double = this.fold { arr -> override fun StructureND<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1))
} }
override fun Tensor<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
@ -644,12 +643,12 @@ public open class DoubleTensorAlgebra :
keepDim keepDim
) )
override fun Tensor<Double>.variance(): Double = this.fold { arr -> override fun StructureND<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)
} }
override fun Tensor<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
@ -672,7 +671,7 @@ public open class DoubleTensorAlgebra :
* @param tensors the [List] of 1-dimensional tensors with same shape * @param tensors the [List] of 1-dimensional tensors with same shape
* @return `M`. * @return `M`.
*/ */
public fun cov(tensors: List<Tensor<Double>>): DoubleTensor { public fun cov(tensors: List<StructureND<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" } check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val n = tensors.size val n = tensors.size
val m = tensors[0].shape[0] val m = tensors[0].shape[0]
@ -689,43 +688,43 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map { exp(it) } override fun StructureND<Double>.exp(): DoubleTensor = tensor.map { exp(it) }
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map { ln(it) } override fun StructureND<Double>.ln(): DoubleTensor = tensor.map { ln(it) }
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) } override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) }
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map { cos(it) } override fun StructureND<Double>.cos(): DoubleTensor = tensor.map { cos(it) }
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map { acos(it) } override fun StructureND<Double>.acos(): DoubleTensor = tensor.map { acos(it) }
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) } override fun StructureND<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) }
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) } override fun StructureND<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) }
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map { sin(it) } override fun StructureND<Double>.sin(): DoubleTensor = tensor.map { sin(it) }
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map { asin(it) } override fun StructureND<Double>.asin(): DoubleTensor = tensor.map { asin(it) }
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) } override fun StructureND<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) }
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) } override fun StructureND<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) }
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map { tan(it) } override fun StructureND<Double>.tan(): DoubleTensor = tensor.map { tan(it) }
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map { atan(it) } override fun StructureND<Double>.atan(): DoubleTensor = tensor.map { atan(it) }
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) } override fun StructureND<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) }
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) } override fun StructureND<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) }
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) } override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) }
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map { floor(it) } override fun StructureND<Double>.floor(): DoubleTensor = tensor.map { floor(it) }
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9) override fun StructureND<Double>.inv(): DoubleTensor = invLU(1e-9)
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9) override fun StructureND<Double>.det(): DoubleTensor = detLU(1e-9)
/** /**
* Computes the LU factorization of a matrix or batches of matrices `input`. * Computes the LU factorization of a matrix or batches of matrices `input`.
@ -736,7 +735,7 @@ public open class DoubleTensorAlgebra :
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. * The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/ */
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun StructureND<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
@ -749,7 +748,7 @@ public open class DoubleTensorAlgebra :
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. * The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/ */
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9) public fun StructureND<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
/** /**
* Unpacks the data and pivots from a LU factorization of a tensor. * Unpacks the data and pivots from a LU factorization of a tensor.
@ -763,7 +762,7 @@ public open class DoubleTensorAlgebra :
* @return triple of `P`, `L` and `U` tensors * @return triple of `P`, `L` and `U` tensors
*/ */
public fun luPivot( public fun luPivot(
luTensor: Tensor<Double>, luTensor: StructureND<Double>,
pivotsTensor: Tensor<Int>, pivotsTensor: Tensor<Int>,
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { ): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape) checkSquareMatrix(luTensor.shape)
@ -806,7 +805,7 @@ public open class DoubleTensorAlgebra :
* Used when checking the positive definiteness of the input matrix or matrices. * Used when checking the positive definiteness of the input matrix or matrices.
* @return a pair of `Q` and `R` tensors. * @return a pair of `Q` and `R` tensors.
*/ */
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor { public fun StructureND<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon) checkPositiveDefinite(tensor, epsilon)
@ -819,9 +818,9 @@ public open class DoubleTensorAlgebra :
return lTensor return lTensor
} }
override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6) override fun StructureND<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> { override fun StructureND<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
val rTensor = zeroesLike() val rTensor = zeroesLike()
@ -837,7 +836,7 @@ public open class DoubleTensorAlgebra :
return qTensor to rTensor return qTensor to rTensor
} }
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = override fun StructureND<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10) svd(epsilon = 1e-10)
/** /**
@ -853,7 +852,7 @@ public open class DoubleTensorAlgebra :
* i.e., the precision with which the cosine approaches 1 in an iterative algorithm. * i.e., the precision with which the cosine approaches 1 in an iterative algorithm.
* @return a triple `Triple(U, S, V)`. * @return a triple `Triple(U, S, V)`.
*/ */
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.dimension val size = tensor.dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = tensor.shape.sliceArray(size - 2 until size)
@ -886,7 +885,7 @@ public open class DoubleTensorAlgebra :
return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
} }
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEig(epsilon = 1e-15) override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEig(epsilon = 1e-15)
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
@ -896,7 +895,7 @@ public open class DoubleTensorAlgebra :
* and when the cosine approaches 1 in the SVD algorithm. * and when the cosine approaches 1 in the SVD algorithm.
* @return a pair `eigenvalues to eigenvectors`. * @return a pair `eigenvalues to eigenvectors`.
*/ */
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun StructureND<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(tensor, epsilon)
fun MutableStructure2D<Double>.cleanSym(n: Int) { fun MutableStructure2D<Double>.cleanSym(n: Int) {
@ -931,7 +930,7 @@ public open class DoubleTensorAlgebra :
* with zero. * with zero.
* @return the determinant. * @return the determinant.
*/ */
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun StructureND<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape) checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy() val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots() val pivotsTensor = tensor.setUpPivots()
@ -964,7 +963,7 @@ public open class DoubleTensorAlgebra :
* @param epsilon error in the LU algorithm&mdash;permissible error when comparing the determinant of a matrix with zero * @param epsilon error in the LU algorithm&mdash;permissible error when comparing the determinant of a matrix with zero
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
*/ */
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor { public fun StructureND<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon) val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
@ -989,12 +988,12 @@ public open class DoubleTensorAlgebra :
* @param epsilon permissible error when comparing the determinant of a matrix with zero. * @param epsilon permissible error when comparing the determinant of a matrix with zero.
* @return triple of `P`, `L` and `U` tensors. * @return triple of `P`, `L` and `U` tensors.
*/ */
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = tensor.luFactor(epsilon) val (lu, pivots) = tensor.luFactor(epsilon)
return luPivot(lu, pivots) return luPivot(lu, pivots)
} }
override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9) override fun StructureND<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
} }
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra