diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/AnalyticTensorAlgebra.kt index fa1de81fa..2a8b228f4 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/AnalyticTensorAlgebra.kt @@ -1,58 +1,58 @@ package space.kscience.kmath.tensors -public interface AnalyticTensorAlgebra> : - TensorPartialDivisionAlgebra { +public interface AnalyticTensorAlgebra : + TensorPartialDivisionAlgebra { //https://pytorch.org/docs/stable/generated/torch.exp.html - public fun TensorType.exp(): TensorType + public fun TensorStructure.exp(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.log.html - public fun TensorType.log(): TensorType + public fun TensorStructure.log(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.sqrt.html - public fun TensorType.sqrt(): TensorType + public fun TensorStructure.sqrt(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun TensorType.cos(): TensorType + public fun TensorStructure.cos(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun TensorType.acos(): TensorType + public fun TensorStructure.acos(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun TensorType.cosh(): TensorType + public fun TensorStructure.cosh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun TensorType.acosh(): TensorType + public fun TensorStructure.acosh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun TensorType.sin(): TensorType + public fun TensorStructure.sin(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun TensorType.asin(): TensorType + public fun TensorStructure.asin(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun TensorType.sinh(): TensorType + public fun TensorStructure.sinh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun TensorType.asinh(): TensorType + public fun TensorStructure.asinh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun TensorType.tan(): TensorType + public fun TensorStructure.tan(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun TensorType.atan(): TensorType + public fun TensorStructure.atan(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun TensorType.tanh(): TensorType + public fun TensorStructure.tanh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun TensorType.atanh(): TensorType + public fun TensorStructure.atanh(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil - public fun TensorType.ceil(): TensorType + public fun TensorStructure.ceil(): TensorStructure //https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor - public fun TensorType.floor(): TensorType + public fun TensorStructure.floor(): TensorStructure } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt index 87c459f35..8ca8945a5 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/LinearOpsTensorAlgebra.kt @@ -1,31 +1,32 @@ package space.kscience.kmath.tensors -public interface LinearOpsTensorAlgebra, IndexTensorType : TensorStructure> : - TensorPartialDivisionAlgebra { +public interface LinearOpsTensorAlgebra : + TensorPartialDivisionAlgebra { //https://pytorch.org/docs/stable/linalg.html#torch.linalg.det - public fun TensorType.det(): TensorType + public fun TensorStructure.det(): TensorStructure //https://pytorch.org/docs/stable/linalg.html#torch.linalg.inv - public fun TensorType.inv(): TensorType + public fun TensorStructure.inv(): TensorStructure //https://pytorch.org/docs/stable/linalg.html#torch.linalg.cholesky - public fun TensorType.cholesky(): TensorType + public fun TensorStructure.cholesky(): TensorStructure //https://pytorch.org/docs/stable/linalg.html#torch.linalg.qr - public fun TensorType.qr(): Pair + public fun TensorStructure.qr(): Pair, TensorStructure> //https://pytorch.org/docs/stable/generated/torch.lu.html - public fun TensorType.lu(): Pair + public fun TensorStructure.lu(): Pair, TensorStructure> //https://pytorch.org/docs/stable/generated/torch.lu_unpack.html - public fun luPivot(luTensor: TensorType, pivotsTensor: IndexTensorType): Triple + public fun luPivot(luTensor: TensorStructure, pivotsTensor: TensorStructure): + Triple, TensorStructure, TensorStructure> //https://pytorch.org/docs/stable/linalg.html#torch.linalg.svd - public fun TensorType.svd(): Triple + public fun TensorStructure.svd(): Triple, TensorStructure, TensorStructure> //https://pytorch.org/docs/stable/generated/torch.symeig.html - public fun TensorType.symEig(): Pair + public fun TensorStructure.symEig(): Pair, TensorStructure> } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt index 8fd1cf2ed..f64d49dd8 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorAlgebra.kt @@ -1,46 +1,46 @@ package space.kscience.kmath.tensors // https://proofwiki.org/wiki/Definition:Algebra_over_Ring -public interface TensorAlgebra> { +public interface TensorAlgebra { - public fun TensorType.value(): T + public fun TensorStructure.value(): T - public operator fun T.plus(other: TensorType): TensorType - public operator fun TensorType.plus(value: T): TensorType - public operator fun TensorType.plus(other: TensorType): TensorType - public operator fun TensorType.plusAssign(value: T): Unit - public operator fun TensorType.plusAssign(other: TensorType): Unit + public operator fun T.plus(other: TensorStructure): TensorStructure + public operator fun TensorStructure.plus(value: T): TensorStructure + public operator fun TensorStructure.plus(other: TensorStructure): TensorStructure + public operator fun TensorStructure.plusAssign(value: T): Unit + public operator fun TensorStructure.plusAssign(other: TensorStructure): Unit - public operator fun T.minus(other: TensorType): TensorType - public operator fun TensorType.minus(value: T): TensorType - public operator fun TensorType.minus(other: TensorType): TensorType - public operator fun TensorType.minusAssign(value: T): Unit - public operator fun TensorType.minusAssign(other: TensorType): Unit + public operator fun T.minus(other: TensorStructure): TensorStructure + public operator fun TensorStructure.minus(value: T): TensorStructure + public operator fun TensorStructure.minus(other: TensorStructure): TensorStructure + public operator fun TensorStructure.minusAssign(value: T): Unit + public operator fun TensorStructure.minusAssign(other: TensorStructure): Unit - public operator fun T.times(other: TensorType): TensorType - public operator fun TensorType.times(value: T): TensorType - public operator fun TensorType.times(other: TensorType): TensorType - public operator fun TensorType.timesAssign(value: T): Unit - public operator fun TensorType.timesAssign(other: TensorType): Unit - public operator fun TensorType.unaryMinus(): TensorType + public operator fun T.times(other: TensorStructure): TensorStructure + public operator fun TensorStructure.times(value: T): TensorStructure + public operator fun TensorStructure.times(other: TensorStructure): TensorStructure + public operator fun TensorStructure.timesAssign(value: T): Unit + public operator fun TensorStructure.timesAssign(other: TensorStructure): Unit + public operator fun TensorStructure.unaryMinus(): TensorStructure //https://pytorch.org/cppdocs/notes/tensor_indexing.html - public operator fun TensorType.get(i: Int): TensorType + public operator fun TensorStructure.get(i: Int): TensorStructure //https://pytorch.org/docs/stable/generated/torch.transpose.html - public fun TensorType.transpose(i: Int = -2, j: Int = -1): TensorType + public fun TensorStructure.transpose(i: Int = -2, j: Int = -1): TensorStructure //https://pytorch.org/docs/stable/tensor_view.html - public fun TensorType.view(shape: IntArray): TensorType - public fun TensorType.viewAs(other: TensorType): TensorType + public fun TensorStructure.view(shape: IntArray): TensorStructure + public fun TensorStructure.viewAs(other: TensorStructure): TensorStructure //https://pytorch.org/docs/stable/generated/torch.matmul.html - public infix fun TensorType.dot(other: TensorType): TensorType + public infix fun TensorStructure.dot(other: TensorStructure): TensorStructure //https://pytorch.org/docs/stable/generated/torch.diag_embed.html public fun diagonalEmbedding( - diagonalEntries: TensorType, + diagonalEntries: TensorStructure, offset: Int = 0, dim1: Int = -2, dim2: Int = -1 - ): TensorType + ): TensorStructure } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt index 0b9079967..b37ac6b6e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/TensorPartialDivisionAlgebra.kt @@ -1,10 +1,10 @@ package space.kscience.kmath.tensors // https://proofwiki.org/wiki/Definition:Division_Algebra -public interface TensorPartialDivisionAlgebra> : - TensorAlgebra { - public operator fun TensorType.div(value: T): TensorType - public operator fun TensorType.div(other: TensorType): TensorType - public operator fun TensorType.divAssign(value: T) - public operator fun TensorType.divAssign(other: TensorType) +public interface TensorPartialDivisionAlgebra : + TensorAlgebra { + public operator fun TensorStructure.div(value: T): TensorStructure + public operator fun TensorStructure.div(other: TensorStructure): TensorStructure + public operator fun TensorStructure.divAssign(value: T) + public operator fun TensorStructure.divAssign(other: TensorStructure) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index f315f6b51..53c385197 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -1,11 +1,12 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.tensors.TensorStructure import kotlin.math.max public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { - override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor { - val broadcast = broadcastTensors(this, other) + override fun TensorStructure.plus(other: TensorStructure): DoubleTensor { + val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> @@ -14,16 +15,16 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun DoubleTensor.plusAssign(other: DoubleTensor) { - val newOther = broadcastTo(other, this.shape) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] += - newOther.buffer.array()[this.bufferStart + i] + override fun TensorStructure.plusAssign(other: TensorStructure) { + val newOther = broadcastTo(other.tensor, tensor.shape) + for (i in 0 until tensor.linearStructure.size) { + tensor.buffer.array()[tensor.bufferStart + i] += + newOther.buffer.array()[tensor.bufferStart + i] } } - override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor { - val broadcast = broadcastTensors(this, other) + override fun TensorStructure.minus(other: TensorStructure): DoubleTensor { + val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> @@ -32,16 +33,16 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun DoubleTensor.minusAssign(other: DoubleTensor) { - val newOther = broadcastTo(other, this.shape) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] -= - newOther.buffer.array()[this.bufferStart + i] + override fun TensorStructure.minusAssign(other: TensorStructure) { + val newOther = broadcastTo(other.tensor, tensor.shape) + for (i in 0 until tensor.linearStructure.size) { + tensor.buffer.array()[tensor.bufferStart + i] -= + newOther.buffer.array()[tensor.bufferStart + i] } } - override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { - val broadcast = broadcastTensors(this, other) + override fun TensorStructure.times(other: TensorStructure): DoubleTensor { + val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> @@ -51,16 +52,16 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun DoubleTensor.timesAssign(other: DoubleTensor) { - val newOther = broadcastTo(other, this.shape) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] *= - newOther.buffer.array()[this.bufferStart + i] + override fun TensorStructure.timesAssign(other: TensorStructure) { + val newOther = broadcastTo(other.tensor, tensor.shape) + for (i in 0 until tensor.linearStructure.size) { + tensor.buffer.array()[tensor.bufferStart + i] *= + newOther.buffer.array()[tensor.bufferStart + i] } } - override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor { - val broadcast = broadcastTensors(this, other) + override fun TensorStructure.div(other: TensorStructure): DoubleTensor { + val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.linearStructure.size) { i -> @@ -70,11 +71,11 @@ public class BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun DoubleTensor.divAssign(other: DoubleTensor) { - val newOther = broadcastTo(other, this.shape) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] /= - newOther.buffer.array()[this.bufferStart + i] + override fun TensorStructure.divAssign(other: TensorStructure) { + val newOther = broadcastTo(other.tensor, tensor.shape) + for (i in 0 until tensor.linearStructure.size) { + tensor.buffer.array()[tensor.bufferStart + i] /= + newOther.buffer.array()[tensor.bufferStart + i] } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index 9e393f1a8..2b61e10b2 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -6,13 +6,13 @@ import space.kscience.kmath.tensors.TensorStructure public open class BufferedTensor( override val shape: IntArray, - public val buffer: MutableBuffer, + internal val buffer: MutableBuffer, internal val bufferStart: Int ) : TensorStructure { public val linearStructure: TensorLinearStructure get() = TensorLinearStructure(shape) - public val numel: Int + public val numElements: Int get() = linearStructure.size override fun get(index: IntArray): T = buffer[bufferStart + linearStructure.offset(index)] @@ -41,26 +41,6 @@ public class IntTensor internal constructor( this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart) } -public class LongTensor internal constructor( - shape: IntArray, - buffer: LongArray, - offset: Int = 0 -) : BufferedTensor(shape, LongBuffer(buffer), offset) -{ - internal constructor(bufferedTensor: BufferedTensor): - this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart) -} - -public class FloatTensor internal constructor( - shape: IntArray, - buffer: FloatArray, - offset: Int = 0 -) : BufferedTensor(shape, FloatBuffer(buffer), offset) -{ - internal constructor(bufferedTensor: BufferedTensor): - this(bufferedTensor.shape, bufferedTensor.buffer.array(), bufferedTensor.bufferStart) -} - public class DoubleTensor internal constructor( shape: IntArray, buffer: DoubleArray, @@ -74,7 +54,23 @@ public class DoubleTensor internal constructor( } -internal fun BufferedTensor.asTensor(): IntTensor = IntTensor(this) -internal fun BufferedTensor.asTensor(): LongTensor = LongTensor(this) -internal fun BufferedTensor.asTensor(): FloatTensor = FloatTensor(this) -internal fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this) +internal inline fun BufferedTensor.asTensor(): IntTensor = IntTensor(this) +internal inline fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this) + +internal inline fun TensorStructure.toBufferedTensor(): BufferedTensor = when (this) { + is BufferedTensor -> this + else -> BufferedTensor(this.shape, this.elements().map{ it.second }.toMutableList().asMutableBuffer(), 0) +} + +internal val TensorStructure.tensor: DoubleTensor + get() = when (this) { + is DoubleTensor -> this + else -> this.toBufferedTensor().asTensor() + } + +internal val TensorStructure.tensor: IntTensor + get() = when (this) { + is IntTensor -> this + else -> this.toBufferedTensor().asTensor() + } + diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra.kt index b46c7fa9c..233217f2f 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleAnalyticTensorAlgebra.kt @@ -1,45 +1,46 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.tensors.AnalyticTensorAlgebra +import space.kscience.kmath.tensors.TensorStructure import kotlin.math.* public class DoubleAnalyticTensorAlgebra: - AnalyticTensorAlgebra, + AnalyticTensorAlgebra, DoubleTensorAlgebra() { - override fun DoubleTensor.exp(): DoubleTensor = this.map(::exp) + override fun TensorStructure.exp(): DoubleTensor = tensor.map(::exp) - override fun DoubleTensor.log(): DoubleTensor = this.map(::ln) + override fun TensorStructure.log(): DoubleTensor = tensor.map(::ln) - override fun DoubleTensor.sqrt(): DoubleTensor = this.map(::sqrt) + override fun TensorStructure.sqrt(): DoubleTensor = tensor.map(::sqrt) - override fun DoubleTensor.cos(): DoubleTensor = this.map(::cos) + override fun TensorStructure.cos(): DoubleTensor = tensor.map(::cos) - override fun DoubleTensor.acos(): DoubleTensor = this.map(::acos) + override fun TensorStructure.acos(): DoubleTensor = tensor.map(::acos) - override fun DoubleTensor.cosh(): DoubleTensor = this.map(::cosh) + override fun TensorStructure.cosh(): DoubleTensor = tensor.map(::cosh) - override fun DoubleTensor.acosh(): DoubleTensor = this.map(::acosh) + override fun TensorStructure.acosh(): DoubleTensor = tensor.map(::acosh) - override fun DoubleTensor.sin(): DoubleTensor = this.map(::sin) + override fun TensorStructure.sin(): DoubleTensor = tensor.map(::sin) - override fun DoubleTensor.asin(): DoubleTensor = this.map(::asin) + override fun TensorStructure.asin(): DoubleTensor = tensor.map(::asin) - override fun DoubleTensor.sinh(): DoubleTensor = this.map(::sinh) + override fun TensorStructure.sinh(): DoubleTensor = tensor.map(::sinh) - override fun DoubleTensor.asinh(): DoubleTensor = this.map(::asinh) + override fun TensorStructure.asinh(): DoubleTensor = tensor.map(::asinh) - override fun DoubleTensor.tan(): DoubleTensor = this.map(::tan) + override fun TensorStructure.tan(): DoubleTensor = tensor.map(::tan) - override fun DoubleTensor.atan(): DoubleTensor = this.map(::atan) + override fun TensorStructure.atan(): DoubleTensor = tensor.map(::atan) - override fun DoubleTensor.tanh(): DoubleTensor = this.map(::tanh) + override fun TensorStructure.tanh(): DoubleTensor = tensor.map(::tanh) - override fun DoubleTensor.atanh(): DoubleTensor = this.map(::atanh) + override fun TensorStructure.atanh(): DoubleTensor = tensor.map(::atanh) - override fun DoubleTensor.ceil(): DoubleTensor = this.map(::ceil) + override fun TensorStructure.ceil(): DoubleTensor = tensor.map(::ceil) - override fun DoubleTensor.floor(): DoubleTensor = this.map(::floor) + override fun TensorStructure.floor(): DoubleTensor = tensor.map(::floor) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt index e72948c84..386fc2a1b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleLinearOpsTensorAlgebra.kt @@ -3,44 +3,45 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.tensors.LinearOpsTensorAlgebra import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D +import space.kscience.kmath.tensors.TensorStructure import kotlin.math.min public class DoubleLinearOpsTensorAlgebra : - LinearOpsTensorAlgebra, + LinearOpsTensorAlgebra, DoubleTensorAlgebra() { - override fun DoubleTensor.inv(): DoubleTensor = invLU(1e-9) + override fun TensorStructure.inv(): DoubleTensor = invLU(1e-9) - override fun DoubleTensor.det(): DoubleTensor = detLU(1e-9) + override fun TensorStructure.det(): DoubleTensor = detLU(1e-9) - public fun DoubleTensor.lu(epsilon: Double): Pair = - computeLU(this, epsilon) ?: + public fun TensorStructure.lu(epsilon: Double): Pair = + computeLU(tensor, epsilon) ?: throw RuntimeException("Tensor contains matrices which are singular at precision $epsilon") - override fun DoubleTensor.lu(): Pair = lu(1e-9) + override fun TensorStructure.lu(): Pair = lu(1e-9) override fun luPivot( - luTensor: DoubleTensor, - pivotsTensor: IntTensor + luTensor: TensorStructure, + pivotsTensor: TensorStructure ): Triple { //todo checks checkSquareMatrix(luTensor.shape) check( luTensor.shape.dropLast(2).toIntArray() contentEquals pivotsTensor.shape.dropLast(1).toIntArray() || luTensor.shape.last() == pivotsTensor.shape.last() - 1 - ) { "Bed shapes ((" } //todo rewrite + ) { "Bad shapes ((" } //todo rewrite val n = luTensor.shape.last() val pTensor = luTensor.zeroesLike() - for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.vectorSequence())) + for ((p, pivot) in pTensor.matrixSequence().zip(pivotsTensor.tensor.vectorSequence())) pivInit(p.as2D(), pivot.as1D(), n) val lTensor = luTensor.zeroesLike() val uTensor = luTensor.zeroesLike() for ((pairLU, lu) in lTensor.matrixSequence().zip(uTensor.matrixSequence()) - .zip(luTensor.matrixSequence())) { + .zip(luTensor.tensor.matrixSequence())) { val (l, u) = pairLU luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n) } @@ -49,26 +50,26 @@ public class DoubleLinearOpsTensorAlgebra : } - public fun DoubleTensor.cholesky(epsilon: Double): DoubleTensor { + public fun TensorStructure.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) - checkPositiveDefinite(this, epsilon) + checkPositiveDefinite(tensor, epsilon) val n = shape.last() val lTensor = zeroesLike() - for ((a, l) in this.matrixSequence().zip(lTensor.matrixSequence())) + for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence())) for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n) return lTensor } - override fun DoubleTensor.cholesky(): DoubleTensor = cholesky(1e-6) + override fun TensorStructure.cholesky(): DoubleTensor = cholesky(1e-6) - override fun DoubleTensor.qr(): Pair { + override fun TensorStructure.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() - val seq = matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence()))) + val seq = tensor.matrixSequence().zip((qTensor.matrixSequence().zip(rTensor.matrixSequence()))) for ((matrix, qr) in seq) { val (q, r) = qr qrHelper(matrix.asTensor(), q.asTensor(), r.as2D()) @@ -76,18 +77,18 @@ public class DoubleLinearOpsTensorAlgebra : return qTensor to rTensor } - override fun DoubleTensor.svd(): Triple = + override fun TensorStructure.svd(): Triple = svd(epsilon = 1e-10) - public fun DoubleTensor.svd(epsilon: Double): Triple { - val size = this.linearStructure.dim - val commonShape = this.shape.sliceArray(0 until size - 2) - val (n, m) = this.shape.sliceArray(size - 2 until size) + public fun TensorStructure.svd(epsilon: Double): Triple { + val size = tensor.linearStructure.dim + val commonShape = tensor.shape.sliceArray(0 until size - 2) + val (n, m) = tensor.shape.sliceArray(size - 2 until size) val resU = zeros(commonShape + intArrayOf(min(n, m), n)) val resS = zeros(commonShape + intArrayOf(min(n, m))) val resV = zeros(commonShape + intArrayOf(min(n, m), m)) - for ((matrix, USV) in this.matrixSequence() + for ((matrix, USV) in tensor.matrixSequence() .zip(resU.matrixSequence().zip(resS.vectorSequence().zip(resV.matrixSequence())))) { val size = matrix.shape.reduce { acc, i -> acc * i } val curMatrix = DoubleTensor( @@ -99,13 +100,13 @@ public class DoubleLinearOpsTensorAlgebra : return Triple(resU.transpose(), resS, resV.transpose()) } - override fun DoubleTensor.symEig(): Pair = + override fun TensorStructure.symEig(): Pair = symEig(epsilon = 1e-15) //http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html - public fun DoubleTensor.symEig(epsilon: Double): Pair { - checkSymmetric(this, epsilon) - val (u, s, v) = this.svd(epsilon) + public fun TensorStructure.symEig(epsilon: Double): Pair { + checkSymmetric(tensor, epsilon) + val (u, s, v) = tensor.svd(epsilon) val shp = s.shape + intArrayOf(1) val utv = u.transpose() dot v val n = s.shape.last() @@ -116,11 +117,11 @@ public class DoubleLinearOpsTensorAlgebra : return Pair(eig, v) } - public fun DoubleTensor.detLU(epsilon: Double = 1e-9): DoubleTensor { + public fun TensorStructure.detLU(epsilon: Double = 1e-9): DoubleTensor { - checkSquareMatrix(this.shape) - val luTensor = this.copy() - val pivotsTensor = this.setUpPivots() + checkSquareMatrix(tensor.shape) + val luTensor = tensor.copy() + val pivotsTensor = tensor.setUpPivots() val n = shape.size @@ -141,7 +142,7 @@ public class DoubleLinearOpsTensorAlgebra : return detTensor } - public fun DoubleTensor.invLU(epsilon: Double = 1e-9): DoubleTensor { + public fun TensorStructure.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = lu(epsilon) val invTensor = luTensor.zeroesLike() diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 8f355dd3d..7544bf68e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -1,16 +1,18 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra import space.kscience.kmath.nd.as2D +import space.kscience.kmath.tensors.TensorPartialDivisionAlgebra +import space.kscience.kmath.tensors.TensorStructure import kotlin.math.abs -public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { +public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { - override fun DoubleTensor.value(): Double { - check(this.shape contentEquals intArrayOf(1)) { + + override fun TensorStructure.value(): Double { + check(tensor.shape contentEquals intArrayOf(1)) { "Inconsistent value for tensor of shape ${shape.toList()}" } - return this.buffer.array()[this.bufferStart] + return tensor.buffer.array()[tensor.bufferStart] } public fun fromArray(shape: IntArray, buffer: DoubleArray): DoubleTensor { @@ -20,11 +22,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra.get(i: Int): DoubleTensor { + val lastShape = tensor.shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) - val newStart = newShape.reduce(Int::times) * i + this.bufferStart - return DoubleTensor(newShape, this.buffer.array(), newStart) + val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart + return DoubleTensor(newShape, tensor.buffer.array(), newStart) } public fun full(value: Double, shape: IntArray): DoubleTensor { @@ -33,19 +35,19 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra.fullLike(value: Double): DoubleTensor { + val shape = tensor.shape + val buffer = DoubleArray(tensor.numElements) { value } return DoubleTensor(shape, buffer) } public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) - public fun DoubleTensor.zeroesLike(): DoubleTensor = this.fullLike(0.0) + public fun TensorStructure.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) - public fun DoubleTensor.onesLike(): DoubleTensor = this.fullLike(1.0) + public fun TensorStructure.onesLike(): DoubleTensor = tensor.fullLike(1.0) public fun eye(n: Int): DoubleTensor { val shape = intArrayOf(n, n) @@ -57,200 +59,200 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra.copy(): DoubleTensor { + return DoubleTensor(tensor.shape, tensor.buffer.array().copyOf(), tensor.bufferStart) } - override fun Double.plus(other: DoubleTensor): DoubleTensor { - val resBuffer = DoubleArray(other.linearStructure.size) { i -> - other.buffer.array()[other.bufferStart + i] + this + override fun Double.plus(other: TensorStructure): DoubleTensor { + val resBuffer = DoubleArray(other.tensor.numElements) { i -> + other.tensor.buffer.array()[other.tensor.bufferStart + i] + this } return DoubleTensor(other.shape, resBuffer) } - override fun DoubleTensor.plus(value: Double): DoubleTensor = value + this + override fun TensorStructure.plus(value: Double): DoubleTensor = value + tensor - override fun DoubleTensor.plus(other: DoubleTensor): DoubleTensor { - checkShapesCompatible(this, other) - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[i] + other.buffer.array()[i] + override fun TensorStructure.plus(other: TensorStructure): DoubleTensor { + checkShapesCompatible(tensor, other.tensor) + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[i] + other.tensor.buffer.array()[i] } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.plusAssign(value: Double) { - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] += value + override fun TensorStructure.plusAssign(value: Double) { + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] += value } } - override fun DoubleTensor.plusAssign(other: DoubleTensor) { - checkShapesCompatible(this, other) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] += - other.buffer.array()[this.bufferStart + i] + override fun TensorStructure.plusAssign(other: TensorStructure) { + checkShapesCompatible(tensor, other.tensor) + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] += + other.tensor.buffer.array()[tensor.bufferStart + i] } } - override fun Double.minus(other: DoubleTensor): DoubleTensor { - val resBuffer = DoubleArray(other.linearStructure.size) { i -> - this - other.buffer.array()[other.bufferStart + i] + override fun Double.minus(other: TensorStructure): DoubleTensor { + val resBuffer = DoubleArray(other.tensor.numElements) { i -> + this - other.tensor.buffer.array()[other.tensor.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } - override fun DoubleTensor.minus(value: Double): DoubleTensor { - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[this.bufferStart + i] - value + override fun TensorStructure.minus(value: Double): DoubleTensor { + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[tensor.bufferStart + i] - value } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.minus(other: DoubleTensor): DoubleTensor { - checkShapesCompatible(this, other) - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[i] - other.buffer.array()[i] + override fun TensorStructure.minus(other: TensorStructure): DoubleTensor { + checkShapesCompatible(tensor, other) + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[i] - other.tensor.buffer.array()[i] } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.minusAssign(value: Double) { - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] -= value + override fun TensorStructure.minusAssign(value: Double) { + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] -= value } } - override fun DoubleTensor.minusAssign(other: DoubleTensor) { - checkShapesCompatible(this, other) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] -= - other.buffer.array()[this.bufferStart + i] + override fun TensorStructure.minusAssign(other: TensorStructure) { + checkShapesCompatible(tensor, other) + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] -= + other.tensor.buffer.array()[tensor.bufferStart + i] } } - override fun Double.times(other: DoubleTensor): DoubleTensor { - val resBuffer = DoubleArray(other.linearStructure.size) { i -> - other.buffer.array()[other.bufferStart + i] * this + override fun Double.times(other: TensorStructure): DoubleTensor { + val resBuffer = DoubleArray(other.tensor.numElements) { i -> + other.tensor.buffer.array()[other.tensor.bufferStart + i] * this } return DoubleTensor(other.shape, resBuffer) } - override fun DoubleTensor.times(value: Double): DoubleTensor = value * this + override fun TensorStructure.times(value: Double): DoubleTensor = value * tensor - override fun DoubleTensor.times(other: DoubleTensor): DoubleTensor { - checkShapesCompatible(this, other) - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[this.bufferStart + i] * - other.buffer.array()[other.bufferStart + i] + override fun TensorStructure.times(other: TensorStructure): DoubleTensor { + checkShapesCompatible(tensor, other) + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[tensor.bufferStart + i] * + other.tensor.buffer.array()[other.tensor.bufferStart + i] } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.timesAssign(value: Double) { - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] *= value + override fun TensorStructure.timesAssign(value: Double) { + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] *= value } } - override fun DoubleTensor.timesAssign(other: DoubleTensor) { - checkShapesCompatible(this, other) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] *= - other.buffer.array()[this.bufferStart + i] + override fun TensorStructure.timesAssign(other: TensorStructure) { + checkShapesCompatible(tensor, other) + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] *= + other.tensor.buffer.array()[tensor.bufferStart + i] } } - override fun DoubleTensor.div(value: Double): DoubleTensor { - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[this.bufferStart + i] / value + override fun TensorStructure.div(value: Double): DoubleTensor { + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[tensor.bufferStart + i] / value } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.div(other: DoubleTensor): DoubleTensor { - checkShapesCompatible(this, other) - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[other.bufferStart + i] / - other.buffer.array()[other.bufferStart + i] + override fun TensorStructure.div(other: TensorStructure): DoubleTensor { + checkShapesCompatible(tensor, other) + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[other.tensor.bufferStart + i] / + other.tensor.buffer.array()[other.tensor.bufferStart + i] } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.divAssign(value: Double) { - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] /= value + override fun TensorStructure.divAssign(value: Double) { + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] /= value } } - override fun DoubleTensor.divAssign(other: DoubleTensor) { - checkShapesCompatible(this, other) - for (i in 0 until this.linearStructure.size) { - this.buffer.array()[this.bufferStart + i] /= - other.buffer.array()[this.bufferStart + i] + override fun TensorStructure.divAssign(other: TensorStructure) { + checkShapesCompatible(tensor, other) + for (i in 0 until tensor.numElements) { + tensor.buffer.array()[tensor.bufferStart + i] /= + other.tensor.buffer.array()[tensor.bufferStart + i] } } - override fun DoubleTensor.unaryMinus(): DoubleTensor { - val resBuffer = DoubleArray(this.linearStructure.size) { i -> - this.buffer.array()[this.bufferStart + i].unaryMinus() + override fun TensorStructure.unaryMinus(): DoubleTensor { + val resBuffer = DoubleArray(tensor.numElements) { i -> + tensor.buffer.array()[tensor.bufferStart + i].unaryMinus() } - return DoubleTensor(this.shape, resBuffer) + return DoubleTensor(tensor.shape, resBuffer) } - override fun DoubleTensor.transpose(i: Int, j: Int): DoubleTensor { - val ii = minusIndex(i) - val jj = minusIndex(j) - checkTranspose(this.dimension, ii, jj) - val n = this.linearStructure.size + override fun TensorStructure.transpose(i: Int, j: Int): DoubleTensor { + val ii = tensor.minusIndex(i) + val jj = tensor.minusIndex(j) + checkTranspose(tensor.dimension, ii, jj) + val n = tensor.numElements val resBuffer = DoubleArray(n) - val resShape = this.shape.copyOf() + val resShape = tensor.shape.copyOf() resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] } val resTensor = DoubleTensor(resShape, resBuffer) for (offset in 0 until n) { - val oldMultiIndex = this.linearStructure.index(offset) + val oldMultiIndex = tensor.linearStructure.index(offset) val newMultiIndex = oldMultiIndex.copyOf() newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } val linearIndex = resTensor.linearStructure.offset(newMultiIndex) resTensor.buffer.array()[linearIndex] = - this.buffer.array()[this.bufferStart + offset] + tensor.buffer.array()[tensor.bufferStart + offset] } return resTensor } - override fun DoubleTensor.view(shape: IntArray): DoubleTensor { - checkView(this, shape) - return DoubleTensor(shape, this.buffer.array(), this.bufferStart) + override fun TensorStructure.view(shape: IntArray): DoubleTensor { + checkView(tensor, shape) + return DoubleTensor(shape, tensor.buffer.array(), tensor.bufferStart) } - override fun DoubleTensor.viewAs(other: DoubleTensor): DoubleTensor { - return this.view(other.shape) + override fun TensorStructure.viewAs(other: TensorStructure): DoubleTensor { + return tensor.view(other.shape) } - override infix fun DoubleTensor.dot(other: DoubleTensor): DoubleTensor { - if (this.shape.size == 1 && other.shape.size == 1) { - return DoubleTensor(intArrayOf(1), doubleArrayOf(this.times(other).buffer.array().sum())) + override infix fun TensorStructure.dot(other: TensorStructure): DoubleTensor { + if (tensor.shape.size == 1 && other.shape.size == 1) { + return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.buffer.array().sum())) } - var newThis = this.copy() + var newThis = tensor.copy() var newOther = other.copy() var penultimateDim = false var lastDim = false - if (this.shape.size == 1) { + if (tensor.shape.size == 1) { penultimateDim = true - newThis = this.view(intArrayOf(1) + this.shape) + newThis = tensor.view(intArrayOf(1) + tensor.shape) } if (other.shape.size == 1) { lastDim = true - newOther = other.view(other.shape + intArrayOf(1)) + newOther = other.tensor.view(other.shape + intArrayOf(1)) } - val broadcastTensors = broadcastOuterTensors(newThis, newOther) + val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor) newThis = broadcastTensors[0] newOther = broadcastTensors[1] @@ -284,10 +286,12 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra, offset: Int, dim1: Int, dim2: Int): + DoubleTensor { val n = diagonalEntries.shape.size + val d1 = minusIndexFrom(n + 1, dim1) + val d2 = minusIndexFrom(n + 1, dim2) + if (d1 == d2) { throw RuntimeException("Diagonal dimensions cannot be identical $d1, $d2") } @@ -300,7 +304,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra greaterDim) { realOffset *= -1 - lessDim = greaterDim.also {greaterDim = lessDim} + lessDim = greaterDim.also { greaterDim = lessDim } } val resShape = diagonalEntries.shape.slice(0 until lessDim).toIntArray() + @@ -310,13 +314,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra Double): DoubleTensor { + public fun TensorStructure.map(transform: (Double) -> Double): DoubleTensor { return DoubleTensor( - this.shape, - this.buffer.array().map { transform(it) }.toDoubleArray(), - this.bufferStart + tensor.shape, + tensor.buffer.array().map { transform(it) }.toDoubleArray(), + tensor.bufferStart ) } - public fun DoubleTensor.eq(other: DoubleTensor, delta: Double): Boolean { - return this.eq(other) { x, y -> abs(x - y) < delta } + public fun TensorStructure.eq(other: TensorStructure, delta: Double): Boolean { + return tensor.eq(other) { x, y -> abs(x - y) < delta } } - public fun DoubleTensor.eq(other: DoubleTensor): Boolean = this.eq(other, 1e-5) + public fun TensorStructure.eq(other: TensorStructure): Boolean = tensor.eq(other, 1e-5) - private fun DoubleTensor.eq(other: DoubleTensor, eqFunction: (Double, Double) -> Boolean): Boolean { - checkShapesCompatible(this, other) - val n = this.linearStructure.size - if (n != other.linearStructure.size) { + private fun TensorStructure.eq( + other: TensorStructure, + eqFunction: (Double, Double) -> Boolean + ): Boolean { + checkShapesCompatible(tensor, other) + val n = tensor.numElements + if (n != other.tensor.numElements) { return false } for (i in 0 until n) { - if (!eqFunction(this.buffer[this.bufferStart + i], other.buffer[other.bufferStart + i])) { + if (!eqFunction(tensor.buffer[tensor.bufferStart + i], other.tensor.buffer[other.tensor.bufferStart + i])) { return false } } @@ -362,8 +369,8 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra.randNormalLike(seed: Long = 0): DoubleTensor = + DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index 730b6ed9a..28828eb09 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -1,54 +1,40 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.tensors.TensorAlgebra import space.kscience.kmath.tensors.TensorStructure -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkEmptyShape(shape: IntArray): Unit = +internal inline fun checkEmptyShape(shape: IntArray): Unit = check(shape.isNotEmpty()) { "Illegal empty shape provided" } -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkEmptyDoubleBuffer(buffer: DoubleArray): Unit = +internal inline fun checkEmptyDoubleBuffer(buffer: DoubleArray): Unit = check(buffer.isNotEmpty()) { "Illegal empty buffer provided" } -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit = +internal inline fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray): Unit = check(buffer.size == shape.reduce(Int::times)) { "Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided" } -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkShapesCompatible(a: TensorType, b: TensorType): Unit = +internal inline fun checkShapesCompatible(a: TensorStructure, b: TensorStructure): Unit = check(a.shape contentEquals b.shape) { "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " } -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkTranspose(dim: Int, i: Int, j: Int): Unit = +internal inline fun checkTranspose(dim: Int, i: Int, j: Int): Unit = check((i < dim) and (j < dim)) { "Cannot transpose $i to $j for a tensor of dim $dim" } -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit = +internal inline fun checkView(a: TensorStructure, shape: IntArray): Unit = check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) -internal inline fun , - TorchTensorAlgebraType : TensorAlgebra> - TorchTensorAlgebraType.checkSquareMatrix(shape: IntArray): Unit { + +internal inline fun checkSquareMatrix(shape: IntArray): Unit { val n = shape.size check(n >= 2) { "Expected tensor with 2 or more dimensions, got size $n instead" @@ -58,16 +44,19 @@ internal inline fun , } } -internal inline fun DoubleTensorAlgebra.checkSymmetric(tensor: DoubleTensor, epsilon: Double = 1e-6): Unit = +internal inline fun DoubleTensorAlgebra.checkSymmetric( + tensor: TensorStructure, epsilon: Double = 1e-6 +): Unit = check(tensor.eq(tensor.transpose(), epsilon)) { "Tensor is not symmetric about the last 2 dimensions at precision $epsilon" } internal inline fun DoubleLinearOpsTensorAlgebra.checkPositiveDefinite( - tensor: DoubleTensor, epsilon: Double = 1e-6): Unit { + tensor: DoubleTensor, epsilon: Double = 1e-6 +): Unit { checkSymmetric(tensor, epsilon) - for( mat in tensor.matrixSequence()) - check(mat.asTensor().detLU().value() > 0.0){ + for (mat in tensor.matrixSequence()) + check(mat.asTensor().detLU().value() > 0.0) { "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" } } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt index ba5e0caaf..166f1d6c4 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/linutils.kt @@ -14,7 +14,7 @@ internal inline fun BufferedTensor.vectorSequence(): Sequence BufferedTensor.matrixSequence(): Sequence= 0) i else { ii } -internal inline fun BufferedTensor.minusIndex(i: Int): Int = minusIndexFrom(this.linearStructure.dim, i) +internal inline fun BufferedTensor.minusIndex(i: Int): Int = minusIndexFrom(this.dimension, i) internal inline fun format(value: Double, digits: Int = 4): String { val ten = 10.0 @@ -111,7 +111,7 @@ internal inline fun DoubleTensor.toPrettyString(): String = buildString { } offset += vectorSize // todo refactor - if (this@toPrettyString.numel == offset) { + if (this@toPrettyString.numElements == offset) { break } append(",\n")