Refactor TensorAlgebra to take StructureND and inherit AlgebraND

This commit is contained in:
Alexander Nozik 2021-10-26 09:16:24 +03:00
parent 47aeb36979
commit 7e59ec5804
21 changed files with 258 additions and 197 deletions

View File

@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
this@StreamDoubleFieldND.shape, this@StreamDoubleFieldND.shape,
shape shape
) )
this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer this is BufferND && this.indices == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }

View File

@ -9,7 +9,7 @@ import space.kscience.kmath.operations.invoke
import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra import space.kscience.kmath.tensors.core.BroadcastDoubleTensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
import space.kscience.kmath.tensors.core.toDoubleArray import space.kscience.kmath.tensors.core.copyArray
import kotlin.math.sqrt import kotlin.math.sqrt
const val seed = 100500L const val seed = 100500L
@ -111,7 +111,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra { private fun softMaxLoss(yPred: DoubleTensor, yTrue: DoubleTensor): DoubleTensor = BroadcastDoubleTensorAlgebra {
val onesForAnswers = yPred.zeroesLike() val onesForAnswers = yPred.zeroesLike()
yTrue.toDoubleArray().forEachIndexed { index, labelDouble -> yTrue.copyArray().forEachIndexed { index, labelDouble ->
val label = labelDouble.toInt() val label = labelDouble.toInt()
onesForAnswers[intArrayOf(index, label)] = 1.0 onesForAnswers[intArrayOf(index, label)] = 1.0
} }

View File

@ -131,19 +131,19 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
* Adds an element to ND structure of it. * Adds an element to ND structure of it.
* *
* @receiver the augend. * @receiver the augend.
* @param arg the addend. * @param other the addend.
* @return the sum. * @return the sum.
*/ */
public operator fun T.plus(arg: StructureND<T>): StructureND<T> = arg + this public operator fun T.plus(other: StructureND<T>): StructureND<T> = other.map { value -> add(this@plus, value) }
/** /**
* Subtracts an ND structure from an element of it. * Subtracts an ND structure from an element of it.
* *
* @receiver the dividend. * @receiver the dividend.
* @param arg the divisor. * @param other the divisor.
* @return the quotient. * @return the quotient.
*/ */
public operator fun T.minus(arg: StructureND<T>): StructureND<T> = arg.map { value -> add(-this@minus, value) } public operator fun T.minus(other: StructureND<T>): StructureND<T> = other.map { value -> add(-this@minus, value) }
public companion object public companion object
} }

View File

@ -51,7 +51,7 @@ public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
arg: BufferND<T>, arg: BufferND<T>,
crossinline transform: A.(T) -> T crossinline transform: A.(T) -> T
): BufferND<T> { ): BufferND<T> {
val indexes = arg.indexes val indexes = arg.indices
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform)) return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
} }
@ -59,7 +59,7 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
arg: BufferND<T>, arg: BufferND<T>,
crossinline transform: A.(index: IntArray, arg: T) -> T crossinline transform: A.(index: IntArray, arg: T) -> T
): BufferND<T> { ): BufferND<T> {
val indexes = arg.indexes val indexes = arg.indices
return BufferND( return BufferND(
indexes, indexes,
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value -> bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
@ -73,8 +73,8 @@ internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
r: BufferND<T>, r: BufferND<T>,
crossinline block: A.(l: T, r: T) -> T crossinline block: A.(l: T, r: T) -> T
): BufferND<T> { ): BufferND<T> {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes val indexes = l.indices
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block)) return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
} }

View File

@ -15,20 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
* Represents [StructureND] over [Buffer]. * Represents [StructureND] over [Buffer].
* *
* @param T the type of items. * @param T the type of items.
* @param indexes The strides to access elements of [Buffer] by linear indices. * @param indices The strides to access elements of [Buffer] by linear indices.
* @param buffer The underlying buffer. * @param buffer The underlying buffer.
*/ */
public open class BufferND<out T>( public open class BufferND<out T>(
public val indexes: ShapeIndexer, public val indices: ShapeIndexer,
public open val buffer: Buffer<T>, public open val buffer: Buffer<T>,
) : StructureND<T> { ) : StructureND<T> {
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)] override operator fun get(index: IntArray): T = buffer[indices.offset(index)]
override val shape: IntArray get() = indexes.shape override val shape: IntArray get() = indices.shape
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indices.indices().map {
it to this[it] it to this[it]
} }
@ -43,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): BufferND<R> { ): BufferND<R> {
return if (this is BufferND<T>) return if (this is BufferND<T>)
BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) BufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
@ -62,7 +62,7 @@ public class MutableBufferND<T>(
override val buffer: MutableBuffer<T>, override val buffer: MutableBuffer<T>,
) : MutableStructureND<T>, BufferND<T>(strides, buffer) { ) : MutableStructureND<T>, BufferND<T>(strides, buffer) {
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
buffer[indexes.offset(index)] = value buffer[indices.offset(index)] = value
} }
} }
@ -74,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
crossinline transform: (T) -> R, crossinline transform: (T) -> R,
): MutableBufferND<R> { ): MutableBufferND<R> {
return if (this is MutableBufferND<T>) return if (this is MutableBufferND<T>)
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) }) MutableBufferND(this.indices, factory.invoke(indices.linearSize) { transform(buffer[it]) })
else { else {
val strides = DefaultStrides(shape) val strides = DefaultStrides(shape)
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) }) MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })

