diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt index 5c92c56c4..cd13e0752 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/AnalyticTensorAlgebra.kt @@ -6,7 +6,7 @@ package space.kscience.kmath.tensors.api /** - * Element-wise analytic operations on [TensorStructure]. + * Element-wise analytic operations on [Tensor]. * * @param T the type of items closed under analytic functions in the tensors. */ @@ -14,54 +14,54 @@ public interface AnalyticTensorAlgebra : TensorPartialDivisionAlgebra { //For information: https://pytorch.org/docs/stable/generated/torch.exp.html - public fun TensorStructure.exp(): TensorStructure + public fun Tensor.exp(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.log.html - public fun TensorStructure.log(): TensorStructure + public fun Tensor.log(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html - public fun TensorStructure.sqrt(): TensorStructure + public fun Tensor.sqrt(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos - public fun TensorStructure.cos(): TensorStructure + public fun Tensor.cos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos - public fun TensorStructure.acos(): TensorStructure + public fun Tensor.acos(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh - public fun TensorStructure.cosh(): TensorStructure + public fun Tensor.cosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh - public fun TensorStructure.acosh(): TensorStructure + public fun Tensor.acosh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin - public fun TensorStructure.sin(): TensorStructure + public fun Tensor.sin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin - public fun TensorStructure.asin(): TensorStructure + public fun Tensor.asin(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh - public fun TensorStructure.sinh(): TensorStructure + public fun Tensor.sinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh - public fun TensorStructure.asinh(): TensorStructure + public fun Tensor.asinh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan - public fun TensorStructure.tan(): TensorStructure + public fun Tensor.tan(): Tensor //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan - public fun TensorStructure.atan(): TensorStructure + public fun Tensor.atan(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh - public fun TensorStructure.tanh(): TensorStructure + public fun Tensor.tanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh - public fun TensorStructure.atanh(): TensorStructure + public fun Tensor.atanh(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil - public fun TensorStructure.ceil(): TensorStructure + public fun Tensor.ceil(): Tensor //For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor - public fun TensorStructure.floor(): TensorStructure + public fun Tensor.floor(): Tensor } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt index bcbb52a1b..527e5d386 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/LinearOpsTensorAlgebra.kt @@ -6,7 +6,7 @@ package space.kscience.kmath.tensors.api /** - * Common linear algebra operations. Operates on [TensorStructure]. + * Common linear algebra operations. Operates on [Tensor]. * * @param T the type of items closed under division in the tensors. */ @@ -19,7 +19,7 @@ public interface LinearOpsTensorAlgebra : * * @return the determinant. */ - public fun TensorStructure.det(): TensorStructure + public fun Tensor.det(): Tensor /** * Computes the multiplicative inverse matrix of a square matrix input, or of each square matrix in a batched input. @@ -29,7 +29,7 @@ public interface LinearOpsTensorAlgebra : * * @return the multiplicative inverse of a matrix. */ - public fun TensorStructure.inv(): TensorStructure + public fun Tensor.inv(): Tensor /** * Cholesky decomposition. @@ -44,7 +44,7 @@ public interface LinearOpsTensorAlgebra : * * @return the batch of L matrices. */ - public fun TensorStructure.cholesky(): TensorStructure + public fun Tensor.cholesky(): Tensor /** * QR decomposition. @@ -57,7 +57,7 @@ public interface LinearOpsTensorAlgebra : * * @return pair of Q and R tensors. */ - public fun TensorStructure.qr(): Pair, TensorStructure> + public fun Tensor.qr(): Pair, Tensor> /** * LUP decomposition @@ -70,7 +70,7 @@ public interface LinearOpsTensorAlgebra : * * * @return triple of P, L and U tensors */ - public fun TensorStructure.lu(): Triple, TensorStructure, TensorStructure> + public fun Tensor.lu(): Triple, Tensor, Tensor> /** * Singular Value Decomposition. @@ -83,7 +83,7 @@ public interface LinearOpsTensorAlgebra : * * @return the determinant. */ - public fun TensorStructure.svd(): Triple, TensorStructure, TensorStructure> + public fun Tensor.svd(): Triple, Tensor, Tensor> /** * Returns eigenvalues and eigenvectors of a real symmetric matrix input or a batch of real symmetric matrices, @@ -92,6 +92,6 @@ public interface LinearOpsTensorAlgebra : * * @return a pair (eigenvalues, eigenvectors) */ - public fun TensorStructure.symEig(): Pair, TensorStructure> + public fun Tensor.symEig(): Pair, Tensor> } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorStructure.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt similarity index 60% rename from kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorStructure.kt rename to kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt index edecd6383..179787684 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorStructure.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/Tensor.kt @@ -2,4 +2,4 @@ package space.kscience.kmath.tensors.api import space.kscience.kmath.nd.MutableStructureND -public typealias TensorStructure = MutableStructureND +public typealias Tensor = MutableStructureND diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt index 96d6985d8..b9c707c0b 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorAlgebra.kt @@ -8,19 +8,19 @@ package space.kscience.kmath.tensors.api import space.kscience.kmath.operations.Algebra /** - * Algebra over a ring on [TensorStructure]. + * Algebra over a ring on [Tensor]. * For more information: https://proofwiki.org/wiki/Definition:Algebra_over_Ring * * @param T the type of items in the tensors. */ -public interface TensorAlgebra: Algebra> { +public interface TensorAlgebra: Algebra> { /** * Returns a single tensor value of unit dimension. The tensor shape must be equal to [1]. * * @return the value of a scalar tensor. */ - public fun TensorStructure.value(): T + public fun Tensor.value(): T /** * Each element of the tensor [other] is added to this value. @@ -29,7 +29,7 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be added. * @return the sum of this value and tensor [other]. */ - public operator fun T.plus(other: TensorStructure): TensorStructure + public operator fun T.plus(other: Tensor): Tensor /** * Adds the scalar [value] to each element of this tensor and returns a new resulting tensor. @@ -37,7 +37,7 @@ public interface TensorAlgebra: Algebra> { * @param value the number to be added to each element of this tensor. * @return the sum of this tensor and [value]. */ - public operator fun TensorStructure.plus(value: T): TensorStructure + public operator fun Tensor.plus(value: T): Tensor /** * Each element of the tensor [other] is added to each element of this tensor. @@ -46,21 +46,21 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be added. * @return the sum of this tensor and [other]. */ - public operator fun TensorStructure.plus(other: TensorStructure): TensorStructure + public operator fun Tensor.plus(other: Tensor): Tensor /** * Adds the scalar [value] to each element of this tensor. * * @param value the number to be added to each element of this tensor. */ - public operator fun TensorStructure.plusAssign(value: T): Unit + public operator fun Tensor.plusAssign(value: T): Unit /** * Each element of the tensor [other] is added to each element of this tensor. * * @param other tensor to be added. */ - public operator fun TensorStructure.plusAssign(other: TensorStructure): Unit + public operator fun Tensor.plusAssign(other: Tensor): Unit /** @@ -70,7 +70,7 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be subtracted. * @return the difference between this value and tensor [other]. */ - public operator fun T.minus(other: TensorStructure): TensorStructure + public operator fun T.minus(other: Tensor): Tensor /** * Subtracts the scalar [value] from each element of this tensor and returns a new resulting tensor. @@ -78,7 +78,7 @@ public interface TensorAlgebra: Algebra> { * @param value the number to be subtracted from each element of this tensor. * @return the difference between this tensor and [value]. */ - public operator fun TensorStructure.minus(value: T): TensorStructure + public operator fun Tensor.minus(value: T): Tensor /** * Each element of the tensor [other] is subtracted from each element of this tensor. @@ -87,21 +87,21 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be subtracted. * @return the difference between this tensor and [other]. */ - public operator fun TensorStructure.minus(other: TensorStructure): TensorStructure + public operator fun Tensor.minus(other: Tensor): Tensor /** * Subtracts the scalar [value] from each element of this tensor. * * @param value the number to be subtracted from each element of this tensor. */ - public operator fun TensorStructure.minusAssign(value: T): Unit + public operator fun Tensor.minusAssign(value: T): Unit /** * Each element of the tensor [other] is subtracted from each element of this tensor. * * @param other tensor to be subtracted. */ - public operator fun TensorStructure.minusAssign(other: TensorStructure): Unit + public operator fun Tensor.minusAssign(other: Tensor): Unit /** @@ -111,7 +111,7 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be multiplied. * @return the product of this value and tensor [other]. */ - public operator fun T.times(other: TensorStructure): TensorStructure + public operator fun T.times(other: Tensor): Tensor /** * Multiplies the scalar [value] by each element of this tensor and returns a new resulting tensor. @@ -119,7 +119,7 @@ public interface TensorAlgebra: Algebra> { * @param value the number to be multiplied by each element of this tensor. * @return the product of this tensor and [value]. */ - public operator fun TensorStructure.times(value: T): TensorStructure + public operator fun Tensor.times(value: T): Tensor /** * Each element of the tensor [other] is multiplied by each element of this tensor. @@ -128,28 +128,28 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be multiplied. * @return the product of this tensor and [other]. */ - public operator fun TensorStructure.times(other: TensorStructure): TensorStructure + public operator fun Tensor.times(other: Tensor): Tensor /** * Multiplies the scalar [value] by each element of this tensor. * * @param value the number to be multiplied by each element of this tensor. */ - public operator fun TensorStructure.timesAssign(value: T): Unit + public operator fun Tensor.timesAssign(value: T): Unit /** * Each element of the tensor [other] is multiplied by each element of this tensor. * * @param other tensor to be multiplied. */ - public operator fun TensorStructure.timesAssign(other: TensorStructure): Unit + public operator fun Tensor.timesAssign(other: Tensor): Unit /** * Numerical negative, element-wise. * * @return tensor negation of the original tensor. */ - public operator fun TensorStructure.unaryMinus(): TensorStructure + public operator fun Tensor.unaryMinus(): Tensor /** * Returns the tensor at index i @@ -158,7 +158,7 @@ public interface TensorAlgebra: Algebra> { * @param i index of the extractable tensor * @return subtensor of the original tensor with index [i] */ - public operator fun TensorStructure.get(i: Int): TensorStructure + public operator fun Tensor.get(i: Int): Tensor /** * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. @@ -168,7 +168,7 @@ public interface TensorAlgebra: Algebra> { * @param j the second dimension to be transposed * @return transposed tensor */ - public fun TensorStructure.transpose(i: Int = -2, j: Int = -1): TensorStructure + public fun Tensor.transpose(i: Int = -2, j: Int = -1): Tensor /** * Returns a new tensor with the same data as the self tensor but of a different shape. @@ -178,7 +178,7 @@ public interface TensorAlgebra: Algebra> { * @param shape the desired size * @return tensor with new shape */ - public fun TensorStructure.view(shape: IntArray): TensorStructure + public fun Tensor.view(shape: IntArray): Tensor /** * View this tensor as the same size as [other]. @@ -188,7 +188,7 @@ public interface TensorAlgebra: Algebra> { * @param other the result tensor has the same size as other. * @return the result tensor with the same size as other. */ - public fun TensorStructure.viewAs(other: TensorStructure): TensorStructure + public fun Tensor.viewAs(other: Tensor): Tensor /** * Matrix product of two tensors. @@ -219,7 +219,7 @@ public interface TensorAlgebra: Algebra> { * @param other tensor to be multiplied * @return mathematical product of two tensors */ - public infix fun TensorStructure.dot(other: TensorStructure): TensorStructure + public infix fun Tensor.dot(other: Tensor): Tensor /** * Creates a tensor whose diagonals of certain 2D planes (specified by [dim1] and [dim2]) @@ -245,10 +245,10 @@ public interface TensorAlgebra: Algebra> { * are filled by [diagonalEntries] */ public fun diagonalEmbedding( - diagonalEntries: TensorStructure, + diagonalEntries: Tensor, offset: Int = 0, dim1: Int = -2, dim2: Int = -1 - ): TensorStructure + ): Tensor } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt index eccc6fa0d..921157963 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/api/TensorPartialDivisionAlgebra.kt @@ -6,7 +6,7 @@ package space.kscience.kmath.tensors.api /** - * Algebra over a field with partial division on [TensorStructure]. + * Algebra over a field with partial division on [Tensor]. * For more information: https://proofwiki.org/wiki/Definition:Division_Algebra * * @param T the type of items closed under division in the tensors. @@ -21,7 +21,7 @@ public interface TensorPartialDivisionAlgebra : * @param other tensor to divide by. * @return the division of this value by the tensor [other]. */ - public operator fun T.div(other: TensorStructure): TensorStructure + public operator fun T.div(other: Tensor): Tensor /** * Divide by the scalar [value] each element of this tensor returns a new resulting tensor. @@ -29,7 +29,7 @@ public interface TensorPartialDivisionAlgebra : * @param value the number to divide by each element of this tensor. * @return the division of this tensor by the [value]. */ - public operator fun TensorStructure.div(value: T): TensorStructure + public operator fun Tensor.div(value: T): Tensor /** * Each element of the tensor [other] is divided by each element of this tensor. @@ -38,19 +38,19 @@ public interface TensorPartialDivisionAlgebra : * @param other tensor to be divided by. * @return the division of this tensor by [other]. */ - public operator fun TensorStructure.div(other: TensorStructure): TensorStructure + public operator fun Tensor.div(other: Tensor): Tensor /** * Divides by the scalar [value] each element of this tensor. * * @param value the number to divide by each element of this tensor. */ - public operator fun TensorStructure.divAssign(value: T) + public operator fun Tensor.divAssign(value: T) /** * Each element of this tensor is divided by each element of the [other] tensor. * * @param other tensor to be divide by. */ - public operator fun TensorStructure.divAssign(other: TensorStructure) + public operator fun Tensor.divAssign(other: Tensor) } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt index 9541a97f9..d0882efb8 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BufferedTensor.kt @@ -2,7 +2,7 @@ package space.kscience.kmath.tensors.core import space.kscience.kmath.nd.MutableBufferND import space.kscience.kmath.structures.* -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.algebras.TensorLinearStructure @@ -10,7 +10,7 @@ public open class BufferedTensor( override val shape: IntArray, internal val mutableBuffer: MutableBuffer, internal val bufferStart: Int -) : TensorStructure { +) : Tensor { public val linearStructure: TensorLinearStructure get() = TensorLinearStructure(shape) @@ -53,33 +53,33 @@ internal fun BufferedTensor.asTensor(): IntTensor = internal fun BufferedTensor.asTensor(): DoubleTensor = DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) -internal fun TensorStructure.copyToBufferedTensor(): BufferedTensor = +internal fun Tensor.copyToBufferedTensor(): BufferedTensor = BufferedTensor( this.shape, TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().asMutableBuffer(), 0 ) -internal fun TensorStructure.toBufferedTensor(): BufferedTensor = when (this) { +internal fun Tensor.toBufferedTensor(): BufferedTensor = when (this) { is BufferedTensor -> this is MutableBufferND -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides) BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor() else -> this.copyToBufferedTensor() } -internal val TensorStructure.tensor: DoubleTensor +internal val Tensor.tensor: DoubleTensor get() = when (this) { is DoubleTensor -> this else -> this.toBufferedTensor().asTensor() } -internal val TensorStructure.tensor: IntTensor +internal val Tensor.tensor: IntTensor get() = when (this) { is IntTensor -> this else -> this.toBufferedTensor().asTensor() } -public fun TensorStructure.toDoubleTensor(): DoubleTensor = this.tensor -public fun TensorStructure.toIntTensor(): IntTensor = this.tensor +public fun Tensor.toDoubleTensor(): DoubleTensor = this.tensor +public fun Tensor.toIntTensor(): IntTensor = this.tensor public fun Array.toDoubleTensor(): DoubleTensor { val n = size diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt index 9b97d5ef2..873ec9027 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/BroadcastDoubleTensorAlgebra.kt @@ -5,7 +5,7 @@ package space.kscience.kmath.tensors.core.algebras -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.broadcastTensors import space.kscience.kmath.tensors.core.broadcastTo @@ -16,7 +16,7 @@ import space.kscience.kmath.tensors.core.broadcastTo */ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { - override fun TensorStructure.plus(other: TensorStructure): DoubleTensor { + override fun Tensor.plus(other: Tensor): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] @@ -26,7 +26,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun TensorStructure.plusAssign(other: TensorStructure) { + override fun Tensor.plusAssign(other: Tensor) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.linearStructure.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += @@ -34,7 +34,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun TensorStructure.minus(other: TensorStructure): DoubleTensor { + override fun Tensor.minus(other: Tensor): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] @@ -44,7 +44,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun TensorStructure.minusAssign(other: TensorStructure) { + override fun Tensor.minusAssign(other: Tensor) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.linearStructure.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= @@ -52,7 +52,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun TensorStructure.times(other: TensorStructure): DoubleTensor { + override fun Tensor.times(other: Tensor): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] @@ -63,7 +63,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun TensorStructure.timesAssign(other: TensorStructure) { + override fun Tensor.timesAssign(other: Tensor) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.linearStructure.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= @@ -71,7 +71,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } } - override fun TensorStructure.div(other: TensorStructure): DoubleTensor { + override fun Tensor.div(other: Tensor): DoubleTensor { val broadcast = broadcastTensors(tensor, other.tensor) val newThis = broadcast[0] val newOther = broadcast[1] @@ -82,7 +82,7 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { return DoubleTensor(newThis.shape, resBuffer) } - override fun TensorStructure.divAssign(other: TensorStructure) { + override fun Tensor.divAssign(other: Tensor) { val newOther = broadcastTo(other.tensor, tensor.shape) for (i in 0 until tensor.linearStructure.linearSize) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleAnalyticTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleAnalyticTensorAlgebra.kt index 4a942df84..9aa6f093e 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleAnalyticTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleAnalyticTensorAlgebra.kt @@ -6,7 +6,7 @@ package space.kscience.kmath.tensors.core.algebras import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.tensor import kotlin.math.* @@ -14,38 +14,38 @@ import kotlin.math.* public object DoubleAnalyticTensorAlgebra : AnalyticTensorAlgebra, DoubleTensorAlgebra() { - override fun TensorStructure.exp(): DoubleTensor = tensor.map(::exp) + override fun Tensor.exp(): DoubleTensor = tensor.map(::exp) - override fun TensorStructure.log(): DoubleTensor = tensor.map(::ln) + override fun Tensor.log(): DoubleTensor = tensor.map(::ln) - override fun TensorStructure.sqrt(): DoubleTensor = tensor.map(::sqrt) + override fun Tensor.sqrt(): DoubleTensor = tensor.map(::sqrt) - override fun TensorStructure.cos(): DoubleTensor = tensor.map(::cos) + override fun Tensor.cos(): DoubleTensor = tensor.map(::cos) - override fun TensorStructure.acos(): DoubleTensor = tensor.map(::acos) + override fun Tensor.acos(): DoubleTensor = tensor.map(::acos) - override fun TensorStructure.cosh(): DoubleTensor = tensor.map(::cosh) + override fun Tensor.cosh(): DoubleTensor = tensor.map(::cosh) - override fun TensorStructure.acosh(): DoubleTensor = tensor.map(::acosh) + override fun Tensor.acosh(): DoubleTensor = tensor.map(::acosh) - override fun TensorStructure.sin(): DoubleTensor = tensor.map(::sin) + override fun Tensor.sin(): DoubleTensor = tensor.map(::sin) - override fun TensorStructure.asin(): DoubleTensor = tensor.map(::asin) + override fun Tensor.asin(): DoubleTensor = tensor.map(::asin) - override fun TensorStructure.sinh(): DoubleTensor = tensor.map(::sinh) + override fun Tensor.sinh(): DoubleTensor = tensor.map(::sinh) - override fun TensorStructure.asinh(): DoubleTensor = tensor.map(::asinh) + override fun Tensor.asinh(): DoubleTensor = tensor.map(::asinh) - override fun TensorStructure.tan(): DoubleTensor = tensor.map(::tan) + override fun Tensor.tan(): DoubleTensor = tensor.map(::tan) - override fun TensorStructure.atan(): DoubleTensor = tensor.map(::atan) + override fun Tensor.atan(): DoubleTensor = tensor.map(::atan) - override fun TensorStructure.tanh(): DoubleTensor = tensor.map(::tanh) + override fun Tensor.tanh(): DoubleTensor = tensor.map(::tanh) - override fun TensorStructure.atanh(): DoubleTensor = tensor.map(::atanh) + override fun Tensor.atanh(): DoubleTensor = tensor.map(::atanh) - override fun TensorStructure.ceil(): DoubleTensor = tensor.map(::ceil) + override fun Tensor.ceil(): DoubleTensor = tensor.map(::ceil) - override fun TensorStructure.floor(): DoubleTensor = tensor.map(::floor) + override fun Tensor.floor(): DoubleTensor = tensor.map(::floor) } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt index dd5ad5a61..89345e315 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleLinearOpsTensorAlgebra.kt @@ -8,7 +8,7 @@ package space.kscience.kmath.tensors.core.algebras import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.nd.as1D import space.kscience.kmath.nd.as2D -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.checkSquareMatrix import space.kscience.kmath.tensors.core.choleskyHelper @@ -25,19 +25,19 @@ public object DoubleLinearOpsTensorAlgebra : LinearOpsTensorAlgebra, DoubleTensorAlgebra() { - override fun TensorStructure.inv(): DoubleTensor = invLU(1e-9) + override fun Tensor.inv(): DoubleTensor = invLU(1e-9) - override fun TensorStructure.det(): DoubleTensor = detLU(1e-9) + override fun Tensor.det(): DoubleTensor = detLU(1e-9) - public fun TensorStructure.luFactor(epsilon: Double): Pair = + public fun Tensor.luFactor(epsilon: Double): Pair = computeLU(tensor, epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") - public fun TensorStructure.luFactor(): Pair = luFactor(1e-9) + public fun Tensor.luFactor(): Pair = luFactor(1e-9) public fun luPivot( - luTensor: TensorStructure, - pivotsTensor: TensorStructure + luTensor: Tensor, + pivotsTensor: Tensor ): Triple { checkSquareMatrix(luTensor.shape) check( @@ -66,7 +66,7 @@ public object DoubleLinearOpsTensorAlgebra : return Triple(pTensor, lTensor, uTensor) } - public fun TensorStructure.cholesky(epsilon: Double): DoubleTensor { + public fun Tensor.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) checkPositiveDefinite(tensor, epsilon) @@ -79,9 +79,9 @@ public object DoubleLinearOpsTensorAlgebra : return lTensor } - override fun TensorStructure.cholesky(): DoubleTensor = cholesky(1e-6) + override fun Tensor.cholesky(): DoubleTensor = cholesky(1e-6) - override fun TensorStructure.qr(): Pair { + override fun Tensor.qr(): Pair { checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() @@ -95,10 +95,10 @@ public object DoubleLinearOpsTensorAlgebra : return qTensor to rTensor } - override fun TensorStructure.svd(): Triple = + override fun Tensor.svd(): Triple = svd(epsilon = 1e-10) - public fun TensorStructure.svd(epsilon: Double): Triple { + public fun Tensor.svd(epsilon: Double): Triple { val size = tensor.linearStructure.dim val commonShape = tensor.shape.sliceArray(0 until size - 2) val (n, m) = tensor.shape.sliceArray(size - 2 until size) @@ -122,11 +122,11 @@ public object DoubleLinearOpsTensorAlgebra : return Triple(uTensor.transpose(), sTensor, vTensor.transpose()) } - override fun TensorStructure.symEig(): Pair = + override fun Tensor.symEig(): Pair = symEig(epsilon = 1e-15) //For information: http://hua-zhou.github.io/teaching/biostatm280-2017spring/slides/16-eigsvd/eigsvd.html - public fun TensorStructure.symEig(epsilon: Double): Pair { + public fun Tensor.symEig(epsilon: Double): Pair { checkSymmetric(tensor, epsilon) val (u, s, v) = tensor.svd(epsilon) val shp = s.shape + intArrayOf(1) @@ -139,7 +139,7 @@ public object DoubleLinearOpsTensorAlgebra : return eig to v } - public fun TensorStructure.detLU(epsilon: Double = 1e-9): DoubleTensor { + public fun Tensor.detLU(epsilon: Double = 1e-9): DoubleTensor { checkSquareMatrix(tensor.shape) val luTensor = tensor.copy() @@ -164,7 +164,7 @@ public object DoubleLinearOpsTensorAlgebra : return detTensor } - public fun TensorStructure.invLU(epsilon: Double = 1e-9): DoubleTensor { + public fun Tensor.invLU(epsilon: Double = 1e-9): DoubleTensor { val (luTensor, pivotsTensor) = luFactor(epsilon) val invTensor = luTensor.zeroesLike() @@ -177,11 +177,11 @@ public object DoubleLinearOpsTensorAlgebra : return invTensor } - public fun TensorStructure.lu(epsilon: Double = 1e-9): Triple { + public fun Tensor.lu(epsilon: Double = 1e-9): Triple { val (lu, pivots) = this.luFactor(epsilon) return luPivot(lu, pivots) } - override fun TensorStructure.lu(): Triple = lu(1e-9) + override fun Tensor.lu(): Triple = lu(1e-9) } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt index d3e8bf175..4009f7b45 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/algebras/DoubleTensorAlgebra.kt @@ -7,7 +7,7 @@ package space.kscience.kmath.tensors.core.algebras import space.kscience.kmath.nd.as2D import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.* import space.kscience.kmath.tensors.core.broadcastOuterTensors import space.kscience.kmath.tensors.core.checkBufferShapeConsistency @@ -25,7 +25,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { public companion object : DoubleTensorAlgebra() - override fun TensorStructure.value(): Double { + override fun Tensor.value(): Double { check(tensor.shape contentEquals intArrayOf(1)) { "Inconsistent value for tensor of shape ${shape.toList()}" } @@ -39,7 +39,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(shape, buffer, 0) } - override operator fun TensorStructure.get(i: Int): DoubleTensor { + override operator fun Tensor.get(i: Int): DoubleTensor { val lastShape = tensor.shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart @@ -52,7 +52,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(shape, buffer) } - public fun TensorStructure.fullLike(value: Double): DoubleTensor { + public fun Tensor.fullLike(value: Double): DoubleTensor { val shape = tensor.shape val buffer = DoubleArray(tensor.numElements) { value } return DoubleTensor(shape, buffer) @@ -60,11 +60,11 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { public fun zeros(shape: IntArray): DoubleTensor = full(0.0, shape) - public fun TensorStructure.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) + public fun Tensor.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun ones(shape: IntArray): DoubleTensor = full(1.0, shape) - public fun TensorStructure.onesLike(): DoubleTensor = tensor.fullLike(1.0) + public fun Tensor.onesLike(): DoubleTensor = tensor.fullLike(1.0) public fun eye(n: Int): DoubleTensor { val shape = intArrayOf(n, n) @@ -76,20 +76,20 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return res } - public fun TensorStructure.copy(): DoubleTensor { + public fun Tensor.copy(): DoubleTensor { return DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) } - override fun Double.plus(other: TensorStructure): DoubleTensor { + override fun Double.plus(other: Tensor): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] + this } return DoubleTensor(other.shape, resBuffer) } - override fun TensorStructure.plus(value: Double): DoubleTensor = value + tensor + override fun Tensor.plus(value: Double): DoubleTensor = value + tensor - override fun TensorStructure.plus(other: TensorStructure): DoubleTensor { + override fun Tensor.plus(other: Tensor): DoubleTensor { checkShapesCompatible(tensor, other.tensor) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[i] + other.tensor.mutableBuffer.array()[i] @@ -97,13 +97,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.plusAssign(value: Double) { + override fun Tensor.plusAssign(value: Double) { for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += value } } - override fun TensorStructure.plusAssign(other: TensorStructure) { + override fun Tensor.plusAssign(other: Tensor) { checkShapesCompatible(tensor, other.tensor) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] += @@ -111,21 +111,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } } - override fun Double.minus(other: TensorStructure): DoubleTensor { + override fun Double.minus(other: Tensor): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> this - other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } - override fun TensorStructure.minus(value: Double): DoubleTensor { + override fun Tensor.minus(value: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i] - value } return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.minus(other: TensorStructure): DoubleTensor { + override fun Tensor.minus(other: Tensor): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[i] - other.tensor.mutableBuffer.array()[i] @@ -133,13 +133,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.minusAssign(value: Double) { + override fun Tensor.minusAssign(value: Double) { for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value } } - override fun TensorStructure.minusAssign(other: TensorStructure) { + override fun Tensor.minusAssign(other: Tensor) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] -= @@ -147,16 +147,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } } - override fun Double.times(other: TensorStructure): DoubleTensor { + override fun Double.times(other: Tensor): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] * this } return DoubleTensor(other.shape, resBuffer) } - override fun TensorStructure.times(value: Double): DoubleTensor = value * tensor + override fun Tensor.times(value: Double): DoubleTensor = value * tensor - override fun TensorStructure.times(other: TensorStructure): DoubleTensor { + override fun Tensor.times(other: Tensor): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i] * @@ -165,13 +165,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.timesAssign(value: Double) { + override fun Tensor.timesAssign(value: Double) { for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value } } - override fun TensorStructure.timesAssign(other: TensorStructure) { + override fun Tensor.timesAssign(other: Tensor) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] *= @@ -179,21 +179,21 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } } - override fun Double.div(other: TensorStructure): DoubleTensor { + override fun Double.div(other: Tensor): DoubleTensor { val resBuffer = DoubleArray(other.tensor.numElements) { i -> this / other.tensor.mutableBuffer.array()[other.tensor.bufferStart + i] } return DoubleTensor(other.shape, resBuffer) } - override fun TensorStructure.div(value: Double): DoubleTensor { + override fun Tensor.div(value: Double): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i] / value } return DoubleTensor(shape, resBuffer) } - override fun TensorStructure.div(other: TensorStructure): DoubleTensor { + override fun Tensor.div(other: Tensor): DoubleTensor { checkShapesCompatible(tensor, other) val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[other.tensor.bufferStart + i] / @@ -202,13 +202,13 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.divAssign(value: Double) { + override fun Tensor.divAssign(value: Double) { for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value } } - override fun TensorStructure.divAssign(other: TensorStructure) { + override fun Tensor.divAssign(other: Tensor) { checkShapesCompatible(tensor, other) for (i in 0 until tensor.numElements) { tensor.mutableBuffer.array()[tensor.bufferStart + i] /= @@ -216,14 +216,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } } - override fun TensorStructure.unaryMinus(): DoubleTensor { + override fun Tensor.unaryMinus(): DoubleTensor { val resBuffer = DoubleArray(tensor.numElements) { i -> tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() } return DoubleTensor(tensor.shape, resBuffer) } - override fun TensorStructure.transpose(i: Int, j: Int): DoubleTensor { + override fun Tensor.transpose(i: Int, j: Int): DoubleTensor { val ii = tensor.minusIndex(i) val jj = tensor.minusIndex(j) checkTranspose(tensor.dimension, ii, jj) @@ -248,16 +248,16 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } - override fun TensorStructure.view(shape: IntArray): DoubleTensor { + override fun Tensor.view(shape: IntArray): DoubleTensor { checkView(tensor, shape) return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) } - override fun TensorStructure.viewAs(other: TensorStructure): DoubleTensor { + override fun Tensor.viewAs(other: Tensor): DoubleTensor { return tensor.view(other.shape) } - override infix fun TensorStructure.dot(other: TensorStructure): DoubleTensor { + override infix fun Tensor.dot(other: Tensor): DoubleTensor { if (tensor.shape.size == 1 && other.shape.size == 1) { return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) } @@ -309,7 +309,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { return resTensor } - override fun diagonalEmbedding(diagonalEntries: TensorStructure, offset: Int, dim1: Int, dim2: Int): + override fun diagonalEmbedding(diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int): DoubleTensor { val n = diagonalEntries.shape.size val d1 = minusIndexFrom(n + 1, dim1) @@ -358,7 +358,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } - public fun TensorStructure.map(transform: (Double) -> Double): DoubleTensor { + public fun Tensor.map(transform: (Double) -> Double): DoubleTensor { return DoubleTensor( tensor.shape, tensor.mutableBuffer.array().map { transform(it) }.toDoubleArray(), @@ -366,14 +366,14 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { ) } - public fun TensorStructure.eq(other: TensorStructure, epsilon: Double): Boolean { + public fun Tensor.eq(other: Tensor, epsilon: Double): Boolean { return tensor.eq(other) { x, y -> abs(x - y) < epsilon } } - public infix fun TensorStructure.eq(other: TensorStructure): Boolean = tensor.eq(other, 1e-5) + public infix fun Tensor.eq(other: Tensor): Boolean = tensor.eq(other, 1e-5) - private fun TensorStructure.eq( - other: TensorStructure, + private fun Tensor.eq( + other: Tensor, eqFunction: (Double, Double) -> Boolean ): Boolean { checkShapesCompatible(tensor, other) @@ -396,7 +396,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { public fun randomNormal(shape: IntArray, seed: Long = 0): DoubleTensor = DoubleTensor(shape, getRandomNormals(shape.reduce(Int::times), seed)) - public fun TensorStructure.randomNormalLike(seed: Long = 0): DoubleTensor = + public fun Tensor.randomNormalLike(seed: Long = 0): DoubleTensor = DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) // stack tensors by axis 0 @@ -412,7 +412,7 @@ public open class DoubleTensorAlgebra : TensorPartialDivisionAlgebra { } // build tensor from this rows by given indices - public fun TensorStructure.rowsByIndices(indices: IntArray): DoubleTensor { + public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor { return stack(indices.map { this[it] }) } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt index b1c12ccde..f8bd5027a 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/checks.kt @@ -1,6 +1,6 @@ package space.kscience.kmath.tensors.core -import space.kscience.kmath.tensors.api.TensorStructure +import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.algebras.DoubleLinearOpsTensorAlgebra import space.kscience.kmath.tensors.core.algebras.DoubleTensorAlgebra @@ -20,7 +20,7 @@ internal fun checkBufferShapeConsistency(shape: IntArray, buffer: DoubleArray) = "Inconsistent shape ${shape.toList()} for buffer of size ${buffer.size} provided" } -internal fun checkShapesCompatible(a: TensorStructure, b: TensorStructure) = +internal fun checkShapesCompatible(a: Tensor, b: Tensor) = check(a.shape contentEquals b.shape) { "Incompatible shapes ${a.shape.toList()} and ${b.shape.toList()} " } @@ -30,7 +30,7 @@ internal fun checkTranspose(dim: Int, i: Int, j: Int) = "Cannot transpose $i to $j for a tensor of dim $dim" } -internal fun checkView(a: TensorStructure, shape: IntArray) = +internal fun checkView(a: Tensor, shape: IntArray) = check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) internal fun checkSquareMatrix(shape: IntArray) { @@ -44,7 +44,7 @@ internal fun checkSquareMatrix(shape: IntArray) { } internal fun DoubleTensorAlgebra.checkSymmetric( - tensor: TensorStructure, epsilon: Double = 1e-6 + tensor: Tensor, epsilon: Double = 1e-6 ) = check(tensor.eq(tensor.transpose(), epsilon)) { "Tensor is not symmetric about the last 2 dimensions at precision $epsilon"