Move power to ExtendedFieldOps

This commit is contained in:
Alexander Nozik 2022-09-05 22:08:35 +03:00
parent 5042fda751
commit a9821772db
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
11 changed files with 320 additions and 301 deletions

View File

@ -11,6 +11,7 @@ import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.math.pow
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
public class DoubleBufferND( public class DoubleBufferND(
@ -165,6 +166,15 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
override fun atanh(arg: StructureND<Double>): DoubleBufferND = override fun atanh(arg: StructureND<Double>): DoubleBufferND =
mapInline(arg.toBufferND()) { kotlin.math.atanh(it) } mapInline(arg.toBufferND()) { kotlin.math.atanh(it) }
override fun power(
arg: StructureND<Double>,
pow: Number,
): StructureND<Double> = if (pow is Int) {
mapInline(arg.toBufferND()) { it.pow(pow) }
} else {
mapInline(arg.toBufferND()) { it.pow(pow.toDouble()) }
}
public companion object : DoubleFieldOpsND() public companion object : DoubleFieldOpsND()
} }
@ -181,7 +191,7 @@ public class DoubleFieldND(override val shape: Shape) :
it.kpow(pow) it.kpow(pow)
} }
override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND = if(pow.isInteger()){ override fun power(arg: StructureND<Double>, pow: Number): DoubleBufferND = if (pow.isInteger()) {
power(arg, pow.toInt()) power(arg, pow.toInt())
} else { } else {
val dpow = pow.toDouble() val dpow = pow.toDouble()

View File

@ -133,6 +133,12 @@ public abstract class DoubleBufferOps : BufferAlgebra<Double, DoubleField>, Exte
override fun scale(a: Buffer<Double>, value: Double): DoubleBuffer = a.mapInline { it * value } override fun scale(a: Buffer<Double>, value: Double): DoubleBuffer = a.mapInline { it * value }
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> = if (pow is Int) {
arg.mapInline { it.pow(pow) }
} else {
arg.mapInline { it.pow(pow.toDouble()) }
}
public companion object : DoubleBufferOps() { public companion object : DoubleBufferOps() {
public inline fun Buffer<Double>.mapInline(block: (Double) -> Double): DoubleBuffer = public inline fun Buffer<Double>.mapInline(block: (Double) -> Double): DoubleBuffer =
if (this is DoubleBuffer) { if (this is DoubleBuffer) {

View File

@ -2,12 +2,13 @@
* Copyright 2018-2022 KMath contributors. * Copyright 2018-2022 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/ */
@file:Suppress("NOTHING_TO_INLINE")
package space.kscience.kmath.operations package space.kscience.kmath.operations
import space.kscience.kmath.structures.* import space.kscience.kmath.structures.*
import kotlin.math.pow as kpow import kotlin.math.pow as kpow
/** /**
* Advanced Number-like semifield that implements basic operations. * Advanced Number-like semifield that implements basic operations.
*/ */
@ -15,7 +16,8 @@ public interface ExtendedFieldOps<T> :
FieldOps<T>, FieldOps<T>,
TrigonometricOperations<T>, TrigonometricOperations<T>,
ExponentialOperations<T>, ExponentialOperations<T>,
ScaleOperations<T> { ScaleOperations<T>,
PowerOperations<T> {
override fun tan(arg: T): T = sin(arg) / cos(arg) override fun tan(arg: T): T = sin(arg) / cos(arg)
override fun tanh(arg: T): T = sinh(arg) / cosh(arg) override fun tanh(arg: T): T = sinh(arg) / cosh(arg)
@ -41,7 +43,7 @@ public interface ExtendedFieldOps<T> :
/** /**
* Advanced Number-like field that implements basic operations. * Advanced Number-like field that implements basic operations.
*/ */
public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, PowerOperations<T>, NumericAlgebra<T> { public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T> {
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0 override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0 override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg)) override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
@ -64,7 +66,7 @@ public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, PowerOperatio
/** /**
* A field for [Double] without boxing. Does not produce appropriate field element. * A field for [Double] without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> { public object DoubleField : ExtendedField<Double>, Norm<Double, Double>, ScaleOperations<Double> {
override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory(::DoubleBuffer) override val bufferFactory: MutableBufferFactory<Double> = MutableBufferFactory(::DoubleBuffer)
@ -124,7 +126,7 @@ public val Double.Companion.algebra: DoubleField get() = DoubleField
/** /**
* A field for [Float] without boxing. Does not produce appropriate field element. * A field for [Float] without boxing. Does not produce appropriate field element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object FloatField : ExtendedField<Float>, Norm<Float, Float> { public object FloatField : ExtendedField<Float>, Norm<Float, Float> {
override val bufferFactory: MutableBufferFactory<Float> = MutableBufferFactory(::FloatBuffer) override val bufferFactory: MutableBufferFactory<Float> = MutableBufferFactory(::FloatBuffer)
@ -180,7 +182,7 @@ public val Float.Companion.algebra: FloatField get() = FloatField
/** /**
* A field for [Int] without boxing. Does not produce corresponding ring element. * A field for [Int] without boxing. Does not produce corresponding ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> { public object IntRing : Ring<Int>, Norm<Int, Int>, NumericAlgebra<Int> {
override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory(::IntBuffer) override val bufferFactory: MutableBufferFactory<Int> = MutableBufferFactory(::IntBuffer)
@ -203,7 +205,7 @@ public val Int.Companion.algebra: IntRing get() = IntRing
/** /**
* A field for [Short] without boxing. Does not produce appropriate ring element. * A field for [Short] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> { public object ShortRing : Ring<Short>, Norm<Short, Short>, NumericAlgebra<Short> {
override val bufferFactory: MutableBufferFactory<Short> = MutableBufferFactory(::ShortBuffer) override val bufferFactory: MutableBufferFactory<Short> = MutableBufferFactory(::ShortBuffer)
@ -226,7 +228,7 @@ public val Short.Companion.algebra: ShortRing get() = ShortRing
/** /**
* A field for [Byte] without boxing. Does not produce appropriate ring element. * A field for [Byte] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> { public object ByteRing : Ring<Byte>, Norm<Byte, Byte>, NumericAlgebra<Byte> {
override val bufferFactory: MutableBufferFactory<Byte> = MutableBufferFactory(::ByteBuffer) override val bufferFactory: MutableBufferFactory<Byte> = MutableBufferFactory(::ByteBuffer)
@ -249,7 +251,7 @@ public val Byte.Companion.algebra: ByteRing get() = ByteRing
/** /**
* A field for [Double] without boxing. Does not produce appropriate ring element. * A field for [Double] without boxing. Does not produce appropriate ring element.
*/ */
@Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE") @Suppress("EXTENSION_SHADOWED_BY_MEMBER", "OVERRIDE_BY_INLINE")
public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> { public object LongRing : Ring<Long>, Norm<Long, Long>, NumericAlgebra<Long> {
override val bufferFactory: MutableBufferFactory<Long> = MutableBufferFactory(::LongBuffer) override val bufferFactory: MutableBufferFactory<Long> = MutableBufferFactory(::LongBuffer)

View File

@ -138,6 +138,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap() override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap() override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap() override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
override fun power(arg: StructureND<T>, pow: Number): StructureND<T> = Transforms.pow(arg.ndArray, pow).wrap()
override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap() override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap() override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> = override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =

View File

@ -11,7 +11,6 @@ import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.broadcastTensors import space.kscience.kmath.tensors.core.internal.broadcastTensors
import space.kscience.kmath.tensors.core.internal.broadcastTo import space.kscience.kmath.tensors.core.internal.broadcastTo
import space.kscience.kmath.tensors.core.internal.tensor
/** /**
* Basic linear algebra operations implemented with broadcasting. * Basic linear algebra operations implemented with broadcasting.
@ -20,7 +19,7 @@ import space.kscience.kmath.tensors.core.internal.tensor
public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() { public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, arg.tensor) val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
@ -30,15 +29,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
override fun Tensor<Double>.plusAssign(arg: StructureND<Double>) { override fun Tensor<Double>.plusAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.tensor, tensor.shape) val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until tensor.indices.linearSize) { for (i in 0 until asDoubleTensor().indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] +=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, arg.tensor) val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
@ -48,15 +47,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
override fun Tensor<Double>.minusAssign(arg: StructureND<Double>) { override fun Tensor<Double>.minusAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.tensor, tensor.shape) val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until tensor.indices.linearSize) { for (i in 0 until asDoubleTensor().indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, arg.tensor) val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
@ -67,15 +66,15 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
override fun Tensor<Double>.timesAssign(arg: StructureND<Double>) { override fun Tensor<Double>.timesAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.tensor, tensor.shape) val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until tensor.indices.linearSize) { for (i in 0 until asDoubleTensor().indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor {
val broadcast = broadcastTensors(tensor, arg.tensor) val broadcast = broadcastTensors(asDoubleTensor(), arg.asDoubleTensor())
val newThis = broadcast[0] val newThis = broadcast[0]
val newOther = broadcast[1] val newOther = broadcast[1]
val resBuffer = DoubleArray(newThis.indices.linearSize) { i -> val resBuffer = DoubleArray(newThis.indices.linearSize) { i ->
@ -86,10 +85,10 @@ public object BroadcastDoubleTensorAlgebra : DoubleTensorAlgebra() {
} }
override fun Tensor<Double>.divAssign(arg: StructureND<Double>) { override fun Tensor<Double>.divAssign(arg: StructureND<Double>) {
val newOther = broadcastTo(arg.tensor, tensor.shape) val newOther = broadcastTo(arg.asDoubleTensor(), asDoubleTensor().shape)
for (i in 0 until tensor.indices.linearSize) { for (i in 0 until asDoubleTensor().indices.linearSize) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /=
newOther.mutableBuffer.array()[tensor.bufferStart + i] newOther.mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
} }

View File

@ -43,7 +43,7 @@ public open class DoubleTensorAlgebra :
@PerformancePitfall @PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): DoubleTensor { final override inline fun StructureND<Double>.map(transform: DoubleField.(Double) -> Double): DoubleTensor {
val tensor = this.tensor val tensor = this.asDoubleTensor()
//TODO remove additional copy //TODO remove additional copy
val sourceArray = tensor.copyArray() val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) } val array = DoubleArray(tensor.numElements) { DoubleField.transform(sourceArray[it]) }
@ -57,7 +57,7 @@ public open class DoubleTensorAlgebra :
@PerformancePitfall @PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Double>.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor { final override inline fun StructureND<Double>.mapIndexed(transform: DoubleField.(index: IntArray, Double) -> Double): DoubleTensor {
val tensor = this.tensor val tensor = this.asDoubleTensor()
//TODO remove additional copy //TODO remove additional copy
val sourceArray = tensor.copyArray() val sourceArray = tensor.copyArray()
val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) } val array = DoubleArray(tensor.numElements) { DoubleField.transform(tensor.indices.index(it), sourceArray[it]) }
@ -77,9 +77,9 @@ public open class DoubleTensorAlgebra :
require(left.shape.contentEquals(right.shape)) { require(left.shape.contentEquals(right.shape)) {
"The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}"
} }
val leftTensor = left.tensor val leftTensor = left.asDoubleTensor()
val leftArray = leftTensor.copyArray() val leftArray = leftTensor.copyArray()
val rightTensor = right.tensor val rightTensor = right.asDoubleTensor()
val rightArray = rightTensor.copyArray() val rightArray = rightTensor.copyArray()
val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) } val array = DoubleArray(leftTensor.numElements) { DoubleField.transform(leftArray[it], rightArray[it]) }
return DoubleTensor( return DoubleTensor(
@ -88,8 +88,8 @@ public open class DoubleTensorAlgebra :
) )
} }
override fun StructureND<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1)) override fun StructureND<Double>.valueOrNull(): Double? = if (asDoubleTensor().shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart] else null
override fun StructureND<Double>.value(): Double = valueOrNull() override fun StructureND<Double>.value(): Double = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
@ -121,10 +121,10 @@ public open class DoubleTensorAlgebra :
) )
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override operator fun Tensor<Double>.get(i: Int): DoubleTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = asDoubleTensor().shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + asDoubleTensor().bufferStart
return DoubleTensor(newShape, tensor.mutableBuffer.array(), newStart) return DoubleTensor(newShape, asDoubleTensor().mutableBuffer.array(), newStart)
} }
/** /**
@ -147,8 +147,8 @@ public open class DoubleTensorAlgebra :
* @return tensor with the `input` tensor shape and filled with [value]. * @return tensor with the `input` tensor shape and filled with [value].
*/ */
public fun Tensor<Double>.fullLike(value: Double): DoubleTensor { public fun Tensor<Double>.fullLike(value: Double): DoubleTensor {
val shape = tensor.shape val shape = asDoubleTensor().shape
val buffer = DoubleArray(tensor.numElements) { value } val buffer = DoubleArray(asDoubleTensor().numElements) { value }
return DoubleTensor(shape, buffer) return DoubleTensor(shape, buffer)
} }
@ -165,7 +165,7 @@ public open class DoubleTensorAlgebra :
* *
* @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `0.0`, with the same shape as `input` tensor.
*/ */
public fun StructureND<Double>.zeroesLike(): DoubleTensor = tensor.fullLike(0.0) public fun StructureND<Double>.zeroesLike(): DoubleTensor = asDoubleTensor().fullLike(0.0)
/** /**
* Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape]. * Returns a tensor filled with the scalar value `1.0`, with the shape defined by the variable argument [shape].
@ -180,7 +180,7 @@ public open class DoubleTensorAlgebra :
* *
* @return tensor filled with the scalar value `1.0`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `1.0`, with the same shape as `input` tensor.
*/ */
public fun Tensor<Double>.onesLike(): DoubleTensor = tensor.fullLike(1.0) public fun Tensor<Double>.onesLike(): DoubleTensor = asDoubleTensor().fullLike(1.0)
/** /**
* Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. * Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
@ -204,182 +204,182 @@ public open class DoubleTensorAlgebra :
* @return a copy of the `input` tensor with a copied buffer. * @return a copy of the `input` tensor with a copied buffer.
*/ */
public fun StructureND<Double>.copy(): DoubleTensor = public fun StructureND<Double>.copy(): DoubleTensor =
DoubleTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) DoubleTensor(asDoubleTensor().shape, asDoubleTensor().mutableBuffer.array().copyOf(), asDoubleTensor().bufferStart)
override fun Double.plus(arg: StructureND<Double>): DoubleTensor { override fun Double.plus(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(arg.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] + this
} }
return DoubleTensor(arg.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun StructureND<Double>.plus(arg: Double): DoubleTensor = arg + tensor override fun StructureND<Double>.plus(arg: Double): DoubleTensor = arg + asDoubleTensor()
override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.plus(arg: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, arg.tensor) checkShapesCompatible(asDoubleTensor(), arg.asDoubleTensor())
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] asDoubleTensor().mutableBuffer.array()[i] + arg.asDoubleTensor().mutableBuffer.array()[i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun Tensor<Double>.plusAssign(value: Double) { override fun Tensor<Double>.plusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] += value
} }
} }
override fun Tensor<Double>.plusAssign(arg: StructureND<Double>) { override fun Tensor<Double>.plusAssign(arg: StructureND<Double>) {
checkShapesCompatible(tensor, arg.tensor) checkShapesCompatible(asDoubleTensor(), arg.asDoubleTensor())
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] +=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun Double.minus(arg: StructureND<Double>): DoubleTensor { override fun Double.minus(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(arg.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i ->
this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] this - arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i]
} }
return DoubleTensor(arg.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun StructureND<Double>.minus(arg: Double): DoubleTensor { override fun StructureND<Double>.minus(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] - arg
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.minus(arg: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] asDoubleTensor().mutableBuffer.array()[i] - arg.asDoubleTensor().mutableBuffer.array()[i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun Tensor<Double>.minusAssign(value: Double) { override fun Tensor<Double>.minusAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -= value
} }
} }
override fun Tensor<Double>.minusAssign(arg: StructureND<Double>) { override fun Tensor<Double>.minusAssign(arg: StructureND<Double>) {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] -=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun Double.times(arg: StructureND<Double>): DoubleTensor { override fun Double.times(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(arg.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] * this
} }
return DoubleTensor(arg.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun StructureND<Double>.times(arg: Double): DoubleTensor = arg * tensor override fun StructureND<Double>.times(arg: Double): DoubleTensor = arg * asDoubleTensor()
override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.times(arg: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] * asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun Tensor<Double>.timesAssign(value: Double) { override fun Tensor<Double>.timesAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *= value
} }
} }
override fun Tensor<Double>.timesAssign(arg: StructureND<Double>) { override fun Tensor<Double>.timesAssign(arg: StructureND<Double>) {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] *=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun Double.div(arg: StructureND<Double>): DoubleTensor { override fun Double.div(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(arg.tensor.numElements) { i -> val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i ->
this / arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] this / arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i]
} }
return DoubleTensor(arg.shape, resBuffer) return DoubleTensor(arg.shape, resBuffer)
} }
override fun StructureND<Double>.div(arg: Double): DoubleTensor { override fun StructureND<Double>.div(arg: Double): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] / arg asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] / arg
} }
return DoubleTensor(shape, resBuffer) return DoubleTensor(shape, resBuffer)
} }
override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor { override fun StructureND<Double>.div(arg: StructureND<Double>): DoubleTensor {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] / asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i] /
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[arg.asDoubleTensor().bufferStart + i]
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun Tensor<Double>.divAssign(value: Double) { override fun Tensor<Double>.divAssign(value: Double) {
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= value asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /= value
} }
} }
override fun Tensor<Double>.divAssign(arg: StructureND<Double>) { override fun Tensor<Double>.divAssign(arg: StructureND<Double>) {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asDoubleTensor(), arg)
for (i in 0 until tensor.numElements) { for (i in 0 until asDoubleTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] /= asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i] /=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i]
} }
} }
override fun StructureND<Double>.unaryMinus(): DoubleTensor { override fun StructureND<Double>.unaryMinus(): DoubleTensor {
val resBuffer = DoubleArray(tensor.numElements) { i -> val resBuffer = DoubleArray(asDoubleTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + i].unaryMinus()
} }
return DoubleTensor(tensor.shape, resBuffer) return DoubleTensor(asDoubleTensor().shape, resBuffer)
} }
override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor { override fun Tensor<Double>.transpose(i: Int, j: Int): DoubleTensor {
val ii = tensor.minusIndex(i) val ii = asDoubleTensor().minusIndex(i)
val jj = tensor.minusIndex(j) val jj = asDoubleTensor().minusIndex(j)
checkTranspose(tensor.dimension, ii, jj) checkTranspose(asDoubleTensor().dimension, ii, jj)
val n = tensor.numElements val n = asDoubleTensor().numElements
val resBuffer = DoubleArray(n) val resBuffer = DoubleArray(n)
val resShape = tensor.shape.copyOf() val resShape = asDoubleTensor().shape.copyOf()
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] } resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = DoubleTensor(resShape, resBuffer) val resTensor = DoubleTensor(resShape, resBuffer)
for (offset in 0 until n) { for (offset in 0 until n) {
val oldMultiIndex = tensor.indices.index(offset) val oldMultiIndex = asDoubleTensor().indices.index(offset)
val newMultiIndex = oldMultiIndex.copyOf() val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.indices.offset(newMultiIndex) val linearIndex = resTensor.indices.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset] asDoubleTensor().mutableBuffer.array()[asDoubleTensor().bufferStart + offset]
} }
return resTensor return resTensor
} }
override fun Tensor<Double>.view(shape: IntArray): DoubleTensor { override fun Tensor<Double>.view(shape: IntArray): DoubleTensor {
checkView(tensor, shape) checkView(asDoubleTensor(), shape)
return DoubleTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) return DoubleTensor(shape, asDoubleTensor().mutableBuffer.array(), asDoubleTensor().bufferStart)
} }
override fun Tensor<Double>.viewAs(other: StructureND<Double>): DoubleTensor = override fun Tensor<Double>.viewAs(other: StructureND<Double>): DoubleTensor =
tensor.view(other.shape) asDoubleTensor().view(other.shape)
/** /**
* Broadcasting Matrix product of two tensors. * Broadcasting Matrix product of two tensors.
@ -412,25 +412,25 @@ public open class DoubleTensorAlgebra :
*/ */
@UnstableKMathAPI @UnstableKMathAPI
public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor { public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor {
if (tensor.shape.size == 1 && other.shape.size == 1) { if (asDoubleTensor().shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(tensor.times(other).tensor.mutableBuffer.array().sum())) return DoubleTensor(intArrayOf(1), doubleArrayOf(asDoubleTensor().times(other).asDoubleTensor().mutableBuffer.array().sum()))
} }
var newThis = tensor.copy() var newThis = asDoubleTensor().copy()
var newOther = other.copy() var newOther = other.copy()
var penultimateDim = false var penultimateDim = false
var lastDim = false var lastDim = false
if (tensor.shape.size == 1) { if (asDoubleTensor().shape.size == 1) {
penultimateDim = true penultimateDim = true
newThis = tensor.view(intArrayOf(1) + tensor.shape) newThis = asDoubleTensor().view(intArrayOf(1) + asDoubleTensor().shape)
} }
if (other.shape.size == 1) { if (other.shape.size == 1) {
lastDim = true lastDim = true
newOther = other.tensor.view(other.shape + intArrayOf(1)) newOther = other.asDoubleTensor().view(other.shape + intArrayOf(1))
} }
val broadcastTensors = broadcastOuterTensors(newThis.tensor, newOther.tensor) val broadcastTensors = broadcastOuterTensors(newThis.asDoubleTensor(), newOther.asDoubleTensor())
newThis = broadcastTensors[0] newThis = broadcastTensors[0]
newOther = broadcastTensors[1] newOther = broadcastTensors[1]
@ -497,8 +497,8 @@ public open class DoubleTensorAlgebra :
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray() diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
val resTensor = zeros(resShape) val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) { for (i in 0 until diagonalEntries.asDoubleTensor().numElements) {
val multiIndex = diagonalEntries.tensor.indices.index(i) val multiIndex = diagonalEntries.asDoubleTensor().indices.index(i)
var offset1 = 0 var offset1 = 0
var offset2 = abs(realOffset) var offset2 = abs(realOffset)
@ -514,7 +514,7 @@ public open class DoubleTensorAlgebra :
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex] resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
} }
return resTensor.tensor return resTensor.asDoubleTensor()
} }
/** /**
@ -525,7 +525,7 @@ public open class DoubleTensorAlgebra :
* @return true if two tensors have the same shape and elements, false otherwise. * @return true if two tensors have the same shape and elements, false otherwise.
*/ */
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean = public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
tensor.eq(other) { x, y -> abs(x - y) < epsilon } asDoubleTensor().eq(other) { x, y -> abs(x - y) < epsilon }
/** /**
* Compares element-wise two tensors. * Compares element-wise two tensors.
@ -534,21 +534,21 @@ public open class DoubleTensorAlgebra :
* @param other the tensor to compare with `input` tensor. * @param other the tensor to compare with `input` tensor.
* @return true if two tensors have the same shape and elements, false otherwise. * @return true if two tensors have the same shape and elements, false otherwise.
*/ */
public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = tensor.eq(other, 1e-5) public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = asDoubleTensor().eq(other, 1e-5)
private fun Tensor<Double>.eq( private fun Tensor<Double>.eq(
other: Tensor<Double>, other: Tensor<Double>,
eqFunction: (Double, Double) -> Boolean, eqFunction: (Double, Double) -> Boolean,
): Boolean { ): Boolean {
checkShapesCompatible(tensor, other) checkShapesCompatible(asDoubleTensor(), other)
val n = tensor.numElements val n = asDoubleTensor().numElements
if (n != other.tensor.numElements) { if (n != other.asDoubleTensor().numElements) {
return false return false
} }
for (i in 0 until n) { for (i in 0 until n) {
if (!eqFunction( if (!eqFunction(
tensor.mutableBuffer[tensor.bufferStart + i], asDoubleTensor().mutableBuffer[asDoubleTensor().bufferStart + i],
other.tensor.mutableBuffer[other.tensor.bufferStart + i] other.asDoubleTensor().mutableBuffer[other.asDoubleTensor().bufferStart + i]
) )
) { ) {
return false return false
@ -578,7 +578,7 @@ public open class DoubleTensorAlgebra :
* with `0.0` mean and `1.0` standard deviation. * with `0.0` mean and `1.0` standard deviation.
*/ */
public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor = public fun Tensor<Double>.randomNormalLike(seed: Long = 0): DoubleTensor =
DoubleTensor(tensor.shape, getRandomNormals(tensor.shape.reduce(Int::times), seed)) DoubleTensor(asDoubleTensor().shape, getRandomNormals(asDoubleTensor().shape.reduce(Int::times), seed))
/** /**
* Concatenates a sequence of tensors with equal shapes along the first dimension. * Concatenates a sequence of tensors with equal shapes along the first dimension.
@ -592,7 +592,7 @@ public open class DoubleTensorAlgebra :
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap { val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements) it.asDoubleTensor().mutableBuffer.array().drop(it.asDoubleTensor().bufferStart).take(it.asDoubleTensor().numElements)
}.toDoubleArray() }.toDoubleArray()
return DoubleTensor(resShape, resBuffer, 0) return DoubleTensor(resShape, resBuffer, 0)
} }
@ -606,7 +606,7 @@ public open class DoubleTensorAlgebra :
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] })
private inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = private inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(tensor.copyArray()) foldFunction(asDoubleTensor().copyArray())
private inline fun <reified R : Any> StructureND<Double>.foldDim( private inline fun <reified R : Any> StructureND<Double>.foldDim(
dim: Int, dim: Int,
@ -629,44 +629,44 @@ public open class DoubleTensorAlgebra :
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i -> resTensor[index] = foldFunction(DoubleArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix] asDoubleTensor()[prefix + intArrayOf(i) + suffix]
}) })
} }
return resTensor return resTensor
} }
override fun StructureND<Double>.sum(): Double = tensor.fold { it.sum() } override fun StructureND<Double>.sum(): Double = asDoubleTensor().fold { it.sum() }
override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.sum(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(dim, keepDim) { x -> x.sum() }.toDoubleTensor() foldDim(dim, keepDim) { x -> x.sum() }.asDoubleTensor()
override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! } override fun StructureND<Double>.min(): Double = this.fold { it.minOrNull()!! }
override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(dim, keepDim) { x -> x.minOrNull()!! }.toDoubleTensor() foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asDoubleTensor()
override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.toDoubleTensor() foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.asDoubleTensor()
override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Double>.argMax(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> foldDim(dim, keepDim) { x ->
x.withIndex().maxByOrNull { it.value }?.index!! x.withIndex().maxByOrNull { it.value }?.index!!
}.toIntTensor() }.asIntTensor()
override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / tensor.numElements } override fun StructureND<Double>.mean(): Double = this.fold { it.sum() / asDoubleTensor().numElements }
override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(dim, keepDim) { arr -> override fun StructureND<Double>.mean(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(dim, keepDim) { arr ->
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
arr.sum() / shape[dim] arr.sum() / shape[dim]
}.toDoubleTensor() }.asDoubleTensor()
override fun StructureND<Double>.std(): Double = fold { arr -> override fun StructureND<Double>.std(): Double = fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / asDoubleTensor().numElements
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1)) sqrt(arr.sumOf { (it - mean) * (it - mean) } / (asDoubleTensor().numElements - 1))
} }
override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.std(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
@ -676,11 +676,11 @@ public open class DoubleTensorAlgebra :
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)) sqrt(arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1))
}.toDoubleTensor() }.asDoubleTensor()
override fun StructureND<Double>.variance(): Double = fold { arr -> override fun StructureND<Double>.variance(): Double = fold { arr ->
val mean = arr.sum() / tensor.numElements val mean = arr.sum() / asDoubleTensor().numElements
arr.sumOf { (it - mean) * (it - mean) } / (tensor.numElements - 1) arr.sumOf { (it - mean) * (it - mean) } / (asDoubleTensor().numElements - 1)
} }
override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim( override fun StructureND<Double>.variance(dim: Int, keepDim: Boolean): DoubleTensor = foldDim(
@ -690,7 +690,7 @@ public open class DoubleTensorAlgebra :
check(dim < dimension) { "Dimension $dim out of range $dimension" } check(dim < dimension) { "Dimension $dim out of range $dimension" }
val mean = arr.sum() / shape[dim] val mean = arr.sum() / shape[dim]
arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1) arr.sumOf { (it - mean) * (it - mean) } / (shape[dim] - 1)
}.toDoubleTensor() }.asDoubleTensor()
private fun cov(x: DoubleTensor, y: DoubleTensor): Double { private fun cov(x: DoubleTensor, y: DoubleTensor): Double {
val n = x.shape[0] val n = x.shape[0]
@ -716,45 +716,51 @@ public open class DoubleTensorAlgebra :
) )
for (i in 0 until n) { for (i in 0 until n) {
for (j in 0 until n) { for (j in 0 until n) {
resTensor[intArrayOf(i, j)] = cov(tensors[i].tensor, tensors[j].tensor) resTensor[intArrayOf(i, j)] = cov(tensors[i].asDoubleTensor(), tensors[j].asDoubleTensor())
} }
} }
return resTensor return resTensor
} }
override fun StructureND<Double>.exp(): DoubleTensor = tensor.map { exp(it) } override fun StructureND<Double>.exp(): DoubleTensor = asDoubleTensor().map { exp(it) }
override fun StructureND<Double>.ln(): DoubleTensor = tensor.map { ln(it) } override fun StructureND<Double>.ln(): DoubleTensor = asDoubleTensor().map { ln(it) }
override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) } override fun StructureND<Double>.sqrt(): DoubleTensor = asDoubleTensor().map { sqrt(it) }
override fun StructureND<Double>.cos(): DoubleTensor = tensor.map { cos(it) } override fun StructureND<Double>.cos(): DoubleTensor = asDoubleTensor().map { cos(it) }
override fun StructureND<Double>.acos(): DoubleTensor = tensor.map { acos(it) } override fun StructureND<Double>.acos(): DoubleTensor = asDoubleTensor().map { acos(it) }
override fun StructureND<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) } override fun StructureND<Double>.cosh(): DoubleTensor = asDoubleTensor().map { cosh(it) }
override fun StructureND<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) } override fun StructureND<Double>.acosh(): DoubleTensor = asDoubleTensor().map { acosh(it) }
override fun StructureND<Double>.sin(): DoubleTensor = tensor.map { sin(it) } override fun StructureND<Double>.sin(): DoubleTensor = asDoubleTensor().map { sin(it) }
override fun StructureND<Double>.asin(): DoubleTensor = tensor.map { asin(it) } override fun StructureND<Double>.asin(): DoubleTensor = asDoubleTensor().map { asin(it) }
override fun StructureND<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) } override fun StructureND<Double>.sinh(): DoubleTensor = asDoubleTensor().map { sinh(it) }
override fun StructureND<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) } override fun StructureND<Double>.asinh(): DoubleTensor = asDoubleTensor().map { asinh(it) }
override fun StructureND<Double>.tan(): DoubleTensor = tensor.map { tan(it) } override fun StructureND<Double>.tan(): DoubleTensor = asDoubleTensor().map { tan(it) }
override fun StructureND<Double>.atan(): DoubleTensor = tensor.map { atan(it) } override fun StructureND<Double>.atan(): DoubleTensor = asDoubleTensor().map { atan(it) }
override fun StructureND<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) } override fun StructureND<Double>.tanh(): DoubleTensor = asDoubleTensor().map { tanh(it) }
override fun StructureND<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) } override fun StructureND<Double>.atanh(): DoubleTensor = asDoubleTensor().map { atanh(it) }
override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) } override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> = if (pow is Int) {
arg.map { it.pow(pow) }
} else {
arg.map { it.pow(pow.toDouble()) }
}
override fun StructureND<Double>.floor(): DoubleTensor = tensor.map { floor(it) } override fun StructureND<Double>.ceil(): DoubleTensor = asDoubleTensor().map { ceil(it) }
override fun StructureND<Double>.floor(): DoubleTensor = asDoubleTensor().map { floor(it) }
override fun StructureND<Double>.inv(): DoubleTensor = invLU(1e-9) override fun StructureND<Double>.inv(): DoubleTensor = invLU(1e-9)
@ -770,7 +776,7 @@ public open class DoubleTensorAlgebra :
* The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows. * The `pivots` has the shape ``(, min(m, n))``. `pivots` stores all the intermediate transpositions of rows.
*/ */
public fun StructureND<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> = public fun StructureND<Double>.luFactor(epsilon: Double): Pair<DoubleTensor, IntTensor> =
computeLU(tensor, epsilon) computeLU(asDoubleTensor(), epsilon)
?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon") ?: throw IllegalArgumentException("Tensor contains matrices which are singular at precision $epsilon")
/** /**
@ -809,7 +815,7 @@ public open class DoubleTensorAlgebra :
val pTensor = luTensor.zeroesLike() val pTensor = luTensor.zeroesLike()
pTensor pTensor
.matrixSequence() .matrixSequence()
.zip(pivotsTensor.tensor.vectorSequence()) .zip(pivotsTensor.asIntTensor().vectorSequence())
.forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) } .forEach { (p, pivot) -> pivInit(p.as2D(), pivot.as1D(), n) }
val lTensor = luTensor.zeroesLike() val lTensor = luTensor.zeroesLike()
@ -817,7 +823,7 @@ public open class DoubleTensorAlgebra :
lTensor.matrixSequence() lTensor.matrixSequence()
.zip(uTensor.matrixSequence()) .zip(uTensor.matrixSequence())
.zip(luTensor.tensor.matrixSequence()) .zip(luTensor.asDoubleTensor().matrixSequence())
.forEach { (pairLU, lu) -> .forEach { (pairLU, lu) ->
val (l, u) = pairLU val (l, u) = pairLU
luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n) luPivotHelper(l.as2D(), u.as2D(), lu.as2D(), n)
@ -841,12 +847,12 @@ public open class DoubleTensorAlgebra :
*/ */
public fun StructureND<Double>.cholesky(epsilon: Double): DoubleTensor { public fun StructureND<Double>.cholesky(epsilon: Double): DoubleTensor {
checkSquareMatrix(shape) checkSquareMatrix(shape)
checkPositiveDefinite(tensor, epsilon) checkPositiveDefinite(asDoubleTensor(), epsilon)
val n = shape.last() val n = shape.last()
val lTensor = zeroesLike() val lTensor = zeroesLike()
for ((a, l) in tensor.matrixSequence().zip(lTensor.matrixSequence())) for ((a, l) in asDoubleTensor().matrixSequence().zip(lTensor.matrixSequence()))
for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n) for (i in 0 until n) choleskyHelper(a.as2D(), l.as2D(), n)
return lTensor return lTensor
@ -858,13 +864,13 @@ public open class DoubleTensorAlgebra :
checkSquareMatrix(shape) checkSquareMatrix(shape)
val qTensor = zeroesLike() val qTensor = zeroesLike()
val rTensor = zeroesLike() val rTensor = zeroesLike()
tensor.matrixSequence() asDoubleTensor().matrixSequence()
.zip( .zip(
(qTensor.matrixSequence() (qTensor.matrixSequence()
.zip(rTensor.matrixSequence())) .zip(rTensor.matrixSequence()))
).forEach { (matrix, qr) -> ).forEach { (matrix, qr) ->
val (q, r) = qr val (q, r) = qr
qrHelper(matrix.asTensor(), q.asTensor(), r.as2D()) qrHelper(matrix.toTensor(), q.toTensor(), r.as2D())
} }
return qTensor to rTensor return qTensor to rTensor
@ -887,14 +893,14 @@ public open class DoubleTensorAlgebra :
* @return a triple `Triple(U, S, V)`. * @return a triple `Triple(U, S, V)`.
*/ */
public fun StructureND<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.svd(epsilon: Double): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val size = tensor.dimension val size = asDoubleTensor().dimension
val commonShape = tensor.shape.sliceArray(0 until size - 2) val commonShape = asDoubleTensor().shape.sliceArray(0 until size - 2)
val (n, m) = tensor.shape.sliceArray(size - 2 until size) val (n, m) = asDoubleTensor().shape.sliceArray(size - 2 until size)
val uTensor = zeros(commonShape + intArrayOf(min(n, m), n)) val uTensor = zeros(commonShape + intArrayOf(min(n, m), n))
val sTensor = zeros(commonShape + intArrayOf(min(n, m))) val sTensor = zeros(commonShape + intArrayOf(min(n, m)))
val vTensor = zeros(commonShape + intArrayOf(min(n, m), m)) val vTensor = zeros(commonShape + intArrayOf(min(n, m), m))
val matrices = tensor.matrices val matrices = asDoubleTensor().matrices
val uTensors = uTensor.matrices val uTensors = uTensor.matrices
val sTensorVectors = sTensor.vectors val sTensorVectors = sTensor.vectors
val vTensors = vTensor.matrices val vTensors = vTensor.matrices
@ -931,7 +937,7 @@ public open class DoubleTensorAlgebra :
* @return a pair `eigenvalues to eigenvectors`. * @return a pair `eigenvalues to eigenvectors`.
*/ */
public fun StructureND<Double>.symEigSvd(epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun StructureND<Double>.symEigSvd(epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(asDoubleTensor(), epsilon)
fun MutableStructure2D<Double>.cleanSym(n: Int) { fun MutableStructure2D<Double>.cleanSym(n: Int) {
for (i in 0 until n) { for (i in 0 until n) {
@ -945,7 +951,7 @@ public open class DoubleTensorAlgebra :
} }
} }
val (u, s, v) = tensor.svd(epsilon) val (u, s, v) = asDoubleTensor().svd(epsilon)
val shp = s.shape + intArrayOf(1) val shp = s.shape + intArrayOf(1)
val utv = u.transpose() matmul v val utv = u.transpose() matmul v
val n = s.shape.last() val n = s.shape.last()
@ -958,7 +964,7 @@ public open class DoubleTensorAlgebra :
} }
public fun StructureND<Double>.symEigJacobi(maxIteration: Int, epsilon: Double): Pair<DoubleTensor, DoubleTensor> { public fun StructureND<Double>.symEigJacobi(maxIteration: Int, epsilon: Double): Pair<DoubleTensor, DoubleTensor> {
checkSymmetric(tensor, epsilon) checkSymmetric(asDoubleTensor(), epsilon)
val size = this.dimension val size = this.dimension
val eigenvectors = zeros(this.shape) val eigenvectors = zeros(this.shape)
@ -966,7 +972,7 @@ public open class DoubleTensorAlgebra :
var eigenvalueStart = 0 var eigenvalueStart = 0
var eigenvectorStart = 0 var eigenvectorStart = 0
for (matrix in tensor.matrixSequence()) { for (matrix in asDoubleTensor().matrixSequence()) {
val matrix2D = matrix.as2D() val matrix2D = matrix.as2D()
val (d, v) = matrix2D.jacobiHelper(maxIteration, epsilon) val (d, v) = matrix2D.jacobiHelper(maxIteration, epsilon)
@ -1111,9 +1117,9 @@ public open class DoubleTensorAlgebra :
* @return the determinant. * @return the determinant.
*/ */
public fun StructureND<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor { public fun StructureND<Double>.detLU(epsilon: Double = 1e-9): DoubleTensor {
checkSquareMatrix(tensor.shape) checkSquareMatrix(asDoubleTensor().shape)
val luTensor = tensor.copy() val luTensor = asDoubleTensor().copy()
val pivotsTensor = tensor.setUpPivots() val pivotsTensor = asDoubleTensor().setUpPivots()
val n = shape.size val n = shape.size
@ -1169,14 +1175,14 @@ public open class DoubleTensorAlgebra :
* @return triple of `P`, `L` and `U` tensors. * @return triple of `P`, `L` and `U` tensors.
*/ */
public fun StructureND<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> { public fun StructureND<Double>.lu(epsilon: Double = 1e-9): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
val (lu, pivots) = tensor.luFactor(epsilon) val (lu, pivots) = asDoubleTensor().luFactor(epsilon)
return luPivot(lu, pivots) return luPivot(lu, pivots)
} }
override fun StructureND<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9) override fun StructureND<Double>.lu(): Triple<DoubleTensor, DoubleTensor, DoubleTensor> = lu(1e-9)
} }
public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra public val Double.Companion.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra
public val DoubleField.tensorAlgebra: DoubleTensorAlgebra.Companion get() = DoubleTensorAlgebra public val DoubleField.tensorAlgebra: DoubleTensorAlgebra get() = DoubleTensorAlgebra

View File

@ -39,7 +39,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
@PerformancePitfall @PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Int>.map(transform: IntRing.(Int) -> Int): IntTensor { final override inline fun StructureND<Int>.map(transform: IntRing.(Int) -> Int): IntTensor {
val tensor = this.tensor val tensor = this.asIntTensor()
//TODO remove additional copy //TODO remove additional copy
val sourceArray = tensor.copyArray() val sourceArray = tensor.copyArray()
val array = IntArray(tensor.numElements) { IntRing.transform(sourceArray[it]) } val array = IntArray(tensor.numElements) { IntRing.transform(sourceArray[it]) }
@ -53,7 +53,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
@PerformancePitfall @PerformancePitfall
@Suppress("OVERRIDE_BY_INLINE") @Suppress("OVERRIDE_BY_INLINE")
final override inline fun StructureND<Int>.mapIndexed(transform: IntRing.(index: IntArray, Int) -> Int): IntTensor { final override inline fun StructureND<Int>.mapIndexed(transform: IntRing.(index: IntArray, Int) -> Int): IntTensor {
val tensor = this.tensor val tensor = this.asIntTensor()
//TODO remove additional copy //TODO remove additional copy
val sourceArray = tensor.copyArray() val sourceArray = tensor.copyArray()
val array = IntArray(tensor.numElements) { IntRing.transform(tensor.indices.index(it), sourceArray[it]) } val array = IntArray(tensor.numElements) { IntRing.transform(tensor.indices.index(it), sourceArray[it]) }
@ -73,9 +73,9 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
require(left.shape.contentEquals(right.shape)) { require(left.shape.contentEquals(right.shape)) {
"The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}" "The shapes in zip are not equal: left - ${left.shape}, right - ${right.shape}"
} }
val leftTensor = left.tensor val leftTensor = left.asIntTensor()
val leftArray = leftTensor.copyArray() val leftArray = leftTensor.copyArray()
val rightTensor = right.tensor val rightTensor = right.asIntTensor()
val rightArray = rightTensor.copyArray() val rightArray = rightTensor.copyArray()
val array = IntArray(leftTensor.numElements) { IntRing.transform(leftArray[it], rightArray[it]) } val array = IntArray(leftTensor.numElements) { IntRing.transform(leftArray[it], rightArray[it]) }
return IntTensor( return IntTensor(
@ -84,8 +84,8 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
) )
} }
override fun StructureND<Int>.valueOrNull(): Int? = if (tensor.shape contentEquals intArrayOf(1)) override fun StructureND<Int>.valueOrNull(): Int? = if (asIntTensor().shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart] else null
override fun StructureND<Int>.value(): Int = valueOrNull() override fun StructureND<Int>.value(): Int = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]") ?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
@ -119,10 +119,10 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
) )
override operator fun Tensor<Int>.get(i: Int): IntTensor { override operator fun Tensor<Int>.get(i: Int): IntTensor {
val lastShape = tensor.shape.drop(1).toIntArray() val lastShape = asIntTensor().shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + tensor.bufferStart val newStart = newShape.reduce(Int::times) * i + asIntTensor().bufferStart
return IntTensor(newShape, tensor.mutableBuffer.array(), newStart) return IntTensor(newShape, asIntTensor().mutableBuffer.array(), newStart)
} }
/** /**
@ -145,8 +145,8 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
* @return tensor with the `input` tensor shape and filled with [value]. * @return tensor with the `input` tensor shape and filled with [value].
*/ */
public fun Tensor<Int>.fullLike(value: Int): IntTensor { public fun Tensor<Int>.fullLike(value: Int): IntTensor {
val shape = tensor.shape val shape = asIntTensor().shape
val buffer = IntArray(tensor.numElements) { value } val buffer = IntArray(asIntTensor().numElements) { value }
return IntTensor(shape, buffer) return IntTensor(shape, buffer)
} }
@ -163,7 +163,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
* *
* @return tensor filled with the scalar value `0`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `0`, with the same shape as `input` tensor.
*/ */
public fun StructureND<Int>.zeroesLike(): IntTensor = tensor.fullLike(0) public fun StructureND<Int>.zeroesLike(): IntTensor = asIntTensor().fullLike(0)
/** /**
* Returns a tensor filled with the scalar value `1`, with the shape defined by the variable argument [shape]. * Returns a tensor filled with the scalar value `1`, with the shape defined by the variable argument [shape].
@ -178,7 +178,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
* *
* @return tensor filled with the scalar value `1`, with the same shape as `input` tensor. * @return tensor filled with the scalar value `1`, with the same shape as `input` tensor.
*/ */
public fun Tensor<Int>.onesLike(): IntTensor = tensor.fullLike(1) public fun Tensor<Int>.onesLike(): IntTensor = asIntTensor().fullLike(1)
/** /**
* Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere. * Returns a 2D tensor with shape ([n], [n]), with ones on the diagonal and zeros elsewhere.
@ -202,145 +202,145 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
* @return a copy of the `input` tensor with a copied buffer. * @return a copy of the `input` tensor with a copied buffer.
*/ */
public fun StructureND<Int>.copy(): IntTensor = public fun StructureND<Int>.copy(): IntTensor =
IntTensor(tensor.shape, tensor.mutableBuffer.array().copyOf(), tensor.bufferStart) IntTensor(asIntTensor().shape, asIntTensor().mutableBuffer.array().copyOf(), asIntTensor().bufferStart)
override fun Int.plus(arg: StructureND<Int>): IntTensor { override fun Int.plus(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i -> val resBuffer = IntArray(arg.asIntTensor().numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] + this arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] + this
} }
return IntTensor(arg.shape, resBuffer) return IntTensor(arg.shape, resBuffer)
} }
override fun StructureND<Int>.plus(arg: Int): IntTensor = arg + tensor override fun StructureND<Int>.plus(arg: Int): IntTensor = arg + asIntTensor()
override fun StructureND<Int>.plus(arg: StructureND<Int>): IntTensor { override fun StructureND<Int>.plus(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg.tensor) checkShapesCompatible(asIntTensor(), arg.asIntTensor())
val resBuffer = IntArray(tensor.numElements) { i -> val resBuffer = IntArray(asIntTensor().numElements) { i ->
tensor.mutableBuffer.array()[i] + arg.tensor.mutableBuffer.array()[i] asIntTensor().mutableBuffer.array()[i] + arg.asIntTensor().mutableBuffer.array()[i]
} }
return IntTensor(tensor.shape, resBuffer) return IntTensor(asIntTensor().shape, resBuffer)
} }
override fun Tensor<Int>.plusAssign(value: Int) { override fun Tensor<Int>.plusAssign(value: Int) {
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += value asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] += value
} }
} }
override fun Tensor<Int>.plusAssign(arg: StructureND<Int>) { override fun Tensor<Int>.plusAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg.tensor) checkShapesCompatible(asIntTensor(), arg.asIntTensor())
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] += asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] +=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i]
} }
} }
override fun Int.minus(arg: StructureND<Int>): IntTensor { override fun Int.minus(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i -> val resBuffer = IntArray(arg.asIntTensor().numElements) { i ->
this - arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] this - arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i]
} }
return IntTensor(arg.shape, resBuffer) return IntTensor(arg.shape, resBuffer)
} }
override fun StructureND<Int>.minus(arg: Int): IntTensor { override fun StructureND<Int>.minus(arg: Int): IntTensor {
val resBuffer = IntArray(tensor.numElements) { i -> val resBuffer = IntArray(asIntTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] - arg asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] - arg
} }
return IntTensor(tensor.shape, resBuffer) return IntTensor(asIntTensor().shape, resBuffer)
} }
override fun StructureND<Int>.minus(arg: StructureND<Int>): IntTensor { override fun StructureND<Int>.minus(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asIntTensor(), arg)
val resBuffer = IntArray(tensor.numElements) { i -> val resBuffer = IntArray(asIntTensor().numElements) { i ->
tensor.mutableBuffer.array()[i] - arg.tensor.mutableBuffer.array()[i] asIntTensor().mutableBuffer.array()[i] - arg.asIntTensor().mutableBuffer.array()[i]
} }
return IntTensor(tensor.shape, resBuffer) return IntTensor(asIntTensor().shape, resBuffer)
} }
override fun Tensor<Int>.minusAssign(value: Int) { override fun Tensor<Int>.minusAssign(value: Int) {
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= value asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] -= value
} }
} }
override fun Tensor<Int>.minusAssign(arg: StructureND<Int>) { override fun Tensor<Int>.minusAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asIntTensor(), arg)
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] -= asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] -=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i]
} }
} }
override fun Int.times(arg: StructureND<Int>): IntTensor { override fun Int.times(arg: StructureND<Int>): IntTensor {
val resBuffer = IntArray(arg.tensor.numElements) { i -> val resBuffer = IntArray(arg.asIntTensor().numElements) { i ->
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] * this arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i] * this
} }
return IntTensor(arg.shape, resBuffer) return IntTensor(arg.shape, resBuffer)
} }
override fun StructureND<Int>.times(arg: Int): IntTensor = arg * tensor override fun StructureND<Int>.times(arg: Int): IntTensor = arg * asIntTensor()
override fun StructureND<Int>.times(arg: StructureND<Int>): IntTensor { override fun StructureND<Int>.times(arg: StructureND<Int>): IntTensor {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asIntTensor(), arg)
val resBuffer = IntArray(tensor.numElements) { i -> val resBuffer = IntArray(asIntTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i] * asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] *
arg.tensor.mutableBuffer.array()[arg.tensor.bufferStart + i] arg.asIntTensor().mutableBuffer.array()[arg.asIntTensor().bufferStart + i]
} }
return IntTensor(tensor.shape, resBuffer) return IntTensor(asIntTensor().shape, resBuffer)
} }
override fun Tensor<Int>.timesAssign(value: Int) { override fun Tensor<Int>.timesAssign(value: Int) {
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= value asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] *= value
} }
} }
override fun Tensor<Int>.timesAssign(arg: StructureND<Int>) { override fun Tensor<Int>.timesAssign(arg: StructureND<Int>) {
checkShapesCompatible(tensor, arg) checkShapesCompatible(asIntTensor(), arg)
for (i in 0 until tensor.numElements) { for (i in 0 until asIntTensor().numElements) {
tensor.mutableBuffer.array()[tensor.bufferStart + i] *= asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i] *=
arg.tensor.mutableBuffer.array()[tensor.bufferStart + i] arg.asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i]
} }
} }
override fun StructureND<Int>.unaryMinus(): IntTensor { override fun StructureND<Int>.unaryMinus(): IntTensor {
val resBuffer = IntArray(tensor.numElements) { i -> val resBuffer = IntArray(asIntTensor().numElements) { i ->
tensor.mutableBuffer.array()[tensor.bufferStart + i].unaryMinus() asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + i].unaryMinus()
} }
return IntTensor(tensor.shape, resBuffer) return IntTensor(asIntTensor().shape, resBuffer)
} }
override fun Tensor<Int>.transpose(i: Int, j: Int): IntTensor { override fun Tensor<Int>.transpose(i: Int, j: Int): IntTensor {
val ii = tensor.minusIndex(i) val ii = asIntTensor().minusIndex(i)
val jj = tensor.minusIndex(j) val jj = asIntTensor().minusIndex(j)
checkTranspose(tensor.dimension, ii, jj) checkTranspose(asIntTensor().dimension, ii, jj)
val n = tensor.numElements val n = asIntTensor().numElements
val resBuffer = IntArray(n) val resBuffer = IntArray(n)
val resShape = tensor.shape.copyOf() val resShape = asIntTensor().shape.copyOf()
resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] } resShape[ii] = resShape[jj].also { resShape[jj] = resShape[ii] }
val resTensor = IntTensor(resShape, resBuffer) val resTensor = IntTensor(resShape, resBuffer)
for (offset in 0 until n) { for (offset in 0 until n) {
val oldMultiIndex = tensor.indices.index(offset) val oldMultiIndex = asIntTensor().indices.index(offset)
val newMultiIndex = oldMultiIndex.copyOf() val newMultiIndex = oldMultiIndex.copyOf()
newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] } newMultiIndex[ii] = newMultiIndex[jj].also { newMultiIndex[jj] = newMultiIndex[ii] }
val linearIndex = resTensor.indices.offset(newMultiIndex) val linearIndex = resTensor.indices.offset(newMultiIndex)
resTensor.mutableBuffer.array()[linearIndex] = resTensor.mutableBuffer.array()[linearIndex] =
tensor.mutableBuffer.array()[tensor.bufferStart + offset] asIntTensor().mutableBuffer.array()[asIntTensor().bufferStart + offset]
} }
return resTensor return resTensor
} }
override fun Tensor<Int>.view(shape: IntArray): IntTensor { override fun Tensor<Int>.view(shape: IntArray): IntTensor {
checkView(tensor, shape) checkView(asIntTensor(), shape)
return IntTensor(shape, tensor.mutableBuffer.array(), tensor.bufferStart) return IntTensor(shape, asIntTensor().mutableBuffer.array(), asIntTensor().bufferStart)
} }
override fun Tensor<Int>.viewAs(other: StructureND<Int>): IntTensor = override fun Tensor<Int>.viewAs(other: StructureND<Int>): IntTensor =
tensor.view(other.shape) asIntTensor().view(other.shape)
override fun diagonalEmbedding( override fun diagonalEmbedding(
diagonalEntries: Tensor<Int>, diagonalEntries: Tensor<Int>,
@ -374,8 +374,8 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray() diagonalEntries.shape.slice(greaterDim - 1 until n - 1).toIntArray()
val resTensor = zeros(resShape) val resTensor = zeros(resShape)
for (i in 0 until diagonalEntries.tensor.numElements) { for (i in 0 until diagonalEntries.asIntTensor().numElements) {
val multiIndex = diagonalEntries.tensor.indices.index(i) val multiIndex = diagonalEntries.asIntTensor().indices.index(i)
var offset1 = 0 var offset1 = 0
var offset2 = abs(realOffset) var offset2 = abs(realOffset)
@ -391,19 +391,19 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex] resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
} }
return resTensor.tensor return resTensor.asIntTensor()
} }
private infix fun Tensor<Int>.eq( private infix fun Tensor<Int>.eq(
other: Tensor<Int>, other: Tensor<Int>,
): Boolean { ): Boolean {
checkShapesCompatible(tensor, other) checkShapesCompatible(asIntTensor(), other)
val n = tensor.numElements val n = asIntTensor().numElements
if (n != other.tensor.numElements) { if (n != other.asIntTensor().numElements) {
return false return false
} }
for (i in 0 until n) { for (i in 0 until n) {
if (tensor.mutableBuffer[tensor.bufferStart + i] != other.tensor.mutableBuffer[other.tensor.bufferStart + i]) { if (asIntTensor().mutableBuffer[asIntTensor().bufferStart + i] != other.asIntTensor().mutableBuffer[other.asIntTensor().bufferStart + i]) {
return false return false
} }
} }
@ -422,7 +422,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap { val resBuffer = tensors.flatMap {
it.tensor.mutableBuffer.array().drop(it.tensor.bufferStart).take(it.tensor.numElements) it.asIntTensor().mutableBuffer.array().drop(it.asIntTensor().bufferStart).take(it.asIntTensor().numElements)
}.toIntArray() }.toIntArray()
return IntTensor(resShape, resBuffer, 0) return IntTensor(resShape, resBuffer, 0)
} }
@ -436,7 +436,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
public fun Tensor<Int>.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { this[it] }) public fun Tensor<Int>.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { this[it] })
private inline fun StructureND<Int>.fold(foldFunction: (IntArray) -> Int): Int = private inline fun StructureND<Int>.fold(foldFunction: (IntArray) -> Int): Int =
foldFunction(tensor.copyArray()) foldFunction(asIntTensor().copyArray())
private inline fun <reified R : Any> StructureND<Int>.foldDim( private inline fun <reified R : Any> StructureND<Int>.foldDim(
dim: Int, dim: Int,
@ -459,35 +459,35 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
val prefix = index.take(dim).toIntArray() val prefix = index.take(dim).toIntArray()
val suffix = index.takeLast(dimension - dim - 1).toIntArray() val suffix = index.takeLast(dimension - dim - 1).toIntArray()
resTensor[index] = foldFunction(IntArray(shape[dim]) { i -> resTensor[index] = foldFunction(IntArray(shape[dim]) { i ->
tensor[prefix + intArrayOf(i) + suffix] asIntTensor()[prefix + intArrayOf(i) + suffix]
}) })
} }
return resTensor return resTensor
} }
override fun StructureND<Int>.sum(): Int = tensor.fold { it.sum() } override fun StructureND<Int>.sum(): Int = asIntTensor().fold { it.sum() }
override fun StructureND<Int>.sum(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.sum(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.sum() }.toIntTensor() foldDim(dim, keepDim) { x -> x.sum() }.asIntTensor()
override fun StructureND<Int>.min(): Int = this.fold { it.minOrNull()!! } override fun StructureND<Int>.min(): Int = this.fold { it.minOrNull()!! }
override fun StructureND<Int>.min(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.min(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.minOrNull()!! }.toIntTensor() foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asIntTensor()
override fun StructureND<Int>.max(): Int = this.fold { it.maxOrNull()!! } override fun StructureND<Int>.max(): Int = this.fold { it.maxOrNull()!! }
override fun StructureND<Int>.max(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.max(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.toIntTensor() foldDim(dim, keepDim) { x -> x.maxOrNull()!! }.asIntTensor()
override fun StructureND<Int>.argMax(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.argMax(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> foldDim(dim, keepDim) { x ->
x.withIndex().maxByOrNull { it.value }?.index!! x.withIndex().maxByOrNull { it.value }?.index!!
}.toIntTensor() }.asIntTensor()
} }
public val Int.Companion.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra public val Int.Companion.tensorAlgebra: IntTensorAlgebra get() = IntTensorAlgebra
public val IntRing.tensorAlgebra: IntTensorAlgebra.Companion get() = IntTensorAlgebra public val IntRing.tensorAlgebra: IntTensorAlgebra get() = IntTensorAlgebra

View File

@ -58,7 +58,7 @@ internal fun DoubleTensorAlgebra.checkSymmetric(
internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) { internal fun DoubleTensorAlgebra.checkPositiveDefinite(tensor: DoubleTensor, epsilon: Double = 1e-6) {
checkSymmetric(tensor, epsilon) checkSymmetric(tensor, epsilon)
for (mat in tensor.matrixSequence()) for (mat in tensor.matrixSequence())
check(mat.asTensor().detLU().value() > 0.0) { check(mat.toTensor().detLU().value() > 0.0) {
"Tensor contains matrices which are not positive definite ${mat.asTensor().detLU().value()}" "Tensor contains matrices which are not positive definite ${mat.toTensor().detLU().value()}"
} }
} }

View File

@ -13,16 +13,17 @@ import space.kscience.kmath.tensors.core.DoubleTensor
import space.kscience.kmath.tensors.core.IntTensor import space.kscience.kmath.tensors.core.IntTensor
import space.kscience.kmath.tensors.core.TensorLinearStructure import space.kscience.kmath.tensors.core.TensorLinearStructure
internal fun BufferedTensor<Int>.asTensor(): IntTensor = internal fun BufferedTensor<Int>.toTensor(): IntTensor =
IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) IntTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun BufferedTensor<Double>.asTensor(): DoubleTensor = internal fun BufferedTensor<Double>.toTensor(): DoubleTensor =
DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart) DoubleTensor(this.shape, this.mutableBuffer.array(), this.bufferStart)
internal fun <T> StructureND<T>.copyToBufferedTensor(): BufferedTensor<T> = internal fun <T> StructureND<T>.copyToBufferedTensor(): BufferedTensor<T> =
BufferedTensor( BufferedTensor(
this.shape, this.shape,
TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().asMutableBuffer(), 0 TensorLinearStructure(this.shape).asSequence().map(this::get).toMutableList().asMutableBuffer(),
0
) )
internal fun <T> StructureND<T>.toBufferedTensor(): BufferedTensor<T> = when (this) { internal fun <T> StructureND<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
@ -34,17 +35,3 @@ internal fun <T> StructureND<T>.toBufferedTensor(): BufferedTensor<T> = when (th
} }
else -> this.copyToBufferedTensor() else -> this.copyToBufferedTensor()
} }
@PublishedApi
internal val StructureND<Double>.tensor: DoubleTensor
get() = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().asTensor()
}
@PublishedApi
internal val StructureND<Int>.tensor: IntTensor
get() = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().asTensor()
}

View File

@ -5,18 +5,26 @@
package space.kscience.kmath.tensors.core package space.kscience.kmath.tensors.core
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.tensor import space.kscience.kmath.tensors.core.internal.toBufferedTensor
import space.kscience.kmath.tensors.core.internal.toTensor
/** /**
* Casts [Tensor] of [Double] to [DoubleTensor] * Casts [Tensor] of [Double] to [DoubleTensor]
*/ */
public fun Tensor<Double>.toDoubleTensor(): DoubleTensor = this.tensor public fun StructureND<Double>.asDoubleTensor(): DoubleTensor = when (this) {
is DoubleTensor -> this
else -> this.toBufferedTensor().toTensor()
}
/** /**
* Casts [Tensor] of [Int] to [IntTensor] * Casts [Tensor] of [Int] to [IntTensor]
*/ */
public fun Tensor<Int>.toIntTensor(): IntTensor = this.tensor public fun StructureND<Int>.asIntTensor(): IntTensor = when (this) {
is IntTensor -> this
else -> this.toBufferedTensor().toTensor()
}
/** /**
* Returns a copy-protected [DoubleArray] of tensor elements * Returns a copy-protected [DoubleArray] of tensor elements

View File

@ -14,9 +14,9 @@ import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.DoubleBuffer import space.kscience.kmath.structures.DoubleBuffer
import space.kscience.kmath.structures.toDoubleArray import space.kscience.kmath.structures.toDoubleArray
import space.kscience.kmath.tensors.core.internal.array import space.kscience.kmath.tensors.core.internal.array
import space.kscience.kmath.tensors.core.internal.asTensor
import space.kscience.kmath.tensors.core.internal.matrixSequence import space.kscience.kmath.tensors.core.internal.matrixSequence
import space.kscience.kmath.tensors.core.internal.toBufferedTensor import space.kscience.kmath.tensors.core.internal.toBufferedTensor
import space.kscience.kmath.tensors.core.internal.toTensor
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -56,7 +56,7 @@ internal class TestDoubleTensor {
assertEquals(tensor[intArrayOf(0, 1, 0)], 109.56) assertEquals(tensor[intArrayOf(0, 1, 0)], 109.56)
tensor.matrixSequence().forEach { tensor.matrixSequence().forEach {
val a = it.asTensor() val a = it.toTensor()
val secondRow = a[1].as1D() val secondRow = a[1].as1D()
val secondColumn = a.transpose(0, 1)[1].as1D() val secondColumn = a.transpose(0, 1)[1].as1D()
assertEquals(secondColumn[0], 77.89) assertEquals(secondColumn[0], 77.89)
@ -75,10 +75,10 @@ internal class TestDoubleTensor {
// map to tensors // map to tensors
val bufferedTensorArray = ndArray.toBufferedTensor() // strides are flipped so data copied val bufferedTensorArray = ndArray.toBufferedTensor() // strides are flipped so data copied
val tensorArray = bufferedTensorArray.asTensor() // data not contiguous so copied again val tensorArray = bufferedTensorArray.toTensor() // data not contiguous so copied again
val tensorArrayPublic = ndArray.toDoubleTensor() // public API, data copied twice val tensorArrayPublic = ndArray.asDoubleTensor() // public API, data copied twice
val sharedTensorArray = tensorArrayPublic.toDoubleTensor() // no data copied by matching type val sharedTensorArray = tensorArrayPublic.asDoubleTensor() // no data copied by matching type
assertTrue(tensorArray.mutableBuffer.array() contentEquals sharedTensorArray.mutableBuffer.array()) assertTrue(tensorArray.mutableBuffer.array() contentEquals sharedTensorArray.mutableBuffer.array())