View File

@ -33,7 +33,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
arg: DoubleBufferND, arg: DoubleBufferND,
transform: (Double) -> Double transform: (Double) -> Double
): DoubleBufferND { ): DoubleBufferND {
val indexes = arg.indexes val indexes = arg.indices
val array = arg.buffer.array val array = arg.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) }) return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { transform(array[it]) })
} }
@ -43,8 +43,8 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
r: DoubleBufferND, r: DoubleBufferND,
block: (l: Double, r: Double) -> Double block: (l: Double, r: Double) -> Double
): DoubleBufferND { ): DoubleBufferND {
require(l.indexes == r.indexes) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" } require(l.indices == r.indices) { "Zip requires the same shapes, but found ${l.shape} on the left and ${r.shape} on the right" }
val indexes = l.indexes val indexes = l.indices
val lArray = l.buffer.array val lArray = l.buffer.array
val rArray = r.buffer.array val rArray = r.buffer.array
return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) }) return DoubleBufferND(indexes, DoubleBuffer(indexes.linearSize) { block(lArray[it], rArray[it]) })

View File

@ -71,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided
@ -87,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
if (st1 === st2) return true if (st1 === st2) return true
// fast comparison of buffers if possible // fast comparison of buffers if possible
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes) if (st1 is BufferND && st2 is BufferND && st1.indices == st2.indices)
return Buffer.contentEquals(st1.buffer, st2.buffer) return Buffer.contentEquals(st1.buffer, st2.buffer)
//element by element comparison if it could not be avoided //element by element comparison if it could not be avoided

View File

@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
* Map one [BufferND] using function without indices. * Map one [BufferND] using function without indices.
*/ */
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> { public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) } val array = DoubleArray(indices.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
return BufferND(indexes, DoubleBuffer(array)) return BufferND(indices, DoubleBuffer(array))
} }
/** /**

View File

@ -16,6 +16,7 @@ import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.Shape import space.kscience.kmath.nd.Shape
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.* import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
@ -49,17 +50,17 @@ private fun <T, D : Dimension> MultiArray<T, D>.asD2Array(): D2Array<T> {
else throw ClassCastException("Cannot cast MultiArray to NDArray.") else throw ClassCastException("Cannot cast MultiArray to NDArray.")
} }
public class MultikTensorAlgebra<T : Number> internal constructor( public class MultikTensorAlgebra<T : Number, A: Ring<T>> internal constructor(
public val type: DataType, public val type: DataType,
public val elementAlgebra: Ring<T>, override val elementAlgebra: A,
public val comparator: Comparator<T> public val comparator: Comparator<T>
) : TensorAlgebra<T> { ) : TensorAlgebra<T, A> {
/** /**
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source * are not reflected back onto the source
*/ */
public fun Tensor<T>.asMultik(): MultikTensor<T> { public fun StructureND<T>.asMultik(): MultikTensor<T> {
return if (this is MultikTensor) { return if (this is MultikTensor) {
this this
} else { } else {
@ -73,17 +74,17 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this) public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { override fun StructureND<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0)) get(intArrayOf(0))
} else null } else null
override fun T.plus(other: Tensor<T>): MultikTensor<T> = override fun T.plus(other: StructureND<T>): MultikTensor<T> =
other.plus(this) other.plus(this)
override fun Tensor<T>.plus(value: T): MultikTensor<T> = override fun StructureND<T>.plus(value: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { plusAssign(value) }.wrap() asMultik().array.deepCopy().apply { plusAssign(value) }.wrap()
override fun Tensor<T>.plus(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.plus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.plus(other.asMultik().array).wrap() asMultik().array.plus(other.asMultik().array).wrap()
override fun Tensor<T>.plusAssign(value: T) { override fun Tensor<T>.plusAssign(value: T) {
@ -94,7 +95,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.plusAssign(other: Tensor<T>) { override fun Tensor<T>.plusAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.plusAssign(other.asMultik().array) array.plusAssign(other.asMultik().array)
} else { } else {
@ -102,12 +103,12 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun T.minus(other: Tensor<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap() override fun T.minus(other: StructureND<T>): MultikTensor<T> = (-(other.asMultik().array - this)).wrap()
override fun Tensor<T>.minus(value: T): MultikTensor<T> = override fun StructureND<T>.minus(arg: T): MultikTensor<T> =
asMultik().array.deepCopy().apply { minusAssign(value) }.wrap() asMultik().array.deepCopy().apply { minusAssign(arg) }.wrap()
override fun Tensor<T>.minus(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.minus(other: StructureND<T>): MultikTensor<T> =
asMultik().array.minus(other.asMultik().array).wrap() asMultik().array.minus(other.asMultik().array).wrap()
override fun Tensor<T>.minusAssign(value: T) { override fun Tensor<T>.minusAssign(value: T) {
@ -118,7 +119,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.minusAssign(other: Tensor<T>) { override fun Tensor<T>.minusAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.minusAssign(other.asMultik().array) array.minusAssign(other.asMultik().array)
} else { } else {
@ -126,13 +127,13 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun T.times(other: Tensor<T>): MultikTensor<T> = override fun T.times(arg: StructureND<T>): MultikTensor<T> =
other.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap() arg.asMultik().array.deepCopy().apply { timesAssign(this@times) }.wrap()
override fun Tensor<T>.times(value: T): Tensor<T> = override fun StructureND<T>.times(arg: T): Tensor<T> =
asMultik().array.deepCopy().apply { timesAssign(value) }.wrap() asMultik().array.deepCopy().apply { timesAssign(arg) }.wrap()
override fun Tensor<T>.times(other: Tensor<T>): MultikTensor<T> = override fun StructureND<T>.times(other: StructureND<T>): MultikTensor<T> =
asMultik().array.times(other.asMultik().array).wrap() asMultik().array.times(other.asMultik().array).wrap()
override fun Tensor<T>.timesAssign(value: T) { override fun Tensor<T>.timesAssign(value: T) {
@ -143,7 +144,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.timesAssign(other: Tensor<T>) { override fun Tensor<T>.timesAssign(other: StructureND<T>) {
if (this is MultikTensor) { if (this is MultikTensor) {
array.timesAssign(other.asMultik().array) array.timesAssign(other.asMultik().array)
} else { } else {
@ -151,7 +152,7 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
override fun Tensor<T>.unaryMinus(): MultikTensor<T> = override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap() asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap() override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
@ -224,17 +225,17 @@ public class MultikTensorAlgebra<T : Number> internal constructor(
} }
} }
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double> public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double, DoubleField>
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) } get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) }
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float> public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float, FloatField>
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) } get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short> public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short, ShortRing>
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) } get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) }
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int> public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int, IntRing>
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) } get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) }
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long> public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long, LongRing>
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) } get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) }

View File

@ -5,13 +5,15 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Field
/** /**
* Analytic operations on [Tensor]. * Analytic operations on [Tensor].
* *
* @param T the type of items closed under analytic functions in the tensors. * @param T the type of items closed under analytic functions in the tensors.
*/ */
public interface AnalyticTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> { public interface AnalyticTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionAlgebra<T, A> {
/** /**
* @return the mean of all elements in the input tensor. * @return the mean of all elements in the input tensor.

View File

@ -5,12 +5,14 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.Field
/** /**
* Common linear algebra operations. Operates on [Tensor]. * Common linear algebra operations. Operates on [Tensor].
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
*/ */
public interface LinearOpsTensorAlgebra<T> : TensorPartialDivisionAlgebra<T> { public interface LinearOpsTensorAlgebra<T, A : Field<T>> : TensorPartialDivisionAlgebra<T, A> {
/** /**
* Computes the determinant of a square matrix input, or of each square matrix in a batched input. * Computes the determinant of a square matrix input, or of each square matrix in a batched input.

View File

@ -5,7 +5,9 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.operations.RingOps import space.kscience.kmath.nd.RingOpsND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Ring
/** /**
* Algebra over a ring on [Tensor]. * Algebra over a ring on [Tensor].
@ -13,20 +15,20 @@ import space.kscience.kmath.operations.RingOps
* *
* @param T the type of items in the tensors. * @param T the type of items in the tensors.
*/ */
public interface TensorAlgebra<T> : RingOps<Tensor<T>> { public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
/** /**
* Returns a single tensor value of unit dimension if tensor shape equals to [1]. * Returns a single tensor value of unit dimension if tensor shape equals to [1].
* *
* @return a nullable value of a potentially scalar tensor. * @return a nullable value of a potentially scalar tensor.
*/ */
public fun Tensor<T>.valueOrNull(): T? public fun StructureND<T>.valueOrNull(): T?
/** /**
* Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1].
* *
* @return the value of a scalar tensor. * @return the value of a scalar tensor.
*/ */
public fun Tensor<T>.value(): T = public fun StructureND<T>.value(): T =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape") valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/** /**
@ -36,7 +38,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this value and tensor [other]. * @return the sum of this value and tensor [other].
*/ */
public operator fun T.plus(other: Tensor<T>): Tensor<T> override operator fun T.plus(other: StructureND<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor.
@ -44,7 +46,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param value the number to be added to each element of this tensor. * @param value the number to be added to each element of this tensor.
* @return the sum of this tensor and [value]. * @return the sum of this tensor and [value].
*/ */
public operator fun Tensor<T>.plus(value: T): Tensor<T> override operator fun StructureND<T>.plus(value: T): Tensor<T>
/** /**
* Each element of the tensor [other] is added to each element of this tensor. * Each element of the tensor [other] is added to each element of this tensor.
@ -53,7 +55,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be added. * @param other tensor to be added.
* @return the sum of this tensor and [other]. * @return the sum of this tensor and [other].
*/ */
override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.plus(other: StructureND<T>): Tensor<T>
/** /**
* Adds the scalar [value] to each element of this tensor. * Adds the scalar [value] to each element of this tensor.
@ -67,7 +69,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be added. * @param other tensor to be added.
*/ */
public operator fun Tensor<T>.plusAssign(other: Tensor<T>) public operator fun Tensor<T>.plusAssign(other: StructureND<T>)
/** /**
* Each element of the tensor [other] is subtracted from this value. * Each element of the tensor [other] is subtracted from this value.
@ -76,15 +78,15 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this value and tensor [other]. * @return the difference between this value and tensor [other].
*/ */
public operator fun T.minus(other: Tensor<T>): Tensor<T> override operator fun T.minus(other: StructureND<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. * Subtracts the scalar [arg] from each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be subtracted from each element of this tensor. * @param arg the number to be subtracted from each element of this tensor.
* @return the difference between this tensor and [value]. * @return the difference between this tensor and [arg].
*/ */
public operator fun Tensor<T>.minus(value: T): Tensor<T> override operator fun StructureND<T>.minus(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is subtracted from each element of this tensor. * Each element of the tensor [other] is subtracted from each element of this tensor.
@ -93,7 +95,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
* @return the difference between this tensor and [other]. * @return the difference between this tensor and [other].
*/ */
override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.minus(other: StructureND<T>): Tensor<T>
/** /**
* Subtracts the scalar [value] from each element of this tensor. * Subtracts the scalar [value] from each element of this tensor.
@ -107,25 +109,25 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be subtracted. * @param other tensor to be subtracted.
*/ */
public operator fun Tensor<T>.minusAssign(other: Tensor<T>) public operator fun Tensor<T>.minusAssign(other: StructureND<T>)
/** /**
* Each element of the tensor [other] is multiplied by this value. * Each element of the tensor [arg] is multiplied by this value.
* The resulting tensor is returned. * The resulting tensor is returned.
* *
* @param other tensor to be multiplied. * @param arg tensor to be multiplied.
* @return the product of this value and tensor [other]. * @return the product of this value and tensor [arg].
*/ */
public operator fun T.times(other: Tensor<T>): Tensor<T> override operator fun T.times(arg: StructureND<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. * Multiplies the scalar [arg] by each element of this tensor and returns a new resulting tensor.
* *
* @param value the number to be multiplied by each element of this tensor. * @param arg the number to be multiplied by each element of this tensor.
* @return the product of this tensor and [value]. * @return the product of this tensor and [arg].
*/ */
public operator fun Tensor<T>.times(value: T): Tensor<T> override operator fun StructureND<T>.times(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is multiplied by each element of this tensor. * Each element of the tensor [other] is multiplied by each element of this tensor.
@ -134,7 +136,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
* @return the product of this tensor and [other]. * @return the product of this tensor and [other].
*/ */
override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.times(other: StructureND<T>): Tensor<T>
/** /**
* Multiplies the scalar [value] by each element of this tensor. * Multiplies the scalar [value] by each element of this tensor.
@ -148,14 +150,14 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
* *
* @param other tensor to be multiplied. * @param other tensor to be multiplied.
*/ */
public operator fun Tensor<T>.timesAssign(other: Tensor<T>) public operator fun Tensor<T>.timesAssign(other: StructureND<T>)
/** /**
* Numerical negative, element-wise. * Numerical negative, element-wise.
* *
* @return tensor negation of the original tensor. * @return tensor negation of the original tensor.
*/ */
override fun Tensor<T>.unaryMinus(): Tensor<T> override operator fun StructureND<T>.unaryMinus(): Tensor<T>
/** /**
* Returns the tensor at index i * Returns the tensor at index i
@ -324,7 +326,7 @@ public interface TensorAlgebra<T> : RingOps<Tensor<T>> {
*/ */
public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> public fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T>
override fun add(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left + right override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
override fun multiply(left: Tensor<T>, right: Tensor<T>): Tensor<T> = left * right override fun multiply(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left * right
} }

View File

@ -5,30 +5,34 @@
package space.kscience.kmath.tensors.api package space.kscience.kmath.tensors.api
import space.kscience.kmath.nd.FieldOpsND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.operations.Field
/** /**
* Algebra over a field with partial division on [Tensor]. * Algebra over a field with partial division on [Tensor].
* For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra
* *
* @param T the type of items closed under division in the tensors. * @param T the type of items closed under division in the tensors.
*/ */
public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> { public interface TensorPartialDivisionAlgebra<T, A : Field<T>> : TensorAlgebra<T, A>, FieldOpsND<T, A> {
/** /**
* Each element of the tensor [other] is divided by this value. * Each element of the tensor [arg] is divided by this value.
* The resulting tensor is returned. * The resulting tensor is returned.
* *
* @param other tensor to divide by. * @param arg tensor to divide by.
* @return the division of this value by the tensor [other]. * @return the division of this value by the tensor [arg].
*/ */
public operator fun T.div(other: Tensor<T>): Tensor<T> override operator fun T.div(arg: StructureND<T>): Tensor<T>
/** /**
* Divide by the scalar [value] each element of this tensor returns a new resulting tensor. * Divide by the scalar [arg] each element of this tensor returns a new resulting tensor.
* *
* @param value the number to divide by each element of this tensor. * @param arg the number to divide by each element of this tensor.
* @return the division of this tensor by the [value]. * @return the division of this tensor by the [arg].
*/ */
public operator fun Tensor<T>.div(value: T): Tensor<T> override operator fun StructureND<T>.div(arg: T): Tensor<T>
/** /**
* Each element of the tensor [other] is divided by each element of this tensor. * Each element of the tensor [other] is divided by each element of this tensor.
@ -37,7 +41,9 @@ public interface TensorPartialDivisionAlgebra<T> : TensorAlgebra<T> {
* @param other tensor to be divided by. * @param other tensor to be divided by.
* @return the division of this tensor by [other]. * @return the division of this tensor by [other].
*/ */
public operator fun Tensor<T>.div(other: Tensor<T>): Tensor<T> override operator fun StructureND<T>.div(other: StructureND<T>): Tensor<T>
override fun divide(left: StructureND<T>, right: StructureND<T>): StructureND<T> = left.div(right)
/** /**
* Divides by the scalar [value] each element of this tensor. * Divides by the scalar [value] each element of this tensor.

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTensors
@ -18,66 +19,66 @@ import space.kscience.kmath.tensors.core.internal.tensor
*/ */
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] + newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) { override fun Tensor<Double>.plusAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i] newThis.mutableBuffer.array()[i] - newOther.mutableBuffer.array()[i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) { override fun Tensor<Double>.minusAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.times(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[newThis.bufferStart + i] * newThis.mutableBuffer.array()[newThis.bufferStart + i] *
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
return DoubleTensor(newThis.shape, resBuffer) return DoubleTensor(newThis.shape, resBuffer)
} }
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) { override fun Tensor<Double>.timesAssign(other: StructureND<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }
} }
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, other.tensor) val broadcast = broadcastTensors(tensor, other.tensor)
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.linearStructure.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
newThis.mutableBuffer.array()[newOther.bufferStart + i] / newThis.mutableBuffer.array()[newOther.bufferStart + i] /
newOther.mutableBuffer.array()[newOther.bufferStart + i] newOther.mutableBuffer.array()[newOther.bufferStart + i]
} }
@ -86,7 +87,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun Tensor<Double>.divAssign(other: Tensor<Double>) { override fun Tensor<Double>.divAssign(other: Tensor<Double>) {
val newOther = broadcastTo(other.tensor, tensor.shape) val newOther = broadcastTo(other.tensor, tensor.shape)
for (i in 0 until tensor.linearStructure.linearSize) { for (i in 0 until tensor.indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= tensor.mutableBuffer.array()[tensor.bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[tensor.bufferStart + i]
} }

View File

@ -23,23 +23,23 @@ public open class BufferedTensor<T> internal constructor(
/** /**
* Buffer strides based on [TensorLinearStructure] implementation * Buffer strides based on [TensorLinearStructure] implementation
*/ */
public val linearStructure: Strides public val indices: Strides
get() = TensorLinearStructure(shape) get() = TensorLinearStructure(shape)
/** /**
* Number of elements in tensor * Number of elements in tensor
*/ */
public val numElements: Int public val numElements: Int
get() = linearStructure.linearSize get() = indices.linearSize
override fun get(index: IntArray): T = mutableBuffer[bufferStart + linearStructure.offset(index)] override fun get(index: IntArray): T = mutableBuffer[bufferStart + indices.offset(index)]
override fun set(index: IntArray, value: T) { override fun set(index: IntArray, value: T) {
mutableBuffer[bufferStart + linearStructure.offset(index)] = value mutableBuffer[bufferStart + indices.offset(index)] = value
} }
@PerformancePitfall @PerformancePitfall
override fun elements(): Sequence<Pair<IntArray, T>> = linearStructure.indices().map { override fun elements(): Sequence<Pair<IntArray, T>> = indices.indices().map {
it to get(it) it to get(it)
} }
} }

View File

@ -6,8 +6,10 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.MutableStructure2D import space.kscience.kmath.nd.MutableStructure2D
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.structures.indices import space.kscience.kmath.structures.indices
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
@ -20,16 +22,71 @@ import kotlin.math.*
* Implementation of basic operations over double tensors and basic algebra operations on them. * Implementation of basic operations over double tensors and basic algebra operations on them.
*/ */
public open class DoubleTensorAlgebra : public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>, TensorPartialDivisionAlgebra<Double, DoubleField>,
AnalyticTensorAlgebra<Double>, AnalyticTensorAlgebra<Double, DoubleField>,
LinearOpsTensorAlgebra<Double>{ LinearOpsTensorAlgebra<Double, DoubleField> {
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) override val elementAlgebra: DoubleField
get() = DoubleField
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): DoubleTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) }
return DoubleTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
@Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor {
val tensor = this.tensor
//TODO remove additional copy
val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) }
return DoubleTensor(
tensor.shape,
array,
tensor.bufferStart
)
}
override fun zip(
left: StructureND<Double>,
right: StructureND<Double>,
transform: DoubleField.(Double, Double) -> Double
): DoubleTensor {
require(left.shape.contentEquals(right.shape)){
"The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}"
}
val leftTensor = left.tensor
val leftArray = leftTensor.copyArray()
val rightTensor = right.tensor
val rightArray = rightTensor.copyArray()
val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) }
return DoubleTensor(
leftTensor.shape,
array
)
}
override fun StructureND<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun Tensor<Double>.value(): Double = valueOrNull() override fun StructureND<Double>.value(): Double = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
/** /**
@ -53,10 +110,9 @@ public open class DoubleTensorAlgebra :
* @param initializer mapping tensor indices to values. * @param initializer mapping tensor indices to values.
* @return tensor with the [shape] shape and data generated by the [initializer]. * @return tensor with the [shape] shape and data generated by the [initializer].
*/ */
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): DoubleTensor = override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray(
fromArray(
shape, shape,
TensorLinearStructure(shape).indices().map(initializer).toMutableList().toDoubleArray() TensorLinearStructure(shape).indices().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
) )
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
@ -146,16 +202,16 @@ public open class DoubleTensorAlgebra :
return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart)
} }
override fun Double.plus(other: Tensor<Double>): DoubleTensor { override fun Double.plus(other: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun Tensor<Double>.plus(value: Double): DoubleTensor = value + tensor override fun StructureND<Double>.plus(value: Double): DoubleTensor = value + tensor
override fun Tensor<Double>.plus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.plus(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i]
@ -169,7 +225,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.plusAssign(other: Tensor<Double>) { override fun Tensor<Double>.plusAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other.tensor) checkShapesCompatible(tensor, other.tensor)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += tensor.mutableBuffer.array()[tensor.bufferStart + i] +=
@ -177,21 +233,21 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.minus(other: Tensor<Double>): DoubleTensor { override fun Double.minus(other: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(other.tensor.numElements) { i ->
this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(other.shape, resBuffer)
} }
override fun Tensor<Double>.minus(value: Double): DoubleTensor { override fun StructureND<Double>.minus(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - value tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(tensor.shape, resBuffer)
} }
override fun Tensor<Double>.minus(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.minus(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i]
@ -205,7 +261,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.minusAssign(other: Tensor<Double>) { override fun Tensor<Double>.minusAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= tensor.mutableBuffer.array()[tensor.bufferStart + i] -=
@ -213,16 +269,16 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.times(other: Tensor<Double>): DoubleTensor { override fun Double.times(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.tensor.numElements) { i ->
other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun Tensor<Double>.times(value: Double): DoubleTensor = value * tensor override fun StructureND<Double>.times(arg: Double): DoubleTensor = arg * tensor
override fun Tensor<Double>.times(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.times(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] * tensor.mutableBuffer.array()[tensor.bufferStart + i] *
@ -237,7 +293,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.timesAssign(other: Tensor<Double>) { override fun Tensor<Double>.timesAssign(other: StructureND<Double>) {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
for (i in 0 until tensor.numElements) { for (i in 0 until tensor.numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= tensor.mutableBuffer.array()[tensor.bufferStart + i] *=
@ -245,21 +301,21 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Double.div(other: Tensor<Double>): DoubleTensor { override fun Double.div(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(other.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.tensor.numElements) { i ->
this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i]
} }
return DoubleTensor(other.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun Tensor<Double>.div(value: Double): DoubleTensor { override fun StructureND<Double>.div(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / value tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg
} }
return DoubleTensor(shape, resBuffer) return DoubleTensor(shape, resBuffer)
} }
override fun Tensor<Double>.div(other: Tensor<Double>): DoubleTensor { override fun StructureND<Double>.div(other: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, other) checkShapesCompatible(tensor, other)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / tensor.mutableBuffer.array()[other.tensor.bufferStart + i] /
@ -282,7 +338,7 @@ public open class DoubleTensorAlgebra :
} }
} }
override fun Tensor<Double>.unaryMinus(): DoubleTensor { override fun StructureND<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(tensor.numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus()
} }
@ -302,11 +358,11 @@ public open class DoubleTensorAlgebra :
val resTensor = DoubleTensor(resShape, resBuffer) val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) { for (offset in 0 until n) {
val oldMultiIndex = tensor.linearStructure.index(offset) val oldMultiIndex = tensor.indices.index(offset)
val newMultiIndex = oldMultiIndex.copyOf() val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.linearStructure.offset(newMultiIndex) val linearIndex = resTensor.indices.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset] tensor.mutableBuffer.array()[tensor.bufferStart + offset]
} }
@ -406,7 +462,7 @@ public open class DoubleTensorAlgebra :
val resTensor = zeros(resShape) val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) { for (i in 0 until diagonalEntries.tensor.numElements) {
val multiIndex = diagonalEntries.tensor.linearStructure.index(i) val multiIndex = diagonalEntries.tensor.indices.index(i)
var offset1 = 0 var offset1 = 0
var offset2 = abs(realOffset) var offset2 = abs(realOffset)
@ -425,18 +481,6 @@ public open class DoubleTensorAlgebra :
return resTensor.tensor return resTensor.tensor
} }
/**
* Applies the [transform] function to each element of the tensor and returns the resulting modified tensor.
*
* @param transform the function to be applied to each element of the tensor.
* @return the resulting tensor after applying the function.
*/
public inline fun Tensor<Double>.map(transform: (Double) -> Double): DoubleTensor = DoubleTensor(
tensor.shape,
tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(),
tensor.bufferStart
)
/** /**
* Compares element-wise two tensors with a specified precision. * Compares element-wise two tensors with a specified precision.
* *
@ -526,7 +570,7 @@ public open class DoubleTensorAlgebra :
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = internal inline fun Tensor<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.toDoubleArray()) foldFunction(tensor.copyArray())
internal inline fun Tensor<Double>.foldDim( internal inline fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double, foldFunction: (DoubleArray) -> Double,
@ -541,7 +585,7 @@ public open class DoubleTensorAlgebra :
} }
val resNumElements = resShape.reduce(Int::times) val resNumElements = resShape.reduce(Int::times)
val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0) val resTensor = DoubleTensor(resShape, DoubleArray(resNumElements) { 0.0 }, 0)
for (index in resTensor.linearStructure.indices()) { for (index in resTensor.indices.indices()) {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
@ -645,39 +689,39 @@ public open class DoubleTensorAlgebra :
return resTensor return resTensor
} }
override fun Tensor<Double>.exp(): DoubleTensor = tensor.map(::exp) override fun Tensor<Double>.exp(): DoubleTensor = tensor.map { exp(it) }
override fun Tensor<Double>.ln(): DoubleTensor = tensor.map(::ln) override fun Tensor<Double>.ln(): DoubleTensor = tensor.map { ln(it) }
override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map(::sqrt) override fun Tensor<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) }
override fun Tensor<Double>.cos(): DoubleTensor = tensor.map(::cos) override fun Tensor<Double>.cos(): DoubleTensor = tensor.map { cos(it) }
override fun Tensor<Double>.acos(): DoubleTensor = tensor.map(::acos) override fun Tensor<Double>.acos(): DoubleTensor = tensor.map { acos(it) }
override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map(::cosh) override fun Tensor<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) }
override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map(::acosh) override fun Tensor<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) }
override fun Tensor<Double>.sin(): DoubleTensor = tensor.map(::sin) override fun Tensor<Double>.sin(): DoubleTensor = tensor.map { sin(it) }
override fun Tensor<Double>.asin(): DoubleTensor = tensor.map(::asin) override fun Tensor<Double>.asin(): DoubleTensor = tensor.map { asin(it) }
override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map(::sinh) override fun Tensor<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) }
override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map(::asinh) override fun Tensor<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) }
override fun Tensor<Double>.tan(): DoubleTensor = tensor.map(::tan) override fun Tensor<Double>.tan(): DoubleTensor = tensor.map { tan(it) }
override fun Tensor<Double>.atan(): DoubleTensor = tensor.map(::atan) override fun Tensor<Double>.atan(): DoubleTensor = tensor.map { atan(it) }
override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map(::tanh) override fun Tensor<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) }
override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map(::atanh) override fun Tensor<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) }
override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map(::ceil) override fun Tensor<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) }
override fun Tensor<Double>.floor(): DoubleTensor = tensor.map(::floor) override fun Tensor<Double>.floor(): DoubleTensor = tensor.map { floor(it) }
override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9) override fun Tensor<Double>.inv(): DoubleTensor = invLU(1e-9)

View File

@ -10,7 +10,7 @@ import kotlin.math.max
internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) { internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTensor, linearSize: Int) {
for (linearIndex in 0 until linearSize) { for (linearIndex in 0 until linearSize) {
val totalMultiIndex = resTensor.linearStructure.index(linearIndex) val totalMultiIndex = resTensor.indices.index(linearIndex)
val curMultiIndex = tensor.shape.copyOf() val curMultiIndex = tensor.shape.copyOf()
val offset = totalMultiIndex.size - curMultiIndex.size val offset = totalMultiIndex.size - curMultiIndex.size
@ -23,7 +23,7 @@ internal fun multiIndexBroadCasting(tensor: DoubleTensor, resTensor: DoubleTenso
} }
} }
val curLinearIndex = tensor.linearStructure.offset(curMultiIndex) val curLinearIndex = tensor.indices.offset(curMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex] tensor.mutableBuffer.array()[tensor.bufferStart + curLinearIndex]
} }
@ -112,7 +112,7 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTen
val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize)) val resTensor = DoubleTensor(totalShape + matrixShape, DoubleArray(n * matrixSize))
for (linearIndex in 0 until n) { for (linearIndex in 0 until n) {
val totalMultiIndex = outerTensor.linearStructure.index(linearIndex) val totalMultiIndex = outerTensor.indices.index(linearIndex)
var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf() var curMultiIndex = tensor.shape.sliceArray(0..tensor.shape.size - 3).copyOf()
curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex curMultiIndex = IntArray(totalMultiIndex.size - curMultiIndex.size) { 1 } + curMultiIndex
@ -127,13 +127,13 @@ internal fun broadcastOuterTensors(vararg tensors: DoubleTensor): List<DoubleTen
} }
for (i in 0 until matrixSize) { for (i in 0 until matrixSize) {
val curLinearIndex = newTensor.linearStructure.offset( val curLinearIndex = newTensor.indices.offset(
curMultiIndex + curMultiIndex +
matrix.linearStructure.index(i) matrix.indices.index(i)
) )
val newLinearIndex = resTensor.linearStructure.offset( val newLinearIndex = resTensor.indices.offset(
totalMultiIndex + totalMultiIndex +
matrix.linearStructure.index(i) matrix.indices.index(i)
) )
resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] = resTensor.mutableBuffer.array()[resTensor.bufferStart + newLinearIndex] =

View File

@ -5,6 +5,7 @@
package space.kscience.kmath.tensors.core.internal package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
@ -25,7 +26,7 @@ internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) =
"Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided" "Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided"
} }
internal fun <T> checkShapesCompatible(a: Tensor<T>, b: Tensor<T>) = internal fun <T> checkShapesCompatible(a: StructureND<T>, b: StructureND<T>) =
check(a.shape contentEquals b.shape) { check(a.shape contentEquals b.shape) {
"Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} "
} }

View File

@ -6,6 +6,7 @@
package space.kscience.kmath.tensors.core.internal package space.kscience.kmath.tensors.core.internal
import space.kscience.kmath.nd.MutableBufferND import space.kscience.kmath.nd.MutableBufferND
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.structures.asMutableBuffer import space.kscience.kmath.structures.asMutableBuffer
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.BufferedTensor import space.kscience.kmath.tensors.core.BufferedTensor
@ -18,15 +19,15 @@ internal fun BufferedTensor<Int>.asTensor(): IntTensor =
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = internal fun BufferedTensor<Double>.asTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> = internal fun <T> StructureND<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor( BufferedTensor(
this.shape, this.shape,
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0
) )
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> StructureND<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
is BufferedTensor<T> -> this is BufferedTensor<T> -> this
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) { is MutableBufferND<T> -> if (this.indices == TensorLinearStructure(this.shape)) {
BufferedTensor(this.shape, this.buffer, 0) BufferedTensor(this.shape, this.buffer, 0)
} else { } else {
this.copyToBufferedTensor() this.copyToBufferedTensor()
@ -35,7 +36,7 @@ internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
} }
@PublishedApi @PublishedApi
internal val Tensor<Double>.tensor: DoubleTensor internal val StructureND<Double>.tensor: DoubleTensor
get() = when (this) { get() = when (this) {
is DoubleTensor -> this is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor() else -> this.toBufferedTensor().asTensor()

View File

@ -85,7 +85,7 @@ internal fun format(value: Double, digits: Int = 4): String = buildString {
internal fun DoubleTensor.toPrettyString(): String = buildString { internal fun DoubleTensor.toPrettyString(): String = buildString {
var offset = 0 var offset = 0
val shape = this@toPrettyString.shape val shape = this@toPrettyString.shape
val linearStructure = this@toPrettyString.linearStructure val linearStructure = this@toPrettyString.indices
val vectorSize = shape.last() val vectorSize = shape.last()
append("DoubleTensor(\n") append("DoubleTensor(\n")
var charOffset = 3 var charOffset = 3

View File

@ -19,18 +19,19 @@ public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor
/** /**
* Returns [DoubleArray] of tensor elements * Returns a copy-protected [DoubleArray] of tensor elements
*/ */
public fun DoubleTensor.toDoubleArray(): DoubleArray { public fun DoubleTensor.copyArray(): DoubleArray {
//TODO use ArrayCopy
return DoubleArray(numElements) { i -> return DoubleArray(numElements) { i ->
mutableBuffer[bufferStart + i] mutableBuffer[bufferStart + i]
} }
} }
/** /**
* Returns [IntArray] of tensor elements * Returns a copy-protected [IntArray] of tensor elements
*/ */
public fun IntTensor.toIntArray(): IntArray { public fun IntTensor.copyArray(): IntArray {
return IntArray(numElements) { i -> return IntArray(numElements) { i ->
mutableBuffer[bufferStart + i] mutableBuffer[bufferStart + i]
} }