From a9821772dbb021f52390428cc510c182f9a1bd16 Mon Sep 17 00:00:00 2001 From: Alexander Nozik Date: Mon, 5 Sep 2022 22:08:35 +0300 Subject: [PATCH] Move power to ExtendedFieldOps --- .../space/kscience/kmath/nd/DoubleFieldND.kt | 12 +- .../kmath/operations/DoubleBufferOps.kt | 6 + .../kscience/kmath/operations/numbers.kt | 20 +- .../kscience/kmath/nd4j/Nd4jTensorAlgebra.kt | 1 + .../core/BroadcastDoubleTensorAlgebra.kt | 41 ++- .../kmath/tensors/core/DoubleTensorAlgebra.kt | 320 +++++++++--------- .../kmath/tensors/core/IntTensorAlgebra.kt | 172 +++++----- .../kmath/tensors/core/internal/checks.kt | 4 +- .../tensors/core/internal/tensorCastsUtils.kt | 21 +- .../kmath/tensors/core/tensorCasts.kt | 14 +- .../kmath/tensors/core/TestDoubleTensor.kt | 10 +- 11 files changed, 320 insertions(+), 301 deletions(-) diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt index 4e8876731..aab137321 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd/DoubleFieldND.kt @@ -11,6 +11,7 @@ import space.kscience.kmath.operations.* import space.kscience.kmath.structures.DoubleBuffer import kotlin.contracts.InvocationKind import kotlin.contracts.contract +import kotlin.math.pow import kotlin.math.pow as kpow public class DoubleBufferND( @@ -165,6 +166,15 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND(D override fun atanh(arg: StructureND): DoubleBufferND = mapInline(arg.toBufferND()) { kotlin.math.atanh(it) } + override fun power( + arg: StructureND, + pow: Number, + ): StructureND = if (pow is Int) { + mapInline(arg.toBufferND()) { it.pow(pow) } + } else { + mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) } + } + public companion object : DoubleFieldOpsND() } @@ -181,7 +191,7 @@ public class DoubleFieldND(override val shape: Shape) : it.kpow(pow) } - override fun power(arg: StructureND, pow: Number): DoubleBufferND = if(pow.isInteger()){ + override fun power(arg: StructureND, pow: Number): DoubleBufferND = if (pow.isInteger()) { power(arg, pow.toInt()) } else { val dpow = pow.toDouble() diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt index 811926c23..ded8753a3 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/DoubleBufferOps.kt @@ -133,6 +133,12 @@ public abstract class DoubleBufferOps : BufferAlgebra, Exte override fun scale(a: Buffer, value: Double): DoubleBuffer = a.mapInline { it * value } + override fun power(arg: Buffer, pow: Number): Buffer = if (pow is Int) { + arg.mapInline { it.pow(pow) } + } else { + arg.mapInline { it.pow(pow.toDouble()) } + } + public companion object : DoubleBufferOps() { public inline fun Buffer.mapInline(block: (Double) -> Double): DoubleBuffer = if (this is DoubleBuffer) { diff --git a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt index 16088a825..224ca1daf 100644 --- a/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt +++ b/kmath-core/src/commonMain/kotlin/space/kscience/kmath/operations/numbers.kt @@ -2,12 +2,13 @@ * Copyright 2018-2022 KMath contributors. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ - +@file:Suppress("NOTHING_TO_INLINE") package space.kscience.kmath.operations import space.kscience.kmath.structures.* import kotlin.math.pow as kpow + /** * Advanced Number-like semifield that implements basic operations. */ @@ -15,7 +16,8 @@ public interface ExtendedFieldOps : FieldOps, TrigonometricOperations, ExponentialOperations, - ScaleOperations { + ScaleOperations, + PowerOperations { override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tanh(arg: T): T = sinh(arg) / cosh(arg) @@ -41,7 +43,7 @@ public interface ExtendedFieldOps : /** * Advanced Number-like field that implements basic operations. */ -public interface ExtendedField : ExtendedFieldOps, Field, PowerOperations, NumericAlgebra { +public interface ExtendedField : ExtendedFieldOps, Field, NumericAlgebra { override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) @@ -64,7 +66,7 @@ public interface ExtendedField : ExtendedFieldOps, Field, PowerOperatio /** * A field for [Double] without boxing. Does not produce appropriate field element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object DoubleField : ExtendedField, Norm, ScaleOperations { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::DoubleBuffer) @@ -124,7 +126,7 @@ public val Double.Companion.algebra: DoubleField get() = DoubleField /** * A field for [Float] without boxing. Does not produce appropriate field element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object FloatField : ExtendedField, Norm { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::FloatBuffer) @@ -180,7 +182,7 @@ public val Float.Companion.algebra: FloatField get() = FloatField /** * A field for [Int] without boxing. Does not produce corresponding ring element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object IntRing : Ring, Norm, NumericAlgebra { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::IntBuffer) @@ -203,7 +205,7 @@ public val Int.Companion.algebra: IntRing get() = IntRing /** * A field for [Short] without boxing. Does not produce appropriate ring element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object ShortRing : Ring, Norm, NumericAlgebra { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::ShortBuffer) @@ -226,7 +228,7 @@ public val Short.Companion.algebra: ShortRing get() = ShortRing /** * A field for [Byte] without boxing. Does not produce appropriate ring element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object ByteRing : Ring, Norm, NumericAlgebra { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::ByteBuffer) @@ -249,7 +251,7 @@ public val Byte.Companion.algebra: ByteRing get() = ByteRing /** * A field for [Double] without boxing. Does not produce appropriate ring element. */ -@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") +@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE") public object LongRing : Ring, Norm, NumericAlgebra { override val bufferFactory: MutableBufferFactory = MutableBufferFactory(::LongBuffer) diff --git a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt index b6a2029e7..18068eaec 100644 --- a/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt +++ b/kmath-nd4j/src/main/kotlin/space/kscience/kmath/nd4j/Nd4jTensorAlgebra.kt @@ -138,6 +138,7 @@ public sealed interface Nd4jTensorAlgebra> : AnalyticTe override fun StructureND.atan(): Nd4jArrayStructure = Transforms.atan(ndArray).wrap() override fun StructureND.tanh(): Nd4jArrayStructure = Transforms.tanh(ndArray).wrap() override fun StructureND.atanh(): Nd4jArrayStructure = Transforms.atanh(ndArray).wrap() + override fun power(arg: StructureND, pow: Number): StructureND = Transforms.pow(arg.ndArray, pow).wrap() override fun StructureND.ceil(): Nd4jArrayStructure = Transforms.ceil(ndArray).wrap() override fun StructureND.floor(): Nd4jArrayStructure = Transforms.floor(ndArray).wrap() override fun StructureND.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure = diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt index fce7d30fd..31e5b6e69 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/BroadcastDoubleTensorAlgebra.kt @@ -11,7 +11,6 @@ import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTo -import space.kscience.kmath.tensors.core.internal.tensor /** * Basic linear algebra operations implemented with broadcasting. @@ -20,7 +19,7 @@ import space.kscience.kmath.tensors.core.internal.tensor public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { override fun StructureND.plus(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor()) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -30,15 +29,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } override fun Tensor.plusAssign(arg: StructureND) { - val newOther = broadcastTo(arg.tensor, tensor.shape) - for (i in 0 until tensor.indices.linearSize) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] += - newOther.mutableBuffer.array()[tensor.bufferStart + i] + val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape) + for (i in 0 until asDoubleTensor().indices.linearSize) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] += + newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun StructureND.minus(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor()) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -48,15 +47,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } override fun Tensor.minusAssign(arg: StructureND) { - val newOther = broadcastTo(arg.tensor, tensor.shape) - for (i in 0 until tensor.indices.linearSize) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] -= - newOther.mutableBuffer.array()[tensor.bufferStart + i] + val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape) + for (i in 0 until asDoubleTensor().indices.linearSize) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -= + newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun StructureND.times(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor()) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -67,15 +66,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } override fun Tensor.timesAssign(arg: StructureND) { - val newOther = broadcastTo(arg.tensor, tensor.shape) - for (i in 0 until tensor.indices.linearSize) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] *= - newOther.mutableBuffer.array()[tensor.bufferStart + i] + val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape) + for (i in 0 until asDoubleTensor().indices.linearSize) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *= + newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun StructureND.div(arg: StructureND): DoubleTensor { - val broadcast = broadcastTensors(tensor, arg.tensor) + val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor()) val newThis = broadcast[0] val newOther = broadcast[1] val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> @@ -86,10 +85,10 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { } override fun Tensor.divAssign(arg: StructureND) { - val newOther = broadcastTo(arg.tensor, tensor.shape) - for (i in 0 until tensor.indices.linearSize) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] /= - newOther.mutableBuffer.array()[tensor.bufferStart + i] + val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape) + for (i in 0 until asDoubleTensor().indices.linearSize) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /= + newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt index 94ad3f43c..096c97324 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/DoubleTensorAlgebra.kt @@ -43,7 +43,7 @@ public open class DoubleTensorAlgebra : @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.map(transform: DoubleField.(Double) -> Double): DoubleTensor { - val tensor = this.tensor + val tensor = this.asDoubleTensor() //TODO remove additional copy val sourceArray = tensor.copyArray() val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) } @@ -57,7 +57,7 @@ public open class DoubleTensorAlgebra : @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor { - val tensor = this.tensor + val tensor = this.asDoubleTensor() //TODO remove additional copy val sourceArray = tensor.copyArray() val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) } @@ -77,9 +77,9 @@ public open class DoubleTensorAlgebra : require(left.shape.contentEquals(right.shape)) { "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" } - val leftTensor = left.tensor + val leftTensor = left.asDoubleTensor() val leftArray = leftTensor.copyArray() - val rightTensor = right.tensor + val rightTensor = right.asDoubleTensor() val rightArray = rightTensor.copyArray() val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) } return DoubleTensor( @@ -88,8 +88,8 @@ public open class DoubleTensorAlgebra : ) } - override fun StructureND.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) - tensor.mutableBuffer.array()[tensor.bufferStart] else null + override fun StructureND.valueOrNull(): Double? = if (asDoubleTensor().shape contentEquals intArrayOf(1)) + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart] else null override fun StructureND.value(): Double = valueOrNull() ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") @@ -121,10 +121,10 @@ public open class DoubleTensorAlgebra : ) override operator fun Tensor.get(i: Int): DoubleTensor { - val lastShape = tensor.shape.drop(1).toIntArray() + val lastShape = asDoubleTensor().shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) - val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart - return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart) + val newStart = newShape.reduce(Int::times) * i + asDoubleTensor().bufferStart + return DoubleTensor(newShape, asDoubleTensor().mutableBuffer.array(), newStart) } /** @@ -147,8 +147,8 @@ public open class DoubleTensorAlgebra : * @return tensor with the `input` tensor shape and filled with [value]. */ public fun Tensor.fullLike(value: Double): DoubleTensor { - val shape = tensor.shape - val buffer = DoubleArray(tensor.numElements) { value } + val shape = asDoubleTensor().shape + val buffer = DoubleArray(asDoubleTensor().numElements) { value } return DoubleTensor(shape, buffer) } @@ -165,7 +165,7 @@ public open class DoubleTensorAlgebra : * * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. */ - public fun StructureND.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) + public fun StructureND.zeroesLike(): DoubleTensor = asDoubleTensor().fullLike(0.0) /** * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. @@ -180,7 +180,7 @@ public open class DoubleTensorAlgebra : * * @return tensor filled with the scalar value `1.0`, with the same shape as `input` tensor. */ - public fun Tensor.onesLike(): DoubleTensor = tensor.fullLike(1.0) + public fun Tensor.onesLike(): DoubleTensor = asDoubleTensor().fullLike(1.0) /** * Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. @@ -204,182 +204,182 @@ public open class DoubleTensorAlgebra : * @return a copy of the `input` tensor with a copied buffer. */ public fun StructureND.copy(): DoubleTensor = - DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) + DoubleTensor(asDoubleTensor().shape, asDoubleTensor().mutableBuffer.array().copyOf(), asDoubleTensor().bufferStart) override fun Double.plus(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this + val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i -> + arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] + this } return DoubleTensor(arg.shape, resBuffer) } - override fun StructureND.plus(arg: Double): DoubleTensor = arg + tensor + override fun StructureND.plus(arg: Double): DoubleTensor = arg + asDoubleTensor() override fun StructureND.plus(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg.tensor) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] + checkShapesCompatible(asDoubleTensor(), arg.asDoubleTensor()) + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[i] + arg.asDoubleTensor().mutableBuffer.array()[i] } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun Tensor.plusAssign(value: Double) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] += value + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] += value } } override fun Tensor.plusAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg.tensor) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] += - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg.asDoubleTensor()) + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] += + arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun Double.minus(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i -> + this - arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] } return DoubleTensor(arg.shape, resBuffer) } override fun StructureND.minus(arg: Double): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] - arg } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun StructureND.minus(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] + checkShapesCompatible(asDoubleTensor(), arg) + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[i] - arg.asDoubleTensor().mutableBuffer.array()[i] } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun Tensor.minusAssign(value: Double) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -= value } } override fun Tensor.minusAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] -= - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg) + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -= + arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun Double.times(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this + val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i -> + arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] * this } return DoubleTensor(arg.shape, resBuffer) } - override fun StructureND.times(arg: Double): DoubleTensor = arg * tensor + override fun StructureND.times(arg: Double): DoubleTensor = arg * asDoubleTensor() override fun StructureND.times(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] * - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg) + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] * + arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun Tensor.timesAssign(value: Double) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *= value } } override fun Tensor.timesAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] *= - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg) + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *= + arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun Double.div(arg: StructureND): DoubleTensor { - val resBuffer = DoubleArray(arg.tensor.numElements) { i -> - this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i -> + this / arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] } return DoubleTensor(arg.shape, resBuffer) } override fun StructureND.div(arg: Double): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] / arg } return DoubleTensor(shape, resBuffer) } override fun StructureND.div(arg: StructureND): DoubleTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] / - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg) + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] / + arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun Tensor.divAssign(value: Double) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /= value } } override fun Tensor.divAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] /= - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asDoubleTensor(), arg) + for (i in 0 until asDoubleTensor().numElements) { + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /= + arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] } } override fun StructureND.unaryMinus(): DoubleTensor { - val resBuffer = DoubleArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() + val resBuffer = DoubleArray(asDoubleTensor().numElements) { i -> + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i].unaryMinus() } - return DoubleTensor(tensor.shape, resBuffer) + return DoubleTensor(asDoubleTensor().shape, resBuffer) } override fun Tensor.transpose(i: Int, j: Int): DoubleTensor { - val ii = tensor.minusIndex(i) - val jj = tensor.minusIndex(j) - checkTranspose(tensor.dimension, ii, jj) - val n = tensor.numElements + val ii = asDoubleTensor().minusIndex(i) + val jj = asDoubleTensor().minusIndex(j) + checkTranspose(asDoubleTensor().dimension, ii, jj) + val n = asDoubleTensor().numElements val resBuffer = DoubleArray(n) - val resShape = tensor.shape.copyOf() + val resShape = asDoubleTensor().shape.copyOf() resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] } val resTensor = DoubleTensor(resShape, resBuffer) for (offset in 0 until n) { - val oldMultiIndex = tensor.indices.index(offset) + val oldMultiIndex = asDoubleTensor().indices.index(offset) val newMultiIndex = oldMultiIndex.copyOf() newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } val linearIndex = resTensor.indices.offset(newMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = - tensor.mutableBuffer.array()[tensor.bufferStart + offset] + asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + offset] } return resTensor } override fun Tensor.view(shape: IntArray): DoubleTensor { - checkView(tensor, shape) - return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) + checkView(asDoubleTensor(), shape) + return DoubleTensor(shape, asDoubleTensor().mutableBuffer.array(), asDoubleTensor().bufferStart) } override fun Tensor.viewAs(other: StructureND): DoubleTensor = - tensor.view(other.shape) + asDoubleTensor().view(other.shape) /** * Broadcasting Matrix product of two tensors. @@ -412,25 +412,25 @@ public open class DoubleTensorAlgebra : */ @UnstableKMathAPI public infix fun StructureND.matmul(other: StructureND): DoubleTensor { - if (tensor.shape.size == 1 && other.shape.size == 1) { - return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) + if (asDoubleTensor().shape.size == 1 && other.shape.size == 1) { + return DoubleTensor(intArrayOf(1), doubleArrayOf(asDoubleTensor().times(other).asDoubleTensor().mutableBuffer.array().sum())) } - var newThis = tensor.copy() + var newThis = asDoubleTensor().copy() var newOther = other.copy() var penultimateDim = false var lastDim = false - if (tensor.shape.size == 1) { + if (asDoubleTensor().shape.size == 1) { penultimateDim = true - newThis = tensor.view(intArrayOf(1) + tensor.shape) + newThis = asDoubleTensor().view(intArrayOf(1) + asDoubleTensor().shape) } if (other.shape.size == 1) { lastDim = true - newOther = other.tensor.view(other.shape + intArrayOf(1)) + newOther = other.asDoubleTensor().view(other.shape + intArrayOf(1)) } - val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor) + val broadcastTensors = broadcastOuterTensors(newThis.asDoubleTensor(), newOther.asDoubleTensor()) newThis = broadcastTensors[0] newOther = broadcastTensors[1] @@ -497,8 +497,8 @@ public open class DoubleTensorAlgebra : diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray() val resTensor = zeros(resShape) - for (i in 0 until diagonalEntries.tensor.numElements) { - val multiIndex = diagonalEntries.tensor.indices.index(i) + for (i in 0 until diagonalEntries.asDoubleTensor().numElements) { + val multiIndex = diagonalEntries.asDoubleTensor().indices.index(i) var offset1 = 0 var offset2 = abs(realOffset) @@ -514,7 +514,7 @@ public open class DoubleTensorAlgebra : resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex] } - return resTensor.tensor + return resTensor.asDoubleTensor() } /** @@ -525,7 +525,7 @@ public open class DoubleTensorAlgebra : * @return true if two tensors have the same shape and elements, false otherwise. */ public fun Tensor.eq(other: Tensor, epsilon: Double): Boolean = - tensor.eq(other) { x, y -> abs(x - y) < epsilon } + asDoubleTensor().eq(other) { x, y -> abs(x - y) < epsilon } /** * Compares element-wise two tensors. @@ -534,21 +534,21 @@ public open class DoubleTensorAlgebra : * @param other the tensor to compare with `input` tensor. * @return true if two tensors have the same shape and elements, false otherwise. */ - public infix fun Tensor.eq(other: Tensor): Boolean = tensor.eq(other, 1e-5) + public infix fun Tensor.eq(other: Tensor): Boolean = asDoubleTensor().eq(other, 1e-5) private fun Tensor.eq( other: Tensor, eqFunction: (Double, Double) -> Boolean, ): Boolean { - checkShapesCompatible(tensor, other) - val n = tensor.numElements - if (n != other.tensor.numElements) { + checkShapesCompatible(asDoubleTensor(), other) + val n = asDoubleTensor().numElements + if (n != other.asDoubleTensor().numElements) { return false } for (i in 0 until n) { if (!eqFunction( - tensor.mutableBuffer[tensor.bufferStart + i], - other.tensor.mutableBuffer[other.tensor.bufferStart + i] + asDoubleTensor().mutableBuffer[asDoubleTensor().bufferStart + i], + other.asDoubleTensor().mutableBuffer[other.asDoubleTensor().bufferStart + i] ) ) { return false @@ -578,7 +578,7 @@ public open class DoubleTensorAlgebra : * with `0.0` mean and `1.0` standard deviation. */ public fun Tensor.randomNormalLike(seed: Long = 0): DoubleTensor = - DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) + DoubleTensor(asDoubleTensor().shape, getRandomNormals(asDoubleTensor().shape.reduce(Int::times), seed)) /** * Concatenates a sequence of tensors with equal shapes along the first dimension. @@ -592,7 +592,7 @@ public open class DoubleTensorAlgebra : check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } val resShape = intArrayOf(tensors.size) + shape val resBuffer = tensors.flatMap { - it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements) + it.asDoubleTensor().mutableBuffer.array().drop(it.asDoubleTensor().bufferStart).take(it.asDoubleTensor().numElements) }.toDoubleArray() return DoubleTensor(resShape, resBuffer, 0) } @@ -606,7 +606,7 @@ public open class DoubleTensorAlgebra : public fun Tensor.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) private inline fun StructureND.fold(foldFunction: (DoubleArray) -> Double): Double = - foldFunction(tensor.copyArray()) + foldFunction(asDoubleTensor().copyArray()) private inline fun StructureND.foldDim( dim: Int, @@ -629,44 +629,44 @@ public open class DoubleTensorAlgebra : val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> - tensor[prefix + intArrayOf(i) + suffix] + asDoubleTensor()[prefix + intArrayOf(i) + suffix] }) } return resTensor } - override fun StructureND.sum(): Double = tensor.fold { it.sum() } + override fun StructureND.sum(): Double = asDoubleTensor().fold { it.sum() } override fun StructureND.sum(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim(dim, keepDim) { x -> x.sum() }.toDoubleTensor() + foldDim(dim, keepDim) { x -> x.sum() }.asDoubleTensor() override fun StructureND.min(): Double = this.fold { it.minOrNull()!! } override fun StructureND.min(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim(dim, keepDim) { x -> x.minOrNull()!! }.toDoubleTensor() + foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asDoubleTensor() override fun StructureND.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND.max(dim: Int, keepDim: Boolean): DoubleTensor = - foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.toDoubleTensor() + foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.asDoubleTensor() override fun StructureND.argMax(dim: Int, keepDim: Boolean): IntTensor = foldDim(dim, keepDim) { x -> x.withIndex().maxByOrNull { it.value }?.index!! - }.toIntTensor() + }.asIntTensor() - override fun StructureND.mean(): Double = this.fold { it.sum() / tensor.numElements } + override fun StructureND.mean(): Double = this.fold { it.sum() / asDoubleTensor().numElements } override fun StructureND.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(dim, keepDim) { arr -> check(dim < dimension) { "Dimension $dim out of range $dimension" } arr.sum() / shape[dim] - }.toDoubleTensor() + }.asDoubleTensor() override fun StructureND.std(): Double = fold { arr -> - val mean = arr.sum() / tensor.numElements - sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) + val mean = arr.sum() / asDoubleTensor().numElements + sqrt(arr.sumOf { (it - mean) * (it - mean) } / (asDoubleTensor().numElements - 1)) } override fun StructureND.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( @@ -676,11 +676,11 @@ public open class DoubleTensorAlgebra : check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)) - }.toDoubleTensor() + }.asDoubleTensor() override fun StructureND.variance(): Double = fold { arr -> - val mean = arr.sum() / tensor.numElements - arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) + val mean = arr.sum() / asDoubleTensor().numElements + arr.sumOf { (it - mean) * (it - mean) } / (asDoubleTensor().numElements - 1) } override fun StructureND.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( @@ -690,7 +690,7 @@ public open class DoubleTensorAlgebra : check(dim < dimension) { "Dimension $dim out of range $dimension" } val mean = arr.sum() / shape[dim] arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1) - }.toDoubleTensor() + }.asDoubleTensor() private fun cov(x: DoubleTensor, y: DoubleTensor): Double { val n = x.shape[0] @@ -716,45 +716,51 @@ public open class DoubleTensorAlgebra : ) for (i in 0 until n) { for (j in 0 until n) { - resTensor[intArrayOf(i, j)] = cov(tensors[i].tensor, tensors[j].tensor) + resTensor[intArrayOf(i, j)] = cov(tensors[i].asDoubleTensor(), tensors[j].asDoubleTensor()) } } return resTensor } - override fun StructureND.exp(): DoubleTensor = tensor.map { exp(it) } + override fun StructureND.exp(): DoubleTensor = asDoubleTensor().map { exp(it) } - override fun StructureND.ln(): DoubleTensor = tensor.map { ln(it) } + override fun StructureND.ln(): DoubleTensor = asDoubleTensor().map { ln(it) } - override fun StructureND.sqrt(): DoubleTensor = tensor.map { sqrt(it) } + override fun StructureND.sqrt(): DoubleTensor = asDoubleTensor().map { sqrt(it) } - override fun StructureND.cos(): DoubleTensor = tensor.map { cos(it) } + override fun StructureND.cos(): DoubleTensor = asDoubleTensor().map { cos(it) } - override fun StructureND.acos(): DoubleTensor = tensor.map { acos(it) } + override fun StructureND.acos(): DoubleTensor = asDoubleTensor().map { acos(it) } - override fun StructureND.cosh(): DoubleTensor = tensor.map { cosh(it) } + override fun StructureND.cosh(): DoubleTensor = asDoubleTensor().map { cosh(it) } - override fun StructureND.acosh(): DoubleTensor = tensor.map { acosh(it) } + override fun StructureND.acosh(): DoubleTensor = asDoubleTensor().map { acosh(it) } - override fun StructureND.sin(): DoubleTensor = tensor.map { sin(it) } + override fun StructureND.sin(): DoubleTensor = asDoubleTensor().map { sin(it) } - override fun StructureND.asin(): DoubleTensor = tensor.map { asin(it) } + override fun StructureND.asin(): DoubleTensor = asDoubleTensor().map { asin(it) } - override fun StructureND.sinh(): DoubleTensor = tensor.map { sinh(it) } + override fun StructureND.sinh(): DoubleTensor = asDoubleTensor().map { sinh(it) } - override fun StructureND.asinh(): DoubleTensor = tensor.map { asinh(it) } + override fun StructureND.asinh(): DoubleTensor = asDoubleTensor().map { asinh(it) } - override fun StructureND.tan(): DoubleTensor = tensor.map { tan(it) } + override fun StructureND.tan(): DoubleTensor = asDoubleTensor().map { tan(it) } - override fun StructureND.atan(): DoubleTensor = tensor.map { atan(it) } + override fun StructureND.atan(): DoubleTensor = asDoubleTensor().map { atan(it) } - override fun StructureND.tanh(): DoubleTensor = tensor.map { tanh(it) } + override fun StructureND.tanh(): DoubleTensor = asDoubleTensor().map { tanh(it) } - override fun StructureND.atanh(): DoubleTensor = tensor.map { atanh(it) } + override fun StructureND.atanh(): DoubleTensor = asDoubleTensor().map { atanh(it) } - override fun StructureND.ceil(): DoubleTensor = tensor.map { ceil(it) } + override fun power(arg: StructureND, pow: Number): StructureND = if (pow is Int) { + arg.map { it.pow(pow) } + } else { + arg.map { it.pow(pow.toDouble()) } + } - override fun StructureND.floor(): DoubleTensor = tensor.map { floor(it) } + override fun StructureND.ceil(): DoubleTensor = asDoubleTensor().map { ceil(it) } + + override fun StructureND.floor(): DoubleTensor = asDoubleTensor().map { floor(it) } override fun StructureND.inv(): DoubleTensor = invLU(1e-9) @@ -770,7 +776,7 @@ public open class DoubleTensorAlgebra : * The `pivots` has the shape ``(∗, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. */ public fun StructureND.luFactor(epsilon: Double): Pair = - computeLU(tensor, epsilon) + computeLU(asDoubleTensor(), epsilon) ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") /** @@ -809,7 +815,7 @@ public open class DoubleTensorAlgebra : val pTensor = luTensor.zeroesLike() pTensor .matrixSequence() - .zip(pivotsTensor.tensor.vectorSequence()) + .zip(pivotsTensor.asIntTensor().vectorSequence()) .forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) } val lTensor = luTensor.zeroesLike() @@ -817,7 +823,7 @@ public open class DoubleTensorAlgebra : lTensor.matrixSequence() .zip(uTensor.matrixSequence()) - .zip(luTensor.tensor.matrixSequence()) + .zip(luTensor.asDoubleTensor().matrixSequence()) .forEach { (pairLU, lu) -> val (l, u) = pairLU luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n) @@ -841,12 +847,12 @@ public open class DoubleTensorAlgebra : */ public fun StructureND.cholesky(epsilon: Double): DoubleTensor { checkSquareMatrix(shape) - checkPositiveDefinite(tensor, epsilon) + checkPositiveDefinite(asDoubleTensor(), epsilon) val n = shape.last() val lTensor = zeroesLike() - for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence())) + for ((a, l) in asDoubleTensor().matrixSequence().zip(lTensor.matrixSequence())) for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n) return lTensor @@ -858,13 +864,13 @@ public open class DoubleTensorAlgebra : checkSquareMatrix(shape) val qTensor = zeroesLike() val rTensor = zeroesLike() - tensor.matrixSequence() + asDoubleTensor().matrixSequence() .zip( (qTensor.matrixSequence() .zip(rTensor.matrixSequence())) ).forEach { (matrix, qr) -> val (q, r) = qr - qrHelper(matrix.asTensor(), q.asTensor(), r.as2D()) + qrHelper(matrix.toTensor(), q.toTensor(), r.as2D()) } return qTensor to rTensor @@ -887,14 +893,14 @@ public open class DoubleTensorAlgebra : * @return a triple `Triple(U, S, V)`. */ public fun StructureND.svd(epsilon: Double): Triple { - val size = tensor.dimension - val commonShape = tensor.shape.sliceArray(0 until size - 2) - val (n, m) = tensor.shape.sliceArray(size - 2 until size) + val size = asDoubleTensor().dimension + val commonShape = asDoubleTensor().shape.sliceArray(0 until size - 2) + val (n, m) = asDoubleTensor().shape.sliceArray(size - 2 until size) val uTensor = zeros(commonShape + intArrayOf(min(n, m), n)) val sTensor = zeros(commonShape + intArrayOf(min(n, m))) val vTensor = zeros(commonShape + intArrayOf(min(n, m), m)) - val matrices = tensor.matrices + val matrices = asDoubleTensor().matrices val uTensors = uTensor.matrices val sTensorVectors = sTensor.vectors val vTensors = vTensor.matrices @@ -931,7 +937,7 @@ public open class DoubleTensorAlgebra : * @return a pair `eigenvalues to eigenvectors`. */ public fun StructureND.symEigSvd(epsilon: Double): Pair { - checkSymmetric(tensor, epsilon) + checkSymmetric(asDoubleTensor(), epsilon) fun MutableStructure2D.cleanSym(n: Int) { for (i in 0 until n) { @@ -945,7 +951,7 @@ public open class DoubleTensorAlgebra : } } - val (u, s, v) = tensor.svd(epsilon) + val (u, s, v) = asDoubleTensor().svd(epsilon) val shp = s.shape + intArrayOf(1) val utv = u.transpose() matmul v val n = s.shape.last() @@ -958,7 +964,7 @@ public open class DoubleTensorAlgebra : } public fun StructureND.symEigJacobi(maxIteration: Int, epsilon: Double): Pair { - checkSymmetric(tensor, epsilon) + checkSymmetric(asDoubleTensor(), epsilon) val size = this.dimension val eigenvectors = zeros(this.shape) @@ -966,7 +972,7 @@ public open class DoubleTensorAlgebra : var eigenvalueStart = 0 var eigenvectorStart = 0 - for (matrix in tensor.matrixSequence()) { + for (matrix in asDoubleTensor().matrixSequence()) { val matrix2D = matrix.as2D() val (d, v) = matrix2D.jacobiHelper(maxIteration, epsilon) @@ -1111,9 +1117,9 @@ public open class DoubleTensorAlgebra : * @return the determinant. */ public fun StructureND.detLU(epsilon: Double = 1e-9): DoubleTensor { - checkSquareMatrix(tensor.shape) - val luTensor = tensor.copy() - val pivotsTensor = tensor.setUpPivots() + checkSquareMatrix(asDoubleTensor().shape) + val luTensor = asDoubleTensor().copy() + val pivotsTensor = asDoubleTensor().setUpPivots() val n = shape.size @@ -1169,14 +1175,14 @@ public open class DoubleTensorAlgebra : * @return triple of `P`, `L` and `U` tensors. */ public fun StructureND.lu(epsilon: Double = 1e-9): Triple { - val (lu, pivots) = tensor.luFactor(epsilon) + val (lu, pivots) = asDoubleTensor().luFactor(epsilon) return luPivot(lu, pivots) } override fun StructureND.lu(): Triple = lu(1e-9) } -public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra -public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra +public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra +public val DoubleField.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt index 0c8ddcf87..194f2bd18 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/IntTensorAlgebra.kt @@ -39,7 +39,7 @@ public open class IntTensorAlgebra : TensorAlgebra { @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.map(transform: IntRing.(Int) -> Int): IntTensor { - val tensor = this.tensor + val tensor = this.asIntTensor() //TODO remove additional copy val sourceArray = tensor.copyArray() val array = IntArray(tensor.numElements) { IntRing.transform(sourceArray[it]) } @@ -53,7 +53,7 @@ public open class IntTensorAlgebra : TensorAlgebra { @PerformancePitfall @Suppress("OVERRIDE_BY_INLINE") final override inline fun StructureND.mapIndexed(transform: IntRing.(index: IntArray, Int) -> Int): IntTensor { - val tensor = this.tensor + val tensor = this.asIntTensor() //TODO remove additional copy val sourceArray = tensor.copyArray() val array = IntArray(tensor.numElements) { IntRing.transform(tensor.indices.index(it), sourceArray[it]) } @@ -73,9 +73,9 @@ public open class IntTensorAlgebra : TensorAlgebra { require(left.shape.contentEquals(right.shape)) { "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" } - val leftTensor = left.tensor + val leftTensor = left.asIntTensor() val leftArray = leftTensor.copyArray() - val rightTensor = right.tensor + val rightTensor = right.asIntTensor() val rightArray = rightTensor.copyArray() val array = IntArray(leftTensor.numElements) { IntRing.transform(leftArray[it], rightArray[it]) } return IntTensor( @@ -84,8 +84,8 @@ public open class IntTensorAlgebra : TensorAlgebra { ) } - override fun StructureND.valueOrNull(): Int? = if (tensor.shape contentEquals intArrayOf(1)) - tensor.mutableBuffer.array()[tensor.bufferStart] else null + override fun StructureND.valueOrNull(): Int? = if (asIntTensor().shape contentEquals intArrayOf(1)) + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart] else null override fun StructureND.value(): Int = valueOrNull() ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") @@ -119,10 +119,10 @@ public open class IntTensorAlgebra : TensorAlgebra { ) override operator fun Tensor.get(i: Int): IntTensor { - val lastShape = tensor.shape.drop(1).toIntArray() + val lastShape = asIntTensor().shape.drop(1).toIntArray() val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) - val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart - return IntTensor(newShape, tensor.mutableBuffer.array(), newStart) + val newStart = newShape.reduce(Int::times) * i + asIntTensor().bufferStart + return IntTensor(newShape, asIntTensor().mutableBuffer.array(), newStart) } /** @@ -145,8 +145,8 @@ public open class IntTensorAlgebra : TensorAlgebra { * @return tensor with the `input` tensor shape and filled with [value]. */ public fun Tensor.fullLike(value: Int): IntTensor { - val shape = tensor.shape - val buffer = IntArray(tensor.numElements) { value } + val shape = asIntTensor().shape + val buffer = IntArray(asIntTensor().numElements) { value } return IntTensor(shape, buffer) } @@ -163,7 +163,7 @@ public open class IntTensorAlgebra : TensorAlgebra { * * @return tensor filled with the scalar value `0`, with the same shape as `input` tensor. */ - public fun StructureND.zeroesLike(): IntTensor = tensor.fullLike(0) + public fun StructureND.zeroesLike(): IntTensor = asIntTensor().fullLike(0) /** * Returns a tensor filled with the scalar value `1`, with the shape defined by the variable argument [shape]. @@ -178,7 +178,7 @@ public open class IntTensorAlgebra : TensorAlgebra { * * @return tensor filled with the scalar value `1`, with the same shape as `input` tensor. */ - public fun Tensor.onesLike(): IntTensor = tensor.fullLike(1) + public fun Tensor.onesLike(): IntTensor = asIntTensor().fullLike(1) /** * Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. @@ -202,145 +202,145 @@ public open class IntTensorAlgebra : TensorAlgebra { * @return a copy of the `input` tensor with a copied buffer. */ public fun StructureND.copy(): IntTensor = - IntTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) + IntTensor(asIntTensor().shape, asIntTensor().mutableBuffer.array().copyOf(), asIntTensor().bufferStart) override fun Int.plus(arg: StructureND): IntTensor { - val resBuffer = IntArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this + val resBuffer = IntArray(arg.asIntTensor().numElements) { i -> + arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] + this } return IntTensor(arg.shape, resBuffer) } - override fun StructureND.plus(arg: Int): IntTensor = arg + tensor + override fun StructureND.plus(arg: Int): IntTensor = arg + asIntTensor() override fun StructureND.plus(arg: StructureND): IntTensor { - checkShapesCompatible(tensor, arg.tensor) - val resBuffer = IntArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] + checkShapesCompatible(asIntTensor(), arg.asIntTensor()) + val resBuffer = IntArray(asIntTensor().numElements) { i -> + asIntTensor().mutableBuffer.array()[i] + arg.asIntTensor().mutableBuffer.array()[i] } - return IntTensor(tensor.shape, resBuffer) + return IntTensor(asIntTensor().shape, resBuffer) } override fun Tensor.plusAssign(value: Int) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] += value + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] += value } } override fun Tensor.plusAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg.tensor) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] += - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asIntTensor(), arg.asIntTensor()) + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] += + arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] } } override fun Int.minus(arg: StructureND): IntTensor { - val resBuffer = IntArray(arg.tensor.numElements) { i -> - this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + val resBuffer = IntArray(arg.asIntTensor().numElements) { i -> + this - arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] } return IntTensor(arg.shape, resBuffer) } override fun StructureND.minus(arg: Int): IntTensor { - val resBuffer = IntArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg + val resBuffer = IntArray(asIntTensor().numElements) { i -> + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] - arg } - return IntTensor(tensor.shape, resBuffer) + return IntTensor(asIntTensor().shape, resBuffer) } override fun StructureND.minus(arg: StructureND): IntTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = IntArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] + checkShapesCompatible(asIntTensor(), arg) + val resBuffer = IntArray(asIntTensor().numElements) { i -> + asIntTensor().mutableBuffer.array()[i] - arg.asIntTensor().mutableBuffer.array()[i] } - return IntTensor(tensor.shape, resBuffer) + return IntTensor(asIntTensor().shape, resBuffer) } override fun Tensor.minusAssign(value: Int) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] -= value } } override fun Tensor.minusAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] -= - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asIntTensor(), arg) + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] -= + arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] } } override fun Int.times(arg: StructureND): IntTensor { - val resBuffer = IntArray(arg.tensor.numElements) { i -> - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this + val resBuffer = IntArray(arg.asIntTensor().numElements) { i -> + arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] * this } return IntTensor(arg.shape, resBuffer) } - override fun StructureND.times(arg: Int): IntTensor = arg * tensor + override fun StructureND.times(arg: Int): IntTensor = arg * asIntTensor() override fun StructureND.times(arg: StructureND): IntTensor { - checkShapesCompatible(tensor, arg) - val resBuffer = IntArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i] * - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + checkShapesCompatible(asIntTensor(), arg) + val resBuffer = IntArray(asIntTensor().numElements) { i -> + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] * + arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] } - return IntTensor(tensor.shape, resBuffer) + return IntTensor(asIntTensor().shape, resBuffer) } override fun Tensor.timesAssign(value: Int) { - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] *= value } } override fun Tensor.timesAssign(arg: StructureND) { - checkShapesCompatible(tensor, arg) - for (i in 0 until tensor.numElements) { - tensor.mutableBuffer.array()[tensor.bufferStart + i] *= - arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] + checkShapesCompatible(asIntTensor(), arg) + for (i in 0 until asIntTensor().numElements) { + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] *= + arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] } } override fun StructureND.unaryMinus(): IntTensor { - val resBuffer = IntArray(tensor.numElements) { i -> - tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() + val resBuffer = IntArray(asIntTensor().numElements) { i -> + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i].unaryMinus() } - return IntTensor(tensor.shape, resBuffer) + return IntTensor(asIntTensor().shape, resBuffer) } override fun Tensor.transpose(i: Int, j: Int): IntTensor { - val ii = tensor.minusIndex(i) - val jj = tensor.minusIndex(j) - checkTranspose(tensor.dimension, ii, jj) - val n = tensor.numElements + val ii = asIntTensor().minusIndex(i) + val jj = asIntTensor().minusIndex(j) + checkTranspose(asIntTensor().dimension, ii, jj) + val n = asIntTensor().numElements val resBuffer = IntArray(n) - val resShape = tensor.shape.copyOf() + val resShape = asIntTensor().shape.copyOf() resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] } val resTensor = IntTensor(resShape, resBuffer) for (offset in 0 until n) { - val oldMultiIndex = tensor.indices.index(offset) + val oldMultiIndex = asIntTensor().indices.index(offset) val newMultiIndex = oldMultiIndex.copyOf() newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } val linearIndex = resTensor.indices.offset(newMultiIndex) resTensor.mutableBuffer.array()[linearIndex] = - tensor.mutableBuffer.array()[tensor.bufferStart + offset] + asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + offset] } return resTensor } override fun Tensor.view(shape: IntArray): IntTensor { - checkView(tensor, shape) - return IntTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) + checkView(asIntTensor(), shape) + return IntTensor(shape, asIntTensor().mutableBuffer.array(), asIntTensor().bufferStart) } override fun Tensor.viewAs(other: StructureND): IntTensor = - tensor.view(other.shape) + asIntTensor().view(other.shape) override fun diagonalEmbedding( diagonalEntries: Tensor, @@ -374,8 +374,8 @@ public open class IntTensorAlgebra : TensorAlgebra { diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray() val resTensor = zeros(resShape) - for (i in 0 until diagonalEntries.tensor.numElements) { - val multiIndex = diagonalEntries.tensor.indices.index(i) + for (i in 0 until diagonalEntries.asIntTensor().numElements) { + val multiIndex = diagonalEntries.asIntTensor().indices.index(i) var offset1 = 0 var offset2 = abs(realOffset) @@ -391,19 +391,19 @@ public open class IntTensorAlgebra : TensorAlgebra { resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex] } - return resTensor.tensor + return resTensor.asIntTensor() } private infix fun Tensor.eq( other: Tensor, ): Boolean { - checkShapesCompatible(tensor, other) - val n = tensor.numElements - if (n != other.tensor.numElements) { + checkShapesCompatible(asIntTensor(), other) + val n = asIntTensor().numElements + if (n != other.asIntTensor().numElements) { return false } for (i in 0 until n) { - if (tensor.mutableBuffer[tensor.bufferStart + i] != other.tensor.mutableBuffer[other.tensor.bufferStart + i]) { + if (asIntTensor().mutableBuffer[asIntTensor().bufferStart + i] != other.asIntTensor().mutableBuffer[other.asIntTensor().bufferStart + i]) { return false } } @@ -422,7 +422,7 @@ public open class IntTensorAlgebra : TensorAlgebra { check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } val resShape = intArrayOf(tensors.size) + shape val resBuffer = tensors.flatMap { - it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements) + it.asIntTensor().mutableBuffer.array().drop(it.asIntTensor().bufferStart).take(it.asIntTensor().numElements) }.toIntArray() return IntTensor(resShape, resBuffer, 0) } @@ -436,7 +436,7 @@ public open class IntTensorAlgebra : TensorAlgebra { public fun Tensor.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { this[it] }) private inline fun StructureND.fold(foldFunction: (IntArray) -> Int): Int = - foldFunction(tensor.copyArray()) + foldFunction(asIntTensor().copyArray()) private inline fun StructureND.foldDim( dim: Int, @@ -459,35 +459,35 @@ public open class IntTensorAlgebra : TensorAlgebra { val prefix = index.take(dim).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray() resTensor[index] = foldFunction(IntArray(shape[dim]) { i -> - tensor[prefix + intArrayOf(i) + suffix] + asIntTensor()[prefix + intArrayOf(i) + suffix] }) } return resTensor } - override fun StructureND.sum(): Int = tensor.fold { it.sum() } + override fun StructureND.sum(): Int = asIntTensor().fold { it.sum() } override fun StructureND.sum(dim: Int, keepDim: Boolean): IntTensor = - foldDim(dim, keepDim) { x -> x.sum() }.toIntTensor() + foldDim(dim, keepDim) { x -> x.sum() }.asIntTensor() override fun StructureND.min(): Int = this.fold { it.minOrNull()!! } override fun StructureND.min(dim: Int, keepDim: Boolean): IntTensor = - foldDim(dim, keepDim) { x -> x.minOrNull()!! }.toIntTensor() + foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asIntTensor() override fun StructureND.max(): Int = this.fold { it.maxOrNull()!! } override fun StructureND.max(dim: Int, keepDim: Boolean): IntTensor = - foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.toIntTensor() + foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.asIntTensor() override fun StructureND.argMax(dim: Int, keepDim: Boolean): IntTensor = foldDim(dim, keepDim) { x -> x.withIndex().maxByOrNull { it.value }?.index!! - }.toIntTensor() + }.asIntTensor() } -public val Int.Companion.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra -public val IntRing.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra +public val Int.Companion.tensorAlgebra: IntTensorAlgebra get() = IntTensorAlgebra +public val IntRing.tensorAlgebra: IntTensorAlgebra get() = IntTensorAlgebra diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt index e997dfcca..62cc85396 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/checks.kt @@ -58,7 +58,7 @@ internal fun DoubleTensorAlgebra.checkSymmetric( internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) { checkSymmetric(tensor, epsilon) for (mat in tensor.matrixSequence()) - check(mat.asTensor().detLU().value() > 0.0) { - "Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" + check(mat.toTensor().detLU().value() > 0.0) { + "Tensor contains matrices which are not positive definite ${mat.toTensor().detLU().value()}" } } \ No newline at end of file diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt index c19b8bf95..e7704b257 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/internal/tensorCastsUtils.kt @@ -13,16 +13,17 @@ import space.kscience.kmath.tensors.core.DoubleTensor import space.kscience.kmath.tensors.core.IntTensor import space.kscience.kmath.tensors.core.TensorLinearStructure -internal fun BufferedTensor.asTensor(): IntTensor = +internal fun BufferedTensor.toTensor(): IntTensor = IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) -internal fun BufferedTensor.asTensor(): DoubleTensor = +internal fun BufferedTensor.toTensor(): DoubleTensor = DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) internal fun StructureND.copyToBufferedTensor(): BufferedTensor = BufferedTensor( this.shape, - TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().asMutableBuffer(), 0 + TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().asMutableBuffer(), + 0 ) internal fun StructureND.toBufferedTensor(): BufferedTensor = when (this) { @@ -34,17 +35,3 @@ internal fun StructureND.toBufferedTensor(): BufferedTensor = when (th } else -> this.copyToBufferedTensor() } - -@PublishedApi -internal val StructureND.tensor: DoubleTensor - get() = when (this) { - is DoubleTensor -> this - else -> this.toBufferedTensor().asTensor() - } - -@PublishedApi -internal val StructureND.tensor: IntTensor - get() = when (this) { - is IntTensor -> this - else -> this.toBufferedTensor().asTensor() - } diff --git a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt index 968efa468..1c914d474 100644 --- a/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt +++ b/kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors/core/tensorCasts.kt @@ -5,18 +5,26 @@ package space.kscience.kmath.tensors.core +import space.kscience.kmath.nd.StructureND import space.kscience.kmath.tensors.api.Tensor -import space.kscience.kmath.tensors.core.internal.tensor +import space.kscience.kmath.tensors.core.internal.toBufferedTensor +import space.kscience.kmath.tensors.core.internal.toTensor /** * Casts [Tensor] of [Double] to [DoubleTensor] */ -public fun Tensor.toDoubleTensor(): DoubleTensor = this.tensor +public fun StructureND.asDoubleTensor(): DoubleTensor = when (this) { + is DoubleTensor -> this + else -> this.toBufferedTensor().toTensor() +} /** * Casts [Tensor] of [Int] to [IntTensor] */ -public fun Tensor.toIntTensor(): IntTensor = this.tensor +public fun StructureND.asIntTensor(): IntTensor = when (this) { + is IntTensor -> this + else -> this.toBufferedTensor().toTensor() +} /** * Returns a copy-protected [DoubleArray] of tensor elements diff --git a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt index 5e5ef95be..61248cbb8 100644 --- a/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt +++ b/kmath-tensors/src/commonTest/kotlin/space/kscience/kmath/tensors/core/TestDoubleTensor.kt @@ -14,9 +14,9 @@ import space.kscience.kmath.operations.invoke import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.tensors.core.internal.array -import space.kscience.kmath.tensors.core.internal.asTensor import space.kscience.kmath.tensors.core.internal.matrixSequence import space.kscience.kmath.tensors.core.internal.toBufferedTensor +import space.kscience.kmath.tensors.core.internal.toTensor import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -56,7 +56,7 @@ internal class TestDoubleTensor { assertEquals(tensor[intArrayOf(0, 1, 0)], 109.56) tensor.matrixSequence().forEach { - val a = it.asTensor() + val a = it.toTensor() val secondRow = a[1].as1D() val secondColumn = a.transpose(0, 1)[1].as1D() assertEquals(secondColumn[0], 77.89) @@ -75,10 +75,10 @@ internal class TestDoubleTensor { // map to tensors val bufferedTensorArray = ndArray.toBufferedTensor() // strides are flipped so data copied - val tensorArray = bufferedTensorArray.asTensor() // data not contiguous so copied again + val tensorArray = bufferedTensorArray.toTensor() // data not contiguous so copied again - val tensorArrayPublic = ndArray.toDoubleTensor() // public API, data copied twice - val sharedTensorArray = tensorArrayPublic.toDoubleTensor() // no data copied by matching type + val tensorArrayPublic = ndArray.asDoubleTensor() // public API, data copied twice + val sharedTensorArray = tensorArrayPublic.asDoubleTensor() // no data copied by matching type assertTrue(tensorArray.mutableBuffer.array() contentEquals sharedTensorArray.mutableBuffer.array())