Merge branch 'dev' into feature/tensorflow

# Conflicts:
#	kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt
This commit is contained in:
Alexander Nozik 2021-10-28 09:55:43 +03:00
commit 7bdc54c818
29 changed files with 608 additions and 546 deletions

View File

@ -45,6 +45,7 @@
- Buffer algebra does not require size anymore - Buffer algebra does not require size anymore
- Operations -> Ops - Operations -> Ops
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes. - Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
- Tensor algebra takes read-only structures as input and inherits AlgebraND
### Deprecated ### Deprecated
- Specialized `DoubleBufferAlgebra` - Specialized `DoubleBufferAlgebra`

View File

@ -13,8 +13,7 @@ import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ones import org.jetbrains.kotlinx.multik.api.ones
import org.jetbrains.kotlinx.multik.ndarray.data.DN import org.jetbrains.kotlinx.multik.ndarray.data.DN
import org.jetbrains.kotlinx.multik.ndarray.data.DataType import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.multik.multikND import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.multik.multikTensorAlgebra
import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
@ -79,7 +78,7 @@ internal class NDFieldBenchmark {
} }
@Benchmark @Benchmark
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) { fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikAlgebra) {
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap() val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
repeat(n) { res += 1.0 } repeat(n) { res += 1.0 }
blackhole.consume(res) blackhole.consume(res)
@ -100,7 +99,7 @@ internal class NDFieldBenchmark {
private val specializedField = DoubleField.ndAlgebra private val specializedField = DoubleField.ndAlgebra
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing) private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
private val nd4jField = DoubleField.nd4j private val nd4jField = DoubleField.nd4j
private val multikField = DoubleField.multikND private val multikField = DoubleField.multikAlgebra
private val viktorField = DoubleField.viktorAlgebra private val viktorField = DoubleField.viktorAlgebra
} }
} }

View File

@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
this@StreamDoubleFieldND.shape, this@StreamDoubleFieldND.shape,
shape shape
) )
this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer this is BufferND && this.indices == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }

View File

@ -7,12 +7,12 @@ package space.kscience.kmath.tensors
import org.jetbrains.kotlinx.multik.api.Multik import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ndarray import org.jetbrains.kotlinx.multik.api.ndarray
import space.kscience.kmath.multik.multikND import space.kscience.kmath.multik.multikAlgebra
import space.kscience.kmath.nd.one import space.kscience.kmath.nd.one
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
fun main(): Unit = with(DoubleField.multikND) { fun main(): Unit = with(DoubleField.multikAlgebra) {
val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap() val a = Multik.ndarray(intArrayOf(1, 2, 3)).asType<Double>().wrap()
val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap() val b = Multik.ndarray(doubleArrayOf(1.0, 2.0, 3.0)).wrap()
one(a.shape) - a + b * 3 one(a.shape) - a + b * 3.0
} }

View File

@ -9,7 +9,7 @@ import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.toDoubleArray import space.kscience.kmath.tensors.core.copyArray
import kotlin.math.sqrt import kotlin.math.sqrt
const val seed = 100500L const val seed = 100500L
@ -111,7 +111,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra { private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra {
val onesForAnswers = yPred.zeroesLike() val onesForAnswers = yPred.zeroesLike()
yTrue.toDoubleArray().forEachIndexed { index, labelDouble -> yTrue.copyArray().forEachIndexed { index, labelDouble ->
val label = labelDouble.toInt() val label = labelDouble.toInt()
onesForAnswers[intArrayOf(index, label)] = 1.0 onesForAnswers[intArrayOf(index, label)] = 1.0
} }

View File

@ -131,19 +131,19 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
* Adds an element to ND structure of it. * Adds an element to ND structure of it.
* *
* @receiver the augend. * @receiver the augend.
* @param arg the addend. * @param other the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this public operator fun T.plus(other: StructureND<T>): StructureND<T> = other.map { value -> add(this@plus, value) }
/** /**
* Subtracts an ND structure from an element of it. * Subtracts an ND structure from an element of it.
* *
* @receiver the dividend. * @receiver the dividend.
* @param arg the divisor. * @param other the divisor.
* @return the quotient. * @return the quotient.
*/ */
public operator fun T.minus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(-this@minus, value) } public operator fun T.minus(other: StructureND<T>): StructureND<T> = other.map { value -> add(-this@minus, value) }
public companion object public companion object
} }

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.BufferFactory import space.kscience.kmath.structures.BufferFactory
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> { public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val indexerBuilder: (IntArray) -> ShapeIndex public val indexerBuilder: (IntArray) -> ShapeIndexer
public val bufferAlgebra: BufferAlgebra<T, A> public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
@ -43,7 +43,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
zipInline(left.toBufferND(), right.toBufferND(), transform) zipInline(left.toBufferND(), right.toBufferND(), transform)
public companion object { public companion object {
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke public val defaultIndexerBuilder: (IntArray) -> ShapeIndexer = DefaultStrides.Companion::invoke
} }
} }
@ -51,7 +51,7 @@ public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
arg: BufferND<T>, arg: BufferND<T>,
crossinline transform: A.(T) -> T crossinline transform: A.(T) -> T
): BufferND<T> { ): BufferND<T> {
val indexes = arg.indexes val indexes = arg.indices
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
} }
@ -59,7 +59,7 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
arg: BufferND<T>, arg: BufferND<T>,
crossinline transform: A.(index: IntArray, arg: T) -> T crossinline transform: A.(index: IntArray, arg: T) -> T
): BufferND<T> { ): BufferND<T> {
val indexes = arg.indexes val indexes = arg.indices
return BufferND( return BufferND(
indexes, indexes,
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
@ -73,32 +73,32 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
r: BufferND<T>, r: BufferND<T>,
crossinline block: A.(l: T, r: T) -> T crossinline block: A.(l: T, r: T) -> T
): BufferND<T> { ): BufferND<T> {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes val indexes = l.indices
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
} }
public open class BufferedGroupNDOps<T, out A : Group<T>>( public open class BufferedGroupNDOps<T, out A : Group<T>>(
override val bufferAlgebra: BufferAlgebra<T, A>, override val bufferAlgebra: BufferAlgebra<T, A>,
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder override val indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> { ) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it } override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
} }
public open class BufferedRingOpsND<T, out A : Ring<T>>( public open class BufferedRingOpsND<T, out A : Ring<T>>(
bufferAlgebra: BufferAlgebra<T, A>, bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A> ) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
public open class BufferedFieldOpsND<T, out A : Field<T>>( public open class BufferedFieldOpsND<T, out A : Field<T>>(
bufferAlgebra: BufferAlgebra<T, A>, bufferAlgebra: BufferAlgebra<T, A>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> { ) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
public constructor( public constructor(
elementAlgebra: A, elementAlgebra: A,
bufferFactory: BufferFactory<T>, bufferFactory: BufferFactory<T>,
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder indexerBuilder: (IntArray) -> ShapeIndexer = BufferAlgebraND.defaultIndexerBuilder
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder) ) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value } override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }

View File

@ -15,20 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Represents [StructureND] over [Buffer]. * Represents [StructureND] over [Buffer].
* *
* @param T the type of items. * @param T the type of items.
* @param indexes The strides to access elements of [Buffer] by linear indices. * @param indices The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val indexes: ShapeIndex, public val indices: ShapeIndexer,
public open val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)] override operator fun get(index: IntArray): T = buffer[indices.offset(index)]
override val shape: IntArray get() = indexes.shape override val shape: IntArray get() = indices.shape
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indices.indices().map {
it to this[it] it to this[it]
} }
@ -43,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): BufferND<R> { ): BufferND<R> {
return if (this is BufferND<T>) return if (this is BufferND<T>)
BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
@ -58,11 +58,11 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public class MutableBufferND<T>( public class MutableBufferND<T>(
strides: ShapeIndex, strides: ShapeIndexer,
override val buffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, buffer) { ) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
buffer[indexes.offset(index)] = value buffer[indices.offset(index)] = value
} }
} }
@ -74,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) MutableBufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -13,7 +13,7 @@ import kotlin.contracts.contract
import kotlin.math.pow import kotlin.math.pow
public class DoubleBufferND( public class DoubleBufferND(
indexes: ShapeIndex, indexes: ShapeIndexer,
override val buffer: DoubleBuffer, override val buffer: DoubleBuffer,
) : BufferND<Double>(indexes, buffer) ) : BufferND<Double>(indexes, buffer)
@ -33,7 +33,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
arg: DoubleBufferND, arg: DoubleBufferND,
transform: (Double) -> Double transform: (Double) -> Double
): DoubleBufferND { ): DoubleBufferND {
val indexes = arg.indexes val indexes = arg.indices
val array = arg.buffer.array val array = arg.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) }) return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
} }
@ -43,8 +43,8 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
r: DoubleBufferND, r: DoubleBufferND,
block: (l: Double, r: Double) -> Double block: (l: Double, r: Double) -> Double
): DoubleBufferND { ): DoubleBufferND {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes val indexes = l.indices
val lArray = l.buffer.array val lArray = l.buffer.array
val rArray = r.buffer.array val rArray = r.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) }) return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })

View File

@ -10,7 +10,7 @@ import kotlin.native.concurrent.ThreadLocal
/** /**
* A converter from linear index to multivariate index * A converter from linear index to multivariate index
*/ */
public interface ShapeIndex{ public interface ShapeIndexer{
public val shape: Shape public val shape: Shape
/** /**
@ -42,7 +42,7 @@ public interface ShapeIndex{
/** /**
* Linear transformation of indexes * Linear transformation of indexes
*/ */
public abstract class Strides: ShapeIndex { public abstract class Strides: ShapeIndexer {
/** /**
* Array strides * Array strides
*/ */

View File

@ -100,7 +100,7 @@ public interface MutableStructure2D<T> : Structure2D<T>, MutableStructureND<T> {
*/ */
@JvmInline @JvmInline
private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> { private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : Structure2D<T> {
override val shape: IntArray get() = structure.shape override val shape: Shape get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1] override val colNum: Int get() = shape[1]
@ -116,9 +116,8 @@ private value class Structure2DWrapper<out T>(val structure: StructureND<T>) : S
/** /**
* A 2D wrapper for a mutable nd-structure * A 2D wrapper for a mutable nd-structure
*/ */
private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>): MutableStructure2D<T> private class MutableStructure2DWrapper<T>(val structure: MutableStructureND<T>) : MutableStructure2D<T> {
{ override val shape: Shape get() = structure.shape
override val shape: IntArray get() = structure.shape
override val rowNum: Int get() = shape[0] override val rowNum: Int get() = shape[0]
override val colNum: Int get() = shape[1] override val colNum: Int get() = shape[1]
@ -152,7 +151,8 @@ public fun <T> StructureND<T>.as2D(): Structure2D<T> = this as? Structure2D<T> ?
/** /**
* Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch. * Represents a [StructureND] as [Structure2D]. Throws runtime error in case of dimension mismatch.
*/ */
public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> = this as? MutableStructure2D<T> ?: when (shape.size) { public fun <T> MutableStructureND<T>.as2D(): MutableStructure2D<T> =
this as? MutableStructure2D<T> ?: when (shape.size) {
2 -> MutableStructure2DWrapper(this) 2 -> MutableStructure2DWrapper(this)
else -> error("Can't create 2d-structure from ${shape.size}d-structure") else -> error("Can't create 2d-structure from ${shape.size}d-structure")
} }

View File

@ -33,7 +33,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
* The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of * The shape of structure i.e., non-empty sequence of non-negative integers that specify sizes of dimensions of
* this structure. * this structure.
*/ */
public val shape: IntArray public val shape: Shape
/** /**
* The count of dimensions in this structure. It should be equal to size of [shape]. * The count of dimensions in this structure. It should be equal to size of [shape].
@ -71,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -87,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided

View File

@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
* Map one [BufferND] using function without indices. * Map one [BufferND] using function without indices.
*/ */
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> { public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) } val array = DoubleArray(indices.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(indexes, DoubleBuffer(array)) return BufferND(indices, DoubleBuffer(array))
} }
/** /**

View File

@ -1,137 +0,0 @@
package space.kscience.kmath.multik
import org.jetbrains.kotlinx.multik.api.math.cos
import org.jetbrains.kotlinx.multik.api.math.sin
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.RingOpsND
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.*
/**
* A ring algebra for Multik operations
*/
public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
public val type: DataType,
override val elementAlgebra: A
) : RingOpsND<T, A> {
public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val res = mk.zeros<T, DN>(shape, type).asDNArray()
for (index in res.multiIndices) {
res[index] = elementAlgebra.initializer(index)
}
return res.wrap()
}
public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
this
} else {
structureND(shape) { get(it) }
}
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el)
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> {
//taken directly from Multik sources
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
val indexIter = array.multiIndices.iterator()
var index = 0
for (item in array) {
if (indexIter.hasNext()) {
data[index++] = elementAlgebra.transform(indexIter.next(), item)
} else {
throw ArithmeticException("Index overflow has happened.")
}
}
return NDArray(data, shape = array.shape, dim = array.dim).wrap()
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type)
var counter = 0
val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator()
//iterating them together
while (leftIterator.hasNext()) {
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
}
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
}
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = asMultik().array.unaryMinus().wrap()
override fun add(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
(left.asMultik().array + right.asMultik().array).wrap()
override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
asMultik().array.plus(arg).wrap()
override fun StructureND<T>.minus(arg: T): MultikTensor<T> = asMultik().array.minus(arg).wrap()
override fun T.plus(arg: StructureND<T>): MultikTensor<T> = arg + this
override fun T.minus(arg: StructureND<T>): MultikTensor<T> = arg.map { this@minus - it }
override fun multiply(left: StructureND<T>, right: StructureND<T>): MultikTensor<T> =
left.asMultik().array.times(right.asMultik().array).wrap()
override fun StructureND<T>.times(arg: T): MultikTensor<T> =
asMultik().array.times(arg).wrap()
override fun T.times(arg: StructureND<T>): MultikTensor<T> = arg * this
override fun StructureND<T>.unaryPlus(): MultikTensor<T> = asMultik()
override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap()
override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap()
override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap()
}
/**
* A field algebra for multik operations
*/
public class MultikFieldOpsND<T, A : Field<T>> internal constructor(
type: DataType,
elementAlgebra: A
) : MultikRingOpsND<T, A>(type, elementAlgebra), FieldOpsND<T, A> {
override fun StructureND<T>.div(other: StructureND<T>): StructureND<T> =
asMultik().array.div(other.asMultik().array).wrap()
}
public val DoubleField.multikND: MultikFieldOpsND<Double, DoubleField>
get() = MultikFieldOpsND(DataType.DoubleDataType, DoubleField)
public val FloatField.multikND: MultikFieldOpsND<Float, FloatField>
get() = MultikFieldOpsND(DataType.FloatDataType, FloatField)
public val ShortRing.multikND: MultikRingOpsND<Short, ShortRing>
get() = MultikRingOpsND(DataType.ShortDataType, ShortRing)
public val IntRing.multikND: MultikRingOpsND<Int, IntRing>
get() = MultikRingOpsND(DataType.IntDataType, IntRing)
public val LongRing.multikND: MultikRingOpsND<Long, LongRing>
get() = MultikRingOpsND(DataType.LongDataType, LongRing)

View File

@ -15,14 +15,18 @@ import org.jetbrains.kotlinx.multik.api.zeros
import org.jetbrains.kotlinx.multik.ndarray.data.* import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
@JvmInline @JvmInline
public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> { public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) : Tensor<T> {
override val shape: IntArray get() = array.shape override val shape: Shape get() = array.shape
override fun get(index: IntArray): T = array[index] override fun get(index: IntArray): T = array[index]
@ -48,18 +52,70 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
else throw ClassCastException("Cannot cast MultiArray to NDArray.") else throw ClassCastException("Cannot cast MultiArray to NDArray.")
} }
public class MultikTensorAlgebra<T : Number> internal constructor( public abstract class MultikTensorAlgebra<T, A : Ring<T>> : TensorAlgebra<T, A> where T : Number, T : Comparable<T> {
public val type: DataType,
public val elementAlgebra: Ring<T>, public abstract val type: DataType
public val comparator: Comparator<T>
) : TensorAlgebra<T> { override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val strides = DefaultStrides(shape)
val memoryView = initMemoryView<T>(strides.linearSize, type)
strides.indices().forEachIndexed { linearIndex, tensorIndex ->
memoryView[linearIndex] = elementAlgebra.initializer(tensorIndex)
}
return MultikTensor(NDArray(memoryView, shape = shape, dim = DN(shape.size)))
}
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> = if (this is MultikTensor) {
val data = initMemoryView<T>(array.size, type)
var count = 0
for (el in array) data[count++] = elementAlgebra.transform(el)
NDArray(data, shape = shape, dim = array.dim).wrap()
} else {
structureND(shape) { index ->
transform(get(index))
}
}
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): MultikTensor<T> =
if (this is MultikTensor) {
val array = asMultik().array
val data = initMemoryView<T>(array.size, type)
val indexIter = array.multiIndices.iterator()
var index = 0
for (item in array) {
if (indexIter.hasNext()) {
data[index++] = elementAlgebra.transform(indexIter.next(), item)
} else {
throw ArithmeticException("Index overflow has happened.")
}
}
NDArray(data, shape = array.shape, dim = array.dim).wrap()
} else {
structureND(shape) { index ->
transform(index, get(index))
}
}
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): MultikTensor<T> {
require(left.shape.contentEquals(right.shape)) { "ND array shape mismatch" } //TODO replace by ShapeMismatchException
val leftArray = left.asMultik().array
val rightArray = right.asMultik().array
val data = initMemoryView<T>(leftArray.size, type)
var counter = 0
val leftIterator = leftArray.iterator()
val rightIterator = rightArray.iterator()
//iterating them together
while (leftIterator.hasNext()) {
data[counter++] = elementAlgebra.transform(leftIterator.next(), rightIterator.next())
}
return NDArray(data, shape = leftArray.shape, dim = leftArray.dim).wrap()
}
/** /**
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source * are not reflected back onto the source
*/ */
public fun Tensor<T>.asMultik(): MultikTensor<T> { public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
return if (this is MultikTensor) {
this this
} else { } else {
val res = mk.zeros<T, DN>(shape, type).asDNArray() val res = mk.zeros<T, DN>(shape, type).asDNArray()
@ -68,21 +124,20 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
res.wrap() res.wrap()
} }
}
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this) public fun MutableMultiArray<T, *>.wrap(): MultikTensor<T> = MultikTensor(this.asDNArray())
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0)) get(intArrayOf(0))
} else null } else null
override fun T.plus(other: Tensor<T>): MultikTensor<T> = override fun T.plus(other: StructureND<T>): MultikTensor<T> =
other.plus(this) other.plus(this)
override fun Tensor<T>.plus(value: T): MultikTensor<T> = override fun StructureND<T>.plus(arg: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() asMultik().array.deepCopy().apply { plusAssign(arg) }.wrap()
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap() asMultik().array.plus(other.asMultik().array).wrap()
override fun Tensor<T>.plusAssign(value: T) { override fun Tensor<T>.plusAssign(value: T) {
@ -93,7 +148,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.plusAssign(other: Tensor<T>) { override fun Tensor<T>.plusAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.plusAssign(other.asMultik().array) array.plusAssign(other.asMultik().array)
} else { } else {
@ -101,12 +156,12 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun T.minus(other: Tensor<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap() override fun T.minus(other: StructureND<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap()
override fun Tensor<T>.minus(value: T): MultikTensor<T> = override fun StructureND<T>.minus(arg: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap() asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap()
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap() asMultik().array.minus(other.asMultik().array).wrap()
override fun Tensor<T>.minusAssign(value: T) { override fun Tensor<T>.minusAssign(value: T) {
@ -117,7 +172,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.minusAssign(other: Tensor<T>) { override fun Tensor<T>.minusAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.minusAssign(other.asMultik().array) array.minusAssign(other.asMultik().array)
} else { } else {
@ -125,13 +180,13 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun T.times(other: Tensor<T>): MultikTensor<T> = override fun T.times(arg: StructureND<T>): MultikTensor<T> =
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
override fun Tensor<T>.times(value: T): Tensor<T> = override fun StructureND<T>.times(arg: T): Tensor<T> =
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap() asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap()
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap() asMultik().array.times(other.asMultik().array).wrap()
override fun Tensor<T>.timesAssign(value: T) { override fun Tensor<T>.timesAssign(value: T) {
@ -142,7 +197,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.timesAssign(other: Tensor<T>) { override fun Tensor<T>.timesAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.timesAssign(other.asMultik().array) array.timesAssign(other.asMultik().array)
} else { } else {
@ -150,14 +205,14 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.unaryMinus(): MultikTensor<T> = override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap() asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap() override fun StructureND<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap() override fun StructureND<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
override fun Tensor<T>.view(shape: IntArray): MultikTensor<T> { override fun StructureND<T>.view(shape: IntArray): MultikTensor<T> {
require(shape.all { it > 0 }) require(shape.all { it > 0 })
require(shape.fold(1, Int::times) == this.shape.size) { require(shape.fold(1, Int::times) == this.shape.size) {
"Cannot reshape array of size ${this.shape.size} into a new shape ${ "Cannot reshape array of size ${this.shape.size} into a new shape ${
@ -170,16 +225,15 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
val mt = asMultik().array val mt = asMultik().array
return if (mt.shape.contentEquals(shape)) { return if (mt.shape.contentEquals(shape)) {
@Suppress("UNCHECKED_CAST") mt
this as NDArray<T, DN>
} else { } else {
NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt) NDArray(mt.data, mt.offset, shape, dim = DN(shape.size), base = mt.base ?: mt)
}.wrap() }.wrap()
} }
override fun Tensor<T>.viewAs(other: Tensor<T>): MultikTensor<T> = view(other.shape) override fun StructureND<T>.viewAs(other: StructureND<T>): MultikTensor<T> = view(other.shape)
override fun Tensor<T>.dot(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.dot(other: StructureND<T>): MultikTensor<T> =
if (this.shape.size == 1 && other.shape.size == 1) { if (this.shape.size == 1 && other.shape.size == 1) {
Multik.ndarrayOf( Multik.ndarrayOf(
asMultik().array.asD1Array() dot other.asMultik().array.asD1Array() asMultik().array.asD1Array() dot other.asMultik().array.asD1Array()
@ -196,45 +250,95 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
TODO("Diagonal embedding not implemented") TODO("Diagonal embedding not implemented")
} }
override fun Tensor<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T -> override fun StructureND<T>.sum(): T = asMultik().array.reduceMultiIndexed { _: IntArray, acc: T, t: T ->
elementAlgebra.add(acc, t) elementAlgebra.add(acc, t)
} }
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.min(): T = override fun StructureND<T>.min(): T? = asMultik().array.min()
asMultik().array.minWith(comparator) ?: error("No elements in tensor")
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.max(): T = override fun StructureND<T>.max(): T? = asMultik().array.max()
asMultik().array.maxWith(comparator) ?: error("No elements in tensor")
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
} }
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double> public abstract class MultikDivisionTensorAlgebra<T, A : Field<T>>
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } : MultikTensorAlgebra<T, A>(), TensorPartialDivisionAlgebra<T, A> where T : Number, T : Comparable<T> {
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float> override fun T.div(arg: StructureND<T>): MultikTensor<T> = arg.map { elementAlgebra.divide(this@div, it) }
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short> override fun StructureND<T>.div(arg: T): MultikTensor<T> =
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } asMultik().array.deepCopy().apply { divAssign(arg) }.wrap()
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int> override fun StructureND<T>.div(other: StructureND<T>): MultikTensor<T> =
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } asMultik().array.div(other.asMultik().array).wrap()
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long> override fun Tensor<T>.divAssign(value: T) {
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } if (this is MultikTensor) {
array.divAssign(value)
} else {
mapInPlace { _, t -> elementAlgebra.divide(t, value) }
}
}
override fun Tensor<T>.divAssign(other: StructureND<T>) {
if (this is MultikTensor) {
array.divAssign(other.asMultik().array)
} else {
mapInPlace { index, t -> elementAlgebra.divide(t, other[index]) }
}
}
}
public object MultikDoubleAlgebra : MultikDivisionTensorAlgebra<Double, DoubleField>() {
override val elementAlgebra: DoubleField get() = DoubleField
override val type: DataType get() = DataType.DoubleDataType
}
public val Double.Companion.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
public val DoubleField.multikAlgebra: MultikTensorAlgebra<Double, DoubleField> get() = MultikDoubleAlgebra
public object MultikFloatAlgebra : MultikDivisionTensorAlgebra<Float, FloatField>() {
override val elementAlgebra: FloatField get() = FloatField
override val type: DataType get() = DataType.FloatDataType
}
public val Float.Companion.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public val FloatField.multikAlgebra: MultikTensorAlgebra<Float, FloatField> get() = MultikFloatAlgebra
public object MultikShortAlgebra : MultikTensorAlgebra<Short, ShortRing>() {
override val elementAlgebra: ShortRing get() = ShortRing
override val type: DataType get() = DataType.ShortDataType
}
public val Short.Companion.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public val ShortRing.multikAlgebra: MultikTensorAlgebra<Short, ShortRing> get() = MultikShortAlgebra
public object MultikIntAlgebra : MultikTensorAlgebra<Int, IntRing>() {
override val elementAlgebra: IntRing get() = IntRing
override val type: DataType get() = DataType.IntDataType
}
public val Int.Companion.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public val IntRing.multikAlgebra: MultikTensorAlgebra<Int, IntRing> get() = MultikIntAlgebra
public object MultikLongAlgebra : MultikTensorAlgebra<Long, LongRing>() {
override val elementAlgebra: LongRing get() = LongRing
override val type: DataType get() = DataType.LongDataType
}
public val Long.Companion.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra
public val LongRing.multikAlgebra: MultikTensorAlgebra<Long, LongRing> get() = MultikLongAlgebra

View File

@ -7,7 +7,7 @@ import space.kscience.kmath.operations.invoke
internal class MultikNDTest { internal class MultikNDTest {
@Test @Test
fun basicAlgebra(): Unit = DoubleField.multikND{ fun basicAlgebra(): Unit = DoubleField.multikAlgebra{
one(2,2) + 1.0 one(2,2) + 1.0
} }
} }

View File

@ -13,7 +13,11 @@ import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.ops.NDBase import org.nd4j.linalg.factory.ops.NDBase
import org.nd4j.linalg.ops.transforms.Transforms import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.DefaultStrides
import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.Field
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
@ -22,7 +26,8 @@ import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
/** /**
* ND4J based [TensorAlgebra] implementation. * ND4J based [TensorAlgebra] implementation.
*/ */
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> { public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTensorAlgebra<T, A> {
/** /**
* Wraps [INDArray] to [Nd4jArrayStructure]. * Wraps [INDArray] to [Nd4jArrayStructure].
*/ */
@ -33,105 +38,121 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap() override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T>
override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap()
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap() override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> =
structureND(shape) { index -> elementAlgebra.transform(get(index)) }
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> =
structureND(shape) { index -> elementAlgebra.transform(index, get(index)) }
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
require(left.shape.contentEquals(right.shape))
return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
}
override fun T.plus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.add(this).wrap()
override fun StructureND<T>.plus(arg: T): Nd4jArrayStructure<T> = ndArray.add(arg).wrap()
override fun StructureND<T>.plus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.add(other.ndArray).wrap()
override fun Tensor<T>.plusAssign(value: T) { override fun Tensor<T>.plusAssign(value: T) {
ndArray.addi(value) ndArray.addi(value)
} }
override fun Tensor<T>.plusAssign(other: Tensor<T>) { override fun Tensor<T>.plusAssign(other: StructureND<T>) {
ndArray.addi(other.ndArray) ndArray.addi(other.ndArray)
} }
override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap() override fun T.minus(other: StructureND<T>): Nd4jArrayStructure<T> = other.ndArray.rsub(this).wrap()
override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap() override fun StructureND<T>.minus(arg: T): Nd4jArrayStructure<T> = ndArray.sub(arg).wrap()
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap() override fun StructureND<T>.minus(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.sub(other.ndArray).wrap()
override fun Tensor<T>.minusAssign(value: T) { override fun Tensor<T>.minusAssign(value: T) {
ndArray.rsubi(value) ndArray.rsubi(value)
} }
override fun Tensor<T>.minusAssign(other: Tensor<T>) { override fun Tensor<T>.minusAssign(other: StructureND<T>) {
ndArray.subi(other.ndArray) ndArray.subi(other.ndArray)
} }
override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap() override fun T.times(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.mul(this).wrap()
override fun Tensor<T>.times(value: T): Tensor<T> = override fun StructureND<T>.times(arg: T): Nd4jArrayStructure<T> =
ndArray.mul(value).wrap() ndArray.mul(arg).wrap()
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap() override fun StructureND<T>.times(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mul(other.ndArray).wrap()
override fun Tensor<T>.timesAssign(value: T) { override fun Tensor<T>.timesAssign(value: T) {
ndArray.muli(value) ndArray.muli(value)
} }
override fun Tensor<T>.timesAssign(other: Tensor<T>) { override fun Tensor<T>.timesAssign(other: StructureND<T>) {
ndArray.mmuli(other.ndArray) ndArray.mmuli(other.ndArray)
} }
override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap() override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap() override fun StructureND<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap() override fun StructureND<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap() override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> = override fun StructureND<T>.min(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.min(keepDim, dim).wrap() ndArray.min(keepDim, dim).wrap()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> = override fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.sum(keepDim, dim).wrap() ndArray.sum(keepDim, dim).wrap()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> = override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.max(keepDim, dim).wrap() ndArray.max(keepDim, dim).wrap()
override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap() override fun StructureND<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape) override fun StructureND<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> = override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap() ndBase.get().argmax(ndArray, keepDim, dim).wrap()
override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap() override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
ndArray.mean(keepDim, dim).wrap()
override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap() override fun StructureND<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap() override fun StructureND<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap() override fun StructureND<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap() override fun StructureND<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap() override fun StructureND<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap() override fun StructureND<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
override fun Tensor<T>.acosh(): Tensor<T> = override fun StructureND<T>.acosh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap() override fun StructureND<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap() override fun StructureND<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap() override fun StructureND<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
override fun Tensor<T>.asinh(): Tensor<T> = override fun StructureND<T>.asinh(): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap() Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap() override fun StructureND<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap() override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap() override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap() override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap() override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap() override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap() override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap() ndArray.std(true, keepDim, dim).wrap()
override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap() override fun T.div(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.rdiv(this).wrap()
override fun StructureND<T>.div(arg: T): Nd4jArrayStructure<T> = ndArray.div(arg).wrap()
override fun StructureND<T>.div(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.div(other.ndArray).wrap()
override fun Tensor<T>.divAssign(value: T) { override fun Tensor<T>.divAssign(value: T) {
ndArray.divi(value) ndArray.divi(value)
} }
override fun Tensor<T>.divAssign(other: Tensor<T>) { override fun Tensor<T>.divAssign(other: StructureND<T>) {
ndArray.divi(other.ndArray) ndArray.divi(other.ndArray)
} }
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> = override fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap() Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
private companion object { private companion object {
@ -142,9 +163,22 @@ public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T>
/** /**
* [Double] specialization of [Nd4jTensorAlgebra]. * [Double] specialization of [Nd4jTensorAlgebra].
*/ */
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> { public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
override val elementAlgebra: DoubleField get() = DoubleField
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure() override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure<Double> {
val array: INDArray = Nd4j.zeros(*shape)
val indices = DefaultStrides(shape)
indices.indices().forEach { index ->
array.putScalar(index, elementAlgebra.initializer(index))
}
return array.wrap()
}
@OptIn(PerformancePitfall::class) @OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray override val StructureND<Double>.ndArray: INDArray
get() = when (this) { get() = when (this) {
@ -154,7 +188,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
} }
} }
override fun Tensor<Double>.valueOrNull(): Double? = override fun StructureND<Double>.valueOrNull(): Double? =
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
// TODO rewrite // TODO rewrite
@ -165,10 +199,10 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
dim2: Int, dim2: Int,
): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2) ): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2)
override fun Tensor<Double>.sum(): Double = ndArray.sumNumber().toDouble() override fun StructureND<Double>.sum(): Double = ndArray.sumNumber().toDouble()
override fun Tensor<Double>.min(): Double = ndArray.minNumber().toDouble() override fun StructureND<Double>.min(): Double = ndArray.minNumber().toDouble()
override fun Tensor<Double>.max(): Double = ndArray.maxNumber().toDouble() override fun StructureND<Double>.max(): Double = ndArray.maxNumber().toDouble()
override fun Tensor<Double>.mean(): Double = ndArray.meanNumber().toDouble() override fun StructureND<Double>.mean(): Double = ndArray.meanNumber().toDouble()
override fun Tensor<Double>.std(): Double = ndArray.stdNumber().toDouble() override fun StructureND<Double>.std(): Double = ndArray.stdNumber().toDouble()
override fun Tensor<Double>.variance(): Double = ndArray.varNumber().toDouble() override fun StructureND<Double>.variance(): Double = ndArray.varNumber().toDouble()
} }

View File

@ -5,18 +5,21 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field
/** /**
* Analytic operations on [Tensor]. * Analytic operations on [Tensor].
* *
* @param T the type of items closed under analytic functions in the tensors. * @param T the type of items closed under analytic functions in the tensors.
*/ */
public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> { public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionAlgebra<T, A> {
/** /**
* @return the mean of all elements in the input tensor. * @return the mean of all elements in the input tensor.
*/ */
public fun Tensor<T>.mean(): T public fun StructureND<T>.mean(): T
/** /**
* Returns the mean of each row of the input tensor in the given dimension [dim]. * Returns the mean of each row of the input tensor in the given dimension [dim].
@ -29,12 +32,12 @@ public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the mean of each row of the input tensor in the given dimension [dim]. * @return the mean of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the standard deviation of all elements in the input tensor. * @return the standard deviation of all elements in the input tensor.
*/ */
public fun Tensor<T>.std(): T public fun StructureND<T>.std(): T
/** /**
* Returns the standard deviation of each row of the input tensor in the given dimension [dim]. * Returns the standard deviation of each row of the input tensor in the given dimension [dim].
@ -47,12 +50,12 @@ public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the standard deviation of each row of the input tensor in the given dimension [dim]. * @return the standard deviation of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.std(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the variance of all elements in the input tensor. * @return the variance of all elements in the input tensor.
*/ */
public fun Tensor<T>.variance(): T public fun StructureND<T>.variance(): T
/** /**
* Returns the variance of each row of the input tensor in the given dimension [dim]. * Returns the variance of each row of the input tensor in the given dimension [dim].
@ -65,57 +68,57 @@ public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the variance of each row of the input tensor in the given dimension [dim]. * @return the variance of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html //For information: https://pytorch.org/docs/stable/generated/torch.exp.html
public fun Tensor<T>.exp(): Tensor<T> public fun StructureND<T>.exp(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.log.html //For information: https://pytorch.org/docs/stable/generated/torch.log.html
public fun Tensor<T>.ln(): Tensor<T> public fun StructureND<T>.ln(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
public fun Tensor<T>.sqrt(): Tensor<T> public fun StructureND<T>.sqrt(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
public fun Tensor<T>.cos(): Tensor<T> public fun StructureND<T>.cos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
public fun Tensor<T>.acos(): Tensor<T> public fun StructureND<T>.acos(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
public fun Tensor<T>.cosh(): Tensor<T> public fun StructureND<T>.cosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
public fun Tensor<T>.acosh(): Tensor<T> public fun StructureND<T>.acosh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
public fun Tensor<T>.sin(): Tensor<T> public fun StructureND<T>.sin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
public fun Tensor<T>.asin(): Tensor<T> public fun StructureND<T>.asin(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
public fun Tensor<T>.sinh(): Tensor<T> public fun StructureND<T>.sinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
public fun Tensor<T>.asinh(): Tensor<T> public fun StructureND<T>.asinh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
public fun Tensor<T>.tan(): Tensor<T> public fun StructureND<T>.tan(): Tensor<T>
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
public fun Tensor<T>.atan(): Tensor<T> public fun StructureND<T>.atan(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
public fun Tensor<T>.tanh(): Tensor<T> public fun StructureND<T>.tanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
public fun Tensor<T>.atanh(): Tensor<T> public fun StructureND<T>.atanh(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
public fun Tensor<T>.ceil(): Tensor<T> public fun StructureND<T>.ceil(): Tensor<T>
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
public fun Tensor<T>.floor(): Tensor<T> public fun StructureND<T>.floor(): Tensor<T>
} }

View File

@ -5,12 +5,15 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field
/** /**
* Common linear algebra operations. Operates on [Tensor]. * Common linear algebra operations. Operates on [Tensor].
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
*/ */
public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> { public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionAlgebra<T, A> {
/** /**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input. * Computes the determinant of a square matrix input, or of each square matrix in a batched input.
@ -18,7 +21,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* *
* @return the determinant. * @return the determinant.
*/ */
public fun Tensor<T>.det(): Tensor<T> public fun StructureND<T>.det(): Tensor<T>
/** /**
* Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input.
@ -28,7 +31,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* *
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
*/ */
public fun Tensor<T>.inv(): Tensor<T> public fun StructureND<T>.inv(): Tensor<T>
/** /**
* Cholesky decomposition. * Cholesky decomposition.
@ -44,7 +47,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @receiver the `input`. * @receiver the `input`.
* @return the batch of `L` matrices. * @return the batch of `L` matrices.
*/ */
public fun Tensor<T>.cholesky(): Tensor<T> public fun StructureND<T>.cholesky(): Tensor<T>
/** /**
* QR decomposition. * QR decomposition.
@ -58,7 +61,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @receiver the `input`. * @receiver the `input`.
* @return pair of `Q` and `R` tensors. * @return pair of `Q` and `R` tensors.
*/ */
public fun Tensor<T>.qr(): Pair<Tensor<T>, Tensor<T>> public fun StructureND<T>.qr(): Pair<Tensor<T>, Tensor<T>>
/** /**
* LUP decomposition * LUP decomposition
@ -72,7 +75,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @receiver the `input`. * @receiver the `input`.
* @return triple of P, L and U tensors * @return triple of P, L and U tensors
*/ */
public fun Tensor<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>> public fun StructureND<T>.lu(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Singular Value Decomposition. * Singular Value Decomposition.
@ -88,7 +91,7 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @receiver the `input`. * @receiver the `input`.
* @return triple `Triple(U, S, V)`. * @return triple `Triple(U, S, V)`.
*/ */
public fun Tensor<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>> public fun StructureND<T>.svd(): Triple<Tensor<T>, Tensor<T>, Tensor<T>>
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix `input` or a batch of real symmetric matrices,
@ -98,6 +101,6 @@ public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> {
* @receiver the `input`. * @receiver the `input`.
* @return a pair `eigenvalues to eigenvectors` * @return a pair `eigenvalues to eigenvectors`
*/ */
public fun Tensor<T>.symEig(): Pair<Tensor<T>, Tensor<T>> public fun StructureND<T>.symEig(): Pair<Tensor<T>, Tensor<T>>
} }

View File

@ -5,7 +5,9 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.RingOps import space.kscience.kmath.nd.RingOpsND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Ring
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].
@ -13,20 +15,20 @@ import space.kscience.kmath.operations.RingOps
* *
* @param T the type of items in the tensors. * @param T the type of items in the tensors.
*/ */
public interface TensorAlgebra<T> : RingOps<Tensor<T>> { public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
/** /**
* Returns a single tensor value of unit dimension if tensor shape equals to [1]. * Returns a single tensor value of unit dimension if tensor shape equals to [1].
* *
* @return a nullable value of a potentially scalar tensor. * @return a nullable value of a potentially scalar tensor.
*/ */
public fun Tensor<T>.valueOrNull(): T? public fun StructureND<T>.valueOrNull(): T?
/** /**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
* *
* @return the value of a scalar tensor. * @return the value of a scalar tensor.
*/ */
public fun Tensor<T>.value(): T = public fun StructureND<T>.value(): T =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/** /**
@ -36,15 +38,15 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this value and tensor [other]. * @return the sum of this value and tensor [other].
*/ */
public operator fun T.plus(other: Tensor<T>): Tensor<T> override operator fun T.plus(other: StructureND<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. * Adds the scalar [arg] to each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be added to each element of this tensor. * @param arg the number to be added to each element of this tensor.
* @return the sum of this tensor and [value]. * @return the sum of this tensor and [arg].
*/ */
public operator fun Tensor<T>.plus(value: T): Tensor<T> override operator fun StructureND<T>.plus(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is added to each element of this tensor. * Each element of the tensor [other] is added to each element of this tensor.
@ -53,7 +55,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this tensor and [other]. * @return the sum of this tensor and [other].
*/ */
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.plus(other: StructureND<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor. * Adds the scalar [value] to each element of this tensor.
@ -67,7 +69,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be added. * @param other tensor to be added.
*/ */
public operator fun Tensor<T>.plusAssign(other: Tensor<T>) public operator fun Tensor<T>.plusAssign(other: StructureND<T>)
/** /**
* Each element of the tensor [other] is subtracted from this value. * Each element of the tensor [other] is subtracted from this value.
@ -76,15 +78,15 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this value and tensor [other]. * @return the difference between this value and tensor [other].
*/ */
public operator fun T.minus(other: Tensor<T>): Tensor<T> override operator fun T.minus(other: StructureND<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be subtracted from each element of this tensor. * @param arg the number to be subtracted from each element of this tensor.
* @return the difference between this tensor and [value]. * @return the difference between this tensor and [arg].
*/ */
public operator fun Tensor<T>.minus(value: T): Tensor<T> override operator fun StructureND<T>.minus(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is subtracted from each element of this tensor. * Each element of the tensor [other] is subtracted from each element of this tensor.
@ -93,7 +95,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this tensor and [other]. * @return the difference between this tensor and [other].
*/ */
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.minus(other: StructureND<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor. * Subtracts the scalar [value] from each element of this tensor.
@ -107,25 +109,25 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
*/ */
public operator fun Tensor<T>.minusAssign(other: Tensor<T>) public operator fun Tensor<T>.minusAssign(other: StructureND<T>)
/** /**
* Each element of the tensor [other] is multiplied by this value. * Each element of the tensor [arg] is multiplied by this value.
* The resulting tensor is returned. * The resulting tensor is returned.
* *
* @param other tensor to be multiplied. * @param arg tensor to be multiplied.
* @return the product of this value and tensor [other]. * @return the product of this value and tensor [arg].
*/ */
public operator fun T.times(other: Tensor<T>): Tensor<T> override operator fun T.times(arg: StructureND<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. * Multiplies the scalar [arg] by each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be multiplied by each element of this tensor. * @param arg the number to be multiplied by each element of this tensor.
* @return the product of this tensor and [value]. * @return the product of this tensor and [arg].
*/ */
public operator fun Tensor<T>.times(value: T): Tensor<T> override operator fun StructureND<T>.times(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is multiplied by each element of this tensor. * Each element of the tensor [other] is multiplied by each element of this tensor.
@ -134,7 +136,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this tensor and [other]. * @return the product of this tensor and [other].
*/ */
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.times(other: StructureND<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor. * Multiplies the scalar [value] by each element of this tensor.
@ -148,14 +150,14 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
*/ */
public operator fun Tensor<T>.timesAssign(other: Tensor<T>) public operator fun Tensor<T>.timesAssign(other: StructureND<T>)
/** /**
* Numerical negative, element-wise. * Numerical negative, element-wise.
* *
* @return tensor negation of the original tensor. * @return tensor negation of the original tensor.
*/ */
override fun Tensor<T>.unaryMinus(): Tensor<T> override operator fun StructureND<T>.unaryMinus(): Tensor<T>
/** /**
* Returns the tensor at index i * Returns the tensor at index i
@ -164,7 +166,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param i index of the extractable tensor * @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i] * @return subtensor of the original tensor with index [i]
*/ */
public operator fun Tensor<T>.get(i: Int): Tensor<T> public operator fun StructureND<T>.get(i: Int): Tensor<T>
/** /**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
@ -174,7 +176,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param j the second dimension to be transposed * @param j the second dimension to be transposed
* @return transposed tensor * @return transposed tensor
*/ */
public fun Tensor<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T> public fun StructureND<T>.transpose(i: Int = -2, j: Int = -1): Tensor<T>
/** /**
* Returns a new tensor with the same data as the self tensor but of a different shape. * Returns a new tensor with the same data as the self tensor but of a different shape.
@ -184,7 +186,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param shape the desired size * @param shape the desired size
* @return tensor with new shape * @return tensor with new shape
*/ */
public fun Tensor<T>.view(shape: IntArray): Tensor<T> public fun StructureND<T>.view(shape: IntArray): Tensor<T>
/** /**
* View this tensor as the same size as [other]. * View this tensor as the same size as [other].
@ -194,7 +196,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other the result tensor has the same size as other. * @param other the result tensor has the same size as other.
* @return the result tensor with the same size as other. * @return the result tensor with the same size as other.
*/ */
public fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> public fun StructureND<T>.viewAs(other: StructureND<T>): Tensor<T>
/** /**
* Matrix product of two tensors. * Matrix product of two tensors.
@ -225,7 +227,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return a mathematical product of two tensors. * @return a mathematical product of two tensors.
*/ */
public infix fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> public infix fun StructureND<T>.dot(other: StructureND<T>): Tensor<T>
/** /**
* Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2])
@ -260,7 +262,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
/** /**
* @return the sum of all elements in the input tensor. * @return the sum of all elements in the input tensor.
*/ */
public fun Tensor<T>.sum(): T public fun StructureND<T>.sum(): T
/** /**
* Returns the sum of each row of the input tensor in the given dimension [dim]. * Returns the sum of each row of the input tensor in the given dimension [dim].
@ -273,12 +275,12 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the sum of each row of the input tensor in the given dimension [dim]. * @return the sum of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.sum(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* @return the minimum value of all elements in the input tensor. * @return the minimum value of all elements in the input tensor or null if there are no values
*/ */
public fun Tensor<T>.min(): T public fun StructureND<T>.min(): T?
/** /**
* Returns the minimum value of each row of the input tensor in the given dimension [dim]. * Returns the minimum value of each row of the input tensor in the given dimension [dim].
@ -291,12 +293,12 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the minimum value of each row of the input tensor in the given dimension [dim]. * @return the minimum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* Returns the maximum value of all elements in the input tensor. * Returns the maximum value of all elements in the input tensor or null if there are no values
*/ */
public fun Tensor<T>.max(): T public fun StructureND<T>.max(): T?
/** /**
* Returns the maximum value of each row of the input tensor in the given dimension [dim]. * Returns the maximum value of each row of the input tensor in the given dimension [dim].
@ -309,7 +311,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the maximum value of each row of the input tensor in the given dimension [dim]. * @return the maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T>
/** /**
* Returns the index of maximum value of each row of the input tensor in the given dimension [dim]. * Returns the index of maximum value of each row of the input tensor in the given dimension [dim].
@ -322,9 +324,9 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param keepDim whether the output tensor has [dim] retained or not. * @param keepDim whether the output tensor has [dim] retained or not.
* @return the index of maximum value of each row of the input tensor in the given dimension [dim]. * @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/ */
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right override fun multiply(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left * right
} }

View File

@ -5,30 +5,34 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field
/** /**
* Algebra over a field with partial division on [Tensor]. * Algebra over a field with partial division on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
*/ */
public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> { public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T, A>, FieldOpsND<T, A> {
/** /**
* Each element of the tensor [other] is divided by this value. * Each element of the tensor [arg] is divided by this value.
* The resulting tensor is returned. * The resulting tensor is returned.
* *
* @param other tensor to divide by. * @param arg tensor to divide by.
* @return the division of this value by the tensor [other]. * @return the division of this value by the tensor [arg].
*/ */
public operator fun T.div(other: Tensor<T>): Tensor<T> override operator fun T.div(arg: StructureND<T>): Tensor<T>
/** /**
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor. * Divide by the scalar [arg] each element of this tensor returns a new resulting tensor.
* *
* @param value the number to divide by each element of this tensor. * @param arg the number to divide by each element of this tensor.
* @return the division of this tensor by the [value]. * @return the division of this tensor by the [arg].
*/ */
public operator fun Tensor<T>.div(value: T): Tensor<T> override operator fun StructureND<T>.div(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is divided by each element of this tensor. * Each element of the tensor [other] is divided by each element of this tensor.
@ -37,7 +41,9 @@ public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
* @param other tensor to be divided by. * @param other tensor to be divided by.
* @return the division of this tensor by [other]. * @return the division of this tensor by [other].
*/ */
public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.div(other: StructureND<T>): Tensor<T>
override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> = left.div(right)
/** /**
* Divides by the scalar [value] each element of this tensor. * Divides by the scalar [value] each element of this tensor.
@ -51,5 +57,5 @@ public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
* *
* @param other tensor to be divided by. * @param other tensor to be divided by.
*/ */
public operator fun Tensor<T>.divAssign(other: Tensor<T>) public operator fun Tensor<T>.divAssign(other: StructureND<T>)
} }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTensors
@ -18,75 +19,75 @@ import space.kscience.kmath.tensors.core.internal.tensor
*/ */
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) { override fun Tensor<Double>.plusAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) { override fun Tensor<Double>.minusAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.times(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] * newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) { override fun Tensor<Double>.timesAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] / newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.divAssign(other: Tensor<Double>) { override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }

View File

@ -23,23 +23,23 @@ public open class BufferedTensor<T> internal constructor(
/** /**
* Buffer strides based on [TensorLinearStructure] implementation * Buffer strides based on [TensorLinearStructure] implementation
*/ */
public val linearStructure: Strides public val indices: Strides
get() = TensorLinearStructure(shape) get() = TensorLinearStructure(shape)
/** /**
* Number of elements in tensor * Number of elements in tensor
*/ */
public val numElements: Int public val numElements: Int
get() = linearStructure.linearSize get() = indices.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] override fun get(index: IntArray): T = mutableBuffer[bufferStart + indices.offset(index)]
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[bufferStart + linearStructure.offset(index)] = value mutableBuffer[bufferStart + indices.offset(index)] = value
} }
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indices.indices().map {
it to get(it) it to get(it)
} }
} }

View File

@ -6,8 +6,10 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
@ -20,16 +22,71 @@ import kotlin.math.*
* Implementation of basic operations over double tensors and basic algebra operations on them. * Implementation of basic operations over double tensors and basic algebra operations on them.
*/ */
public open class DoubleTensorAlgebra : public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>, TensorPartialDivisionAlgebra<Double, DoubleField>,
AnalyticTensorAlgebra<Double>, AnalyticTensorAlgebra<Double, DoubleField>,
LinearOpsTensorAlgebra<Double> { LinearOpsTensorAlgebra<Double, DoubleField> {
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) override val elementAlgebra: DoubleField
get() = DoubleField
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): DoubleTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) }
return DoubleTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) }
return DoubleTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double
): DoubleTensor {
require(left.shape.contentEquals(right.shape)){
"The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}"
}
val leftTensor = left.tensor
val leftArray = leftTensor.copyArray()
val rightTensor = right.tensor
val rightArray = rightTensor.copyArray()
val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) }
return DoubleTensor(
leftTensor.shape,
array
)
}
override fun StructureND<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun Tensor<Double>.value(): Double = valueOrNull() override fun StructureND<Double>.value(): Double = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
/** /**
@ -53,13 +110,12 @@ public open class DoubleTensorAlgebra :
* @param initializer mapping tensor indices to values. * @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by the [initializer]. * @return tensor with the [shape] shape and data generated by the [initializer].
*/ */
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor = override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray(
fromArray(
shape, shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray() TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
) )
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override operator fun StructureND<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = tensor.shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart
@ -104,7 +160,7 @@ public open class DoubleTensorAlgebra :
* *
* @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor.
*/ */
public fun Tensor<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun StructureND<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0)
/** /**
* Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape].
@ -142,20 +198,19 @@ public open class DoubleTensorAlgebra :
* *
* @return a copy of the `input` tensor with a copied buffer. * @return a copy of the `input` tensor with a copied buffer.
*/ */
public fun Tensor<Double>.copy(): DoubleTensor { public fun StructureND<Double>.copy(): DoubleTensor =
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
}
override fun Double.plus(other: Tensor<Double>): DoubleTensor { override fun Double.plus(other: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor override fun StructureND<Double>.plus(arg: Double): DoubleTensor = arg + tensor
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
@ -169,7 +224,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) { override fun Tensor<Double>.plusAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
@ -177,21 +232,21 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.minus(other: Tensor<Double>): DoubleTensor { override fun Double.minus(other: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun Tensor<Double>.minus(value: Double): DoubleTensor { override fun StructureND<Double>.minus(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
@ -205,7 +260,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) { override fun Tensor<Double>.minusAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
@ -213,16 +268,16 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.times(other: Tensor<Double>): DoubleTensor { override fun Double.times(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor override fun StructureND<Double>.times(arg: Double): DoubleTensor = arg * tensor
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.times(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] * tensor.mutableBuffer.array()[tensor.bufferStart + i] *
@ -237,7 +292,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) { override fun Tensor<Double>.timesAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
@ -245,21 +300,21 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.div(other: Tensor<Double>): DoubleTensor { override fun Double.div(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.tensor.numElements) { i ->
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun Tensor<Double>.div(value: Double): DoubleTensor { override fun StructureND<Double>.div(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg
} }
return DoubleTensor(shape, resBuffer) return DoubleTensor(shape, resBuffer)
} }
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
@ -274,7 +329,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.divAssign(other: Tensor<Double>) { override fun Tensor<Double>.divAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
@ -282,14 +337,14 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.unaryMinus(): DoubleTensor { override fun StructureND<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor { override fun StructureND<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i) val ii = tensor.minusIndex(i)
val jj = tensor.minusIndex(j) val jj = tensor.minusIndex(j)
checkTranspose(tensor.dimension, ii, jj) checkTranspose(tensor.dimension, ii, jj)
@ -302,26 +357,26 @@ public open class DoubleTensorAlgebra :
val resTensor = DoubleTensor(resShape, resBuffer) val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) { for (offset in 0 until n) {
val oldMultiIndex = tensor.linearStructure.index(offset) val oldMultiIndex = tensor.indices.index(offset)
val newMultiIndex = oldMultiIndex.copyOf() val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex) val linearIndex = resTensor.indices.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset] tensor.mutableBuffer.array()[tensor.bufferStart + offset]
} }
return resTensor return resTensor
} }
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor { override fun StructureND<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape) checkView(tensor, shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart)
} }
override fun Tensor<Double>.viewAs(other: Tensor<Double>): DoubleTensor = override fun StructureND<Double>.viewAs(other: StructureND<Double>): DoubleTensor =
tensor.view(other.shape) tensor.view(other.shape)
override infix fun Tensor<Double>.dot(other: Tensor<Double>): DoubleTensor { override infix fun StructureND<Double>.dot(other: StructureND<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (tensor.shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum()))
} }
@ -406,7 +461,7 @@ public open class DoubleTensorAlgebra :
val resTensor = zeros(resShape) val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) { for (i in 0 until diagonalEntries.tensor.numElements) {
val multiIndex = diagonalEntries.tensor.linearStructure.index(i) val multiIndex = diagonalEntries.tensor.indices.index(i)
var offset1 = 0 var offset1 = 0
var offset2 = abs(realOffset) var offset2 = abs(realOffset)
@ -425,18 +480,6 @@ public open class DoubleTensorAlgebra :
return resTensor.tensor return resTensor.tensor
} }
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
public inline fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor = DoubleTensor(
tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart
)
/** /**
* Compares element-wise two tensors with a specified precision. * Compares element-wise two tensors with a specified precision.
* *
@ -525,10 +568,10 @@ public open class DoubleTensorAlgebra :
*/ */
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray()) foldFunction(tensor.copyArray())
internal inline fun Tensor<Double>.foldDim( internal inline fun StructureND<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> Double,
dim: Int, dim: Int,
keepDim: Boolean, keepDim: Boolean,
@ -541,7 +584,7 @@ public open class DoubleTensorAlgebra :
} }
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
for (index in resTensor.linearStructure.indices()) { for (index in resTensor.indices.indices()) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
@ -552,30 +595,30 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.sum(): Double = tensor.fold { it.sum() } override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() }
override fun Tensor<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.sum() }, dim, keepDim) foldDim({ x -> x.sum() }, dim, keepDim)
override fun Tensor<Double>.min(): Double = this.fold { it.minOrNull()!! } override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun Tensor<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.minOrNull()!! }, dim, keepDim) foldDim({ x -> x.minOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun Tensor<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> x.maxOrNull()!! }, dim, keepDim) foldDim({ x -> x.maxOrNull()!! }, dim, keepDim)
override fun Tensor<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim({ x -> foldDim({ x ->
x.withIndex().maxByOrNull { it.value }?.index!!.toDouble() x.withIndex().maxByOrNull { it.value }?.index!!.toDouble()
}, dim, keepDim) }, dim, keepDim)
override fun Tensor<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements } override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements }
override fun Tensor<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim( foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
@ -585,12 +628,12 @@ public open class DoubleTensorAlgebra :
keepDim keepDim
) )
override fun Tensor<Double>.std(): Double = this.fold { arr -> override fun StructureND<Double>.std(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1))
} }
override fun Tensor<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
@ -600,12 +643,12 @@ public open class DoubleTensorAlgebra :
keepDim keepDim
) )
override fun Tensor<Double>.variance(): Double = this.fold { arr -> override fun StructureND<Double>.variance(): Double = this.fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / tensor.numElements
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)
} }
override fun Tensor<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
{ arr -> { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
@ -628,7 +671,7 @@ public open class DoubleTensorAlgebra :
* @param tensors the [List] of 1-dimensional tensors with same shape * @param tensors the [List] of 1-dimensional tensors with same shape
* @return `M`. * @return `M`.
*/ */
public fun cov(tensors: List<Tensor<Double>>): DoubleTensor { public fun cov(tensors: List<StructureND<Double>>): DoubleTensor {
check(tensors.isNotEmpty()) { "List must have at least 1 element" } check(tensors.isNotEmpty()) { "List must have at least 1 element" }
val n = tensors.size val n = tensors.size
val m = tensors[0].shape[0] val m = tensors[0].shape[0]
@ -645,43 +688,43 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp) override fun StructureND<Double>.exp(): DoubleTensor = tensor.map { exp(it) }
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map(::ln) override fun StructureND<Double>.ln(): DoubleTensor = tensor.map { ln(it) }
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt) override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) }
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos) override fun StructureND<Double>.cos(): DoubleTensor = tensor.map { cos(it) }
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos) override fun StructureND<Double>.acos(): DoubleTensor = tensor.map { acos(it) }
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh) override fun StructureND<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) }
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh) override fun StructureND<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) }
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin) override fun StructureND<Double>.sin(): DoubleTensor = tensor.map { sin(it) }
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin) override fun StructureND<Double>.asin(): DoubleTensor = tensor.map { asin(it) }
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh) override fun StructureND<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) }
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh) override fun StructureND<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) }
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan) override fun StructureND<Double>.tan(): DoubleTensor = tensor.map { tan(it) }
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan) override fun StructureND<Double>.atan(): DoubleTensor = tensor.map { atan(it) }
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh) override fun StructureND<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) }
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh) override fun StructureND<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) }
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil) override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) }
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor) override fun StructureND<Double>.floor(): DoubleTensor = tensor.map { floor(it) }
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9) override fun StructureND<Double>.inv(): DoubleTensor = invLU(1e-9)
override fun Tensor<Double>.det(): DoubleTensor = detLU(1e-9) override fun StructureND<Double>.det(): DoubleTensor = detLU(1e-9)
/** /**
* Computes the LU factorization of a matrix or batches of matrices `input`. * Computes the LU factorization of a matrix or batches of matrices `input`.
@ -692,7 +735,7 @@ public open class DoubleTensorAlgebra :
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. * The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/ */
public fun Tensor<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun StructureND<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(tensor, epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
@ -705,7 +748,7 @@ public open class DoubleTensorAlgebra :
* The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor. * The `factorization` has the shape ``(*, m, n)``, where``(*, m, n)`` is the shape of the `input` tensor.
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. * The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/ */
public fun Tensor<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9) public fun StructureND<Double>.luFactor(): Pair<DoubleTensor, IntTensor> = luFactor(1e-9)
/** /**
* Unpacks the data and pivots from a LU factorization of a tensor. * Unpacks the data and pivots from a LU factorization of a tensor.
@ -719,7 +762,7 @@ public open class DoubleTensorAlgebra :
* @return triple of `P`, `L` and `U` tensors * @return triple of `P`, `L` and `U` tensors
*/ */
public fun luPivot( public fun luPivot(
luTensor: Tensor<Double>, luTensor: StructureND<Double>,
pivotsTensor: Tensor<Int>, pivotsTensor: Tensor<Int>,
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { ): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape) checkSquareMatrix(luTensor.shape)
@ -762,7 +805,7 @@ public open class DoubleTensorAlgebra :
* Used when checking the positive definiteness of the input matrix or matrices. * Used when checking the positive definiteness of the input matrix or matrices.
* @return a pair of `Q` and `R` tensors. * @return a pair of `Q` and `R` tensors.
*/ */
public fun Tensor<Double>.cholesky(epsilon: Double): DoubleTensor { public fun StructureND<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon) checkPositiveDefinite(tensor, epsilon)
@ -775,9 +818,9 @@ public open class DoubleTensorAlgebra :
return lTensor return lTensor
} }
override fun Tensor<Double>.cholesky(): DoubleTensor = cholesky(1e-6) override fun StructureND<Double>.cholesky(): DoubleTensor = cholesky(1e-6)
override fun Tensor<Double>.qr(): Pair<DoubleTensor, DoubleTensor> { override fun StructureND<Double>.qr(): Pair<DoubleTensor, DoubleTensor> {
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
val rTensor = zeroesLike() val rTensor = zeroesLike()
@ -793,7 +836,7 @@ public open class DoubleTensorAlgebra :
return qTensor to rTensor return qTensor to rTensor
} }
override fun Tensor<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = override fun StructureND<Double>.svd(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> =
svd(epsilon = 1e-10) svd(epsilon = 1e-10)
/** /**
@ -809,7 +852,7 @@ public open class DoubleTensorAlgebra :
* i.e., the precision with which the cosine approaches 1 in an iterative algorithm. * i.e., the precision with which the cosine approaches 1 in an iterative algorithm.
* @return a triple `Triple(U, S, V)`. * @return a triple `Triple(U, S, V)`.
*/ */
public fun Tensor<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.dimension val size = tensor.dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = tensor.shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = tensor.shape.sliceArray(size - 2 until size)
@ -842,7 +885,7 @@ public open class DoubleTensorAlgebra :
return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) return Triple(uTensor.transpose(), sTensor, vTensor.transpose())
} }
override fun Tensor<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEig(epsilon = 1e-15) override fun StructureND<Double>.symEig(): Pair<DoubleTensor, DoubleTensor> = symEig(epsilon = 1e-15)
/** /**
* Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices,
@ -852,7 +895,7 @@ public open class DoubleTensorAlgebra :
* and when the cosine approaches 1 in the SVD algorithm. * and when the cosine approaches 1 in the SVD algorithm.
* @return a pair `eigenvalues to eigenvectors`. * @return a pair `eigenvalues to eigenvectors`.
*/ */
public fun Tensor<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun StructureND<Double>.symEig(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(tensor, epsilon)
fun MutableStructure2D<Double>.cleanSym(n: Int) { fun MutableStructure2D<Double>.cleanSym(n: Int) {
@ -887,7 +930,7 @@ public open class DoubleTensorAlgebra :
* with zero. * with zero.
* @return the determinant. * @return the determinant.
*/ */
public fun Tensor<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun StructureND<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape) checkSquareMatrix(tensor.shape)
val luTensor = tensor.copy() val luTensor = tensor.copy()
val pivotsTensor = tensor.setUpPivots() val pivotsTensor = tensor.setUpPivots()
@ -920,7 +963,7 @@ public open class DoubleTensorAlgebra :
* @param epsilon error in the LU algorithm&mdash;permissible error when comparing the determinant of a matrix with zero * @param epsilon error in the LU algorithm&mdash;permissible error when comparing the determinant of a matrix with zero
* @return the multiplicative inverse of a matrix. * @return the multiplicative inverse of a matrix.
*/ */
public fun Tensor<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor { public fun StructureND<Double>.invLU(epsilon: Double = 1e-9): DoubleTensor {
val (luTensor, pivotsTensor) = luFactor(epsilon) val (luTensor, pivotsTensor) = luFactor(epsilon)
val invTensor = luTensor.zeroesLike() val invTensor = luTensor.zeroesLike()
@ -945,12 +988,12 @@ public open class DoubleTensorAlgebra :
* @param epsilon permissible error when comparing the determinant of a matrix with zero. * @param epsilon permissible error when comparing the determinant of a matrix with zero.
* @return triple of `P`, `L` and `U` tensors. * @return triple of `P`, `L` and `U` tensors.
*/ */
public fun Tensor<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = tensor.luFactor(epsilon) val (lu, pivots) = tensor.luFactor(epsilon)
return luPivot(lu, pivots) return luPivot(lu, pivots)
} }
override fun Tensor<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9) override fun StructureND<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
} }
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra

View File

@ -10,7 +10,7 @@ import kotlin.math.max
internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) { internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
for (linearIndex in 0 until linearSize) { for (linearIndex in 0 until linearSize) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex) val totalMultiIndex = resTensor.indices.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf() val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size val offset = totalMultiIndex.size - curMultiIndex.size
@ -23,7 +23,7 @@ internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTenso
} }
} }
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex) val curLinearIndex = tensor.indices.offset(curMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex] tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
} }
@ -112,7 +112,7 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTen
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize)) val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
for (linearIndex in 0 until n) { for (linearIndex in 0 until n) {
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex) val totalMultiIndex = outerTensor.indices.index(linearIndex)
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf() var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
@ -127,13 +127,13 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTen
} }
for (i in 0 until matrixSize) { for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset( val curLinearIndex = newTensor.indices.offset(
curMultiIndex + curMultiIndex +
matrix.linearStructure.index(i) matrix.indices.index(i)
) )
val newLinearIndex = resTensor.linearStructure.offset( val newLinearIndex = resTensor.indices.offset(
totalMultiIndex + totalMultiIndex +
matrix.linearStructure.index(i) matrix.indices.index(i)
) )
resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] = resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.tensors.core.internal package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
@ -25,7 +26,7 @@ internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided" "Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
} }
internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) = internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>) =
check(a.shape contentEquals b.shape) { check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
} }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core.internal package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableBufferND import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor import space.kscience.kmath.tensors.core.BufferedTensor
@ -18,15 +19,15 @@ internal fun BufferedTensor<Int>.asTensor(): IntTensor =
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> = internal fun <T> StructureND<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor( BufferedTensor(
this.shape, this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
) )
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> StructureND<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) { is MutableBufferND<T> -> if (this.indices == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.buffer, 0) BufferedTensor(this.shape, this.buffer, 0)
} else { } else {
this.copyToBufferedTensor() this.copyToBufferedTensor()
@ -35,7 +36,7 @@ internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
} }
@PublishedApi @PublishedApi
internal val Tensor<Double>.tensor: DoubleTensor internal val StructureND<Double>.tensor: DoubleTensor
get() = when (this) { get() = when (this) {
is DoubleTensor -> this is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor() else -> this.toBufferedTensor().asTensor()

View File

@ -85,7 +85,7 @@ internal fun format(value: Double, digits: Int = 4): String = buildString {
internal fun DoubleTensor.toPrettyString(): String = buildString { internal fun DoubleTensor.toPrettyString(): String = buildString {
var offset = 0 var offset = 0
val shape = this@toPrettyString.shape val shape = this@toPrettyString.shape
val linearStructure = this@toPrettyString.linearStructure val linearStructure = this@toPrettyString.indices
val vectorSize = shape.last() val vectorSize = shape.last()
append("DoubleTensor(\n") append("DoubleTensor(\n")
var charOffset = 3 var charOffset = 3

View File

@ -19,18 +19,19 @@ public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
/** /**
* Returns [DoubleArray] of tensor elements * Returns a copy-protected [DoubleArray] of tensor elements
*/ */
public fun DoubleTensor.toDoubleArray(): DoubleArray { public fun DoubleTensor.copyArray(): DoubleArray {
//TODO use ArrayCopy
return DoubleArray(numElements) { i -> return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i] mutableBuffer[bufferStart + i]
} }
} }
/** /**
* Returns [IntArray] of tensor elements * Returns a copy-protected [IntArray] of tensor elements
*/ */
public fun IntTensor.toIntArray(): IntArray { public fun IntTensor.copyArray(): IntArray {
return IntArray(numElements) { i -> return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i] mutableBuffer[bufferStart + i]
} }