Rename Tensor::get to Tensor::getTensor to avoid name clash.

This commit is contained in:
Alexander Nozik 2022-09-05 23:24:01 +03:00
parent a9821772db
commit 3729faf49b
No known key found for this signature in database
GPG Key ID: F7FCF2DD25C71357
12 changed files with 90 additions and 30 deletions

View File

@ -58,7 +58,7 @@ fun main(): Unit = Double.tensorAlgebra.withBroadcast { // work in context with
// and find out eigenvector of it // and find out eigenvector of it
val (_, evecs) = covMatrix.symEig() val (_, evecs) = covMatrix.symEig()
val v = evecs[0] val v = evecs.getTensor(0)
println("Eigenvector:\n$v") println("Eigenvector:\n$v")
// reduce dimension of dataset // reduce dimension of dataset
@ -68,7 +68,7 @@ fun main(): Unit = Double.tensorAlgebra.withBroadcast { // work in context with
// we can restore original data from reduced data; // we can restore original data from reduced data;
// for example, find 7th element of dataset. // for example, find 7th element of dataset.
val n = 7 val n = 7
val restored = (datasetReduced[n] dot v.view(intArrayOf(1, 2))) * std + mean val restored = (datasetReduced.getTensor(n) dot v.view(intArrayOf(1, 2))) * std + mean
println("Original value:\n${dataset[n]}") println("Original value:\n${dataset.getTensor(n)}")
println("Restored value:\n$restored") println("Restored value:\n$restored")
} }

View File

@ -66,7 +66,7 @@ fun main() = Double.tensorAlgebra.withBroadcast {// work in context with linear
val n = l.shape[0] val n = l.shape[0]
val x = zeros(intArrayOf(n)) val x = zeros(intArrayOf(n))
for (i in 0 until n) { for (i in 0 until n) {
x[intArrayOf(i)] = (b[intArrayOf(i)] - l[i].dot(x).value()) / l[intArrayOf(i, i)] x[intArrayOf(i)] = (b[intArrayOf(i)] - l.getTensor(i).dot(x).value()) / l[intArrayOf(i, i)]
} }
return x return x
} }

View File

@ -197,7 +197,7 @@ fun main() = BroadcastDoubleTensorAlgebra {
val y = fromArray( val y = fromArray(
intArrayOf(sampleSize, 1), intArrayOf(sampleSize, 1),
DoubleArray(sampleSize) { i -> DoubleArray(sampleSize) { i ->
if (x[i].sum() > 0.0) { if (x.getTensor(i).sum() > 0.0) {
1.0 1.0
} else { } else {
0.0 0.0

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.misc
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.Ring
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.structures.Buffer
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
/** /**
@ -42,8 +43,8 @@ public inline fun <T, R> List<T>.cumulative(initial: R, crossinline operation: (
/** /**
* Cumulative sum with custom space * Cumulative sum with custom space
*/ */
public fun <T> Iterable<T>.cumulativeSum(group: Ring<T>): Iterable<T> = public fun <T> Iterable<T>.cumulativeSum(ring: Ring<T>): Iterable<T> =
group { cumulative(zero) { element: T, sum: T -> sum + element } } ring { cumulative(zero) { element: T, sum: T -> sum + element } }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
public fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element } public fun Iterable<Double>.cumulativeSum(): Iterable<Double> = cumulative(0.0) { element, sum -> sum + element }
@ -54,8 +55,8 @@ public fun Iterable<Int>.cumulativeSum(): Iterable<Int> = cumulative(0) { elemen
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
public fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element } public fun Iterable<Long>.cumulativeSum(): Iterable<Long> = cumulative(0L) { element, sum -> sum + element }
public fun <T> Sequence<T>.cumulativeSum(group: Ring<T>): Sequence<T> = public fun <T> Sequence<T>.cumulativeSum(ring: Ring<T>): Sequence<T> =
group { cumulative(zero) { element: T, sum: T -> sum + element } } ring { cumulative(zero) { element: T, sum: T -> sum + element } }
@JvmName("cumulativeSumOfDouble") @JvmName("cumulativeSumOfDouble")
public fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element } public fun Sequence<Double>.cumulativeSum(): Sequence<Double> = cumulative(0.0) { element, sum -> sum + element }
@ -77,3 +78,12 @@ public fun List<Int>.cumulativeSum(): List<Int> = cumulative(0) { element, sum -
@JvmName("cumulativeSumOfLong") @JvmName("cumulativeSumOfLong")
public fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element } public fun List<Long>.cumulativeSum(): List<Long> = cumulative(0L) { element, sum -> sum + element }
public fun <T> Buffer<T>.cumulativeSum(ring: Ring<T>): Buffer<T> = with(ring) {
var accumulator: T = zero
return bufferFactory(size) {
accumulator += get(it)
accumulator
}
}

View File

@ -185,7 +185,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
override fun StructureND<T>.unaryMinus(): MultikTensor<T> = override fun StructureND<T>.unaryMinus(): MultikTensor<T> =
asMultik().array.unaryMinus().wrap() asMultik().array.unaryMinus().wrap()
override fun Tensor<T>.get(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap() override fun Tensor<T>.getTensor(i: Int): MultikTensor<T> = asMultik().array.mutableView(i).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap() override fun Tensor<T>.transpose(i: Int, j: Int): MultikTensor<T> = asMultik().array.transpose(i, j).wrap()
@ -246,6 +246,12 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
return multikMath.minDN(asMultik().array, dim).wrap() return multikMath.minDN(asMultik().array, dim).wrap()
} }
override fun StructureND<T>.argMin(dim: Int, keepDim: Boolean): Tensor<Int> {
if (keepDim) TODO("keepDim not implemented")
val res = multikMath.argMinDN(asMultik().array, dim)
return with(MultikIntAlgebra(multikEngine)) { res.wrap() }
}
override fun StructureND<T>.max(): T? = asMultik().array.max() override fun StructureND<T>.max(): T? = asMultik().array.max()
override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> { override fun StructureND<T>.max(dim: Int, keepDim: Boolean): Tensor<T> {

View File

@ -95,7 +95,7 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
} }
override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap() override fun StructureND<T>.unaryMinus(): Nd4jArrayStructure<T> = ndArray.neg().wrap()
override fun Tensor<T>.get(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap() override fun Tensor<T>.getTensor(i: Int): Nd4jArrayStructure<T> = ndArray.slice(i.toLong()).wrap()
override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap() override fun Tensor<T>.transpose(i: Int, j: Int): Nd4jArrayStructure<T> = ndArray.swapAxes(i, j).wrap()
override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap() override fun StructureND<T>.dot(other: StructureND<T>): Nd4jArrayStructure<T> = ndArray.mmul(other.ndArray).wrap()
@ -111,6 +111,9 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap() override fun Tensor<T>.view(shape: IntArray): Nd4jArrayStructure<T> = ndArray.reshape(shape).wrap()
override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape) override fun Tensor<T>.viewAs(other: StructureND<T>): Nd4jArrayStructure<T> = view(other.shape)
override fun StructureND<T>.argMin(dim: Int, keepDim: Boolean): Tensor<Int> =
ndBase.get().argmin(ndArray, keepDim, dim).asIntStructure()
override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> = override fun StructureND<T>.argMax(dim: Int, keepDim: Boolean): Tensor<Int> =
ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure() ndBase.get().argmax(ndArray, keepDim, dim).asIntStructure()

View File

@ -184,7 +184,7 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg) override fun StructureND<T>.unaryMinus(): TensorFlowOutput<T, TT> = operate(ops.math::neg)
override fun Tensor<T>.get(i: Int): Tensor<T> = operate { override fun Tensor<T>.getTensor(i: Int): Tensor<T> = operate {
StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong())) StridedSliceHelper.stridedSlice(ops.scope(), it, Indices.at(i.toLong()))
} }
@ -238,6 +238,11 @@ public abstract class TensorFlowAlgebra<T, TT : TNumber, A : Ring<T>> internal c
ops.min(it, ops.constant(dim), Min.keepDims(keepDim)) ops.min(it, ops.constant(dim), Min.keepDims(keepDim))
} }
override fun StructureND<T>.argMin(dim: Int, keepDim: Boolean): Tensor<Int> = IntTensorFlowOutput(
graph,
ops.math.argMin(asTensorFlow().output, ops.constant(dim), TInt32::class.java).output()
).actualTensor
override fun StructureND<T>.max(): T = operate { override fun StructureND<T>.max(): T = operate {
ops.max(it, ops.constant(intArrayOf())) ops.max(it, ops.constant(intArrayOf()))
}.value() }.value()

View File

@ -166,7 +166,11 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
* @param i index of the extractable tensor * @param i index of the extractable tensor
* @return subtensor of the original tensor with index [i] * @return subtensor of the original tensor with index [i]
*/ */
public operator fun Tensor<T>.get(i: Int): Tensor<T> public fun Tensor<T>.getTensor(i: Int): Tensor<T>
public fun Tensor<T>.getTensor(first: Int, second: Int): Tensor<T> {
return getTensor(first).getTensor(second)
}
/** /**
* Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped. * Returns a tensor that is a transposed version of this tensor. The given dimensions [i] and [j] are swapped.
@ -286,6 +290,19 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
*/ */
public fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T> public fun StructureND<T>.min(dim: Int, keepDim: Boolean): Tensor<T>
/**
* Returns the index of minimum value of each row of the input tensor in the given dimension [dim].
*
* If [keepDim] is true, the output tensor is of the same size as
* input except in the dimension [dim] where it is of size 1.
* Otherwise, [dim] is squeezed, resulting in the output tensor having 1 fewer dimension.
*
* @param dim the dimension to reduce.
* @param keepDim whether the output tensor has [dim] retained or not.
* @return the index of maximum value of each row of the input tensor in the given dimension [dim].
*/
public fun StructureND<T>.argMin(dim: Int, keepDim: Boolean): Tensor<Int>
/** /**
* Returns the maximum value of all elements in the input tensor or null if there are no values * Returns the maximum value of all elements in the input tensor or null if there are no values
*/ */
@ -320,4 +337,4 @@ public interface TensorAlgebra<T, A : Ring<T>> : RingOpsND<T, A> {
override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right override fun add(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left + right
override fun multiply(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left * right override fun multiply(left: StructureND<T>, right: StructureND<T>): Tensor<T> = left * right
} }

View File

@ -24,6 +24,7 @@ import kotlin.math.*
/** /**
* Implementation of basic operations over double tensors and basic algebra operations on them. * Implementation of basic operations over double tensors and basic algebra operations on them.
*/ */
@OptIn(PerformancePitfall::class)
public open class DoubleTensorAlgebra : public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double, DoubleField>, TensorPartialDivisionAlgebra<Double, DoubleField>,
AnalyticTensorAlgebra<Double, DoubleField>, AnalyticTensorAlgebra<Double, DoubleField>,
@ -120,7 +121,7 @@ public open class DoubleTensorAlgebra :
TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray() TensorLinearStructure(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
) )
override operator fun Tensor<Double>.get(i: Int): DoubleTensor { override fun Tensor<Double>.getTensor(i: Int): DoubleTensor {
val lastShape = asDoubleTensor().shape.drop(1).toIntArray() val lastShape = asDoubleTensor().shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + asDoubleTensor().bufferStart val newStart = newShape.reduce(Int::times) * i + asDoubleTensor().bufferStart
@ -204,7 +205,11 @@ public open class DoubleTensorAlgebra :
* @return a copy of the `input` tensor with a copied buffer. * @return a copy of the `input` tensor with a copied buffer.
*/ */
public fun StructureND<Double>.copy(): DoubleTensor = public fun StructureND<Double>.copy(): DoubleTensor =
DoubleTensor(asDoubleTensor().shape, asDoubleTensor().mutableBuffer.array().copyOf(), asDoubleTensor().bufferStart) DoubleTensor(
asDoubleTensor().shape,
asDoubleTensor().mutableBuffer.array().copyOf(),
asDoubleTensor().bufferStart
)
override fun Double.plus(arg: StructureND<Double>): DoubleTensor { override fun Double.plus(arg: StructureND<Double>): DoubleTensor {
val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i -> val resBuffer = DoubleArray(arg.asDoubleTensor().numElements) { i ->
@ -413,7 +418,10 @@ public open class DoubleTensorAlgebra :
@UnstableKMathAPI @UnstableKMathAPI
public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor { public infix fun StructureND<Double>.matmul(other: StructureND<Double>): DoubleTensor {
if (asDoubleTensor().shape.size == 1 && other.shape.size == 1) { if (asDoubleTensor().shape.size == 1 && other.shape.size == 1) {
return DoubleTensor(intArrayOf(1), doubleArrayOf(asDoubleTensor().times(other).asDoubleTensor().mutableBuffer.array().sum())) return DoubleTensor(
intArrayOf(1),
doubleArrayOf(asDoubleTensor().times(other).asDoubleTensor().mutableBuffer.array().sum())
)
} }
var newThis = asDoubleTensor().copy() var newThis = asDoubleTensor().copy()
@ -592,7 +600,8 @@ public open class DoubleTensorAlgebra :
check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" } check(tensors.all { it.shape contentEquals shape }) { "Tensors must have same shapes" }
val resShape = intArrayOf(tensors.size) + shape val resShape = intArrayOf(tensors.size) + shape
val resBuffer = tensors.flatMap { val resBuffer = tensors.flatMap {
it.asDoubleTensor().mutableBuffer.array().drop(it.asDoubleTensor().bufferStart).take(it.asDoubleTensor().numElements) it.asDoubleTensor().mutableBuffer.array().drop(it.asDoubleTensor().bufferStart)
.take(it.asDoubleTensor().numElements)
}.toDoubleArray() }.toDoubleArray()
return DoubleTensor(resShape, resBuffer, 0) return DoubleTensor(resShape, resBuffer, 0)
} }
@ -603,7 +612,7 @@ public open class DoubleTensorAlgebra :
* @param indices the [IntArray] of 1-dimensional indices * @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to row by [indices] * @return tensor with rows corresponding to row by [indices]
*/ */
public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { this[it] }) public fun Tensor<Double>.rowsByIndices(indices: IntArray): DoubleTensor = stack(indices.map { getTensor(it) })
private inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double = private inline fun StructureND<Double>.fold(foldFunction: (DoubleArray) -> Double): Double =
foldFunction(asDoubleTensor().copyArray()) foldFunction(asDoubleTensor().copyArray())
@ -645,6 +654,10 @@ public open class DoubleTensorAlgebra :
override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.min(dim: Int, keepDim: Boolean): DoubleTensor =
foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asDoubleTensor() foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asDoubleTensor()
override fun StructureND<Double>.argMin(dim: Int, keepDim: Boolean): Tensor<Int> = foldDim(dim, keepDim) { x ->
x.withIndex().minByOrNull { it.value }?.index!!
}.asIntTensor()
override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! } override fun StructureND<Double>.max(): Double = this.fold { it.maxOrNull()!! }
override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor = override fun StructureND<Double>.max(dim: Int, keepDim: Boolean): DoubleTensor =

View File

@ -118,7 +118,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
TensorLinearStructure(shape).asSequence().map { IntRing.initializer(it) }.toMutableList().toIntArray() TensorLinearStructure(shape).asSequence().map { IntRing.initializer(it) }.toMutableList().toIntArray()
) )
override operator fun Tensor<Int>.get(i: Int): IntTensor { override fun Tensor<Int>.getTensor(i: Int): IntTensor {
val lastShape = asIntTensor().shape.drop(1).toIntArray() val lastShape = asIntTensor().shape.drop(1).toIntArray()
val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1) val newShape = if (lastShape.isNotEmpty()) lastShape else intArrayOf(1)
val newStart = newShape.reduce(Int::times) * i + asIntTensor().bufferStart val newStart = newShape.reduce(Int::times) * i + asIntTensor().bufferStart
@ -433,7 +433,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
* @param indices the [IntArray] of 1-dimensional indices * @param indices the [IntArray] of 1-dimensional indices
* @return tensor with rows corresponding to row by [indices] * @return tensor with rows corresponding to row by [indices]
*/ */
public fun Tensor<Int>.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { this[it] }) public fun Tensor<Int>.rowsByIndices(indices: IntArray): IntTensor = stack(indices.map { getTensor(it) })
private inline fun StructureND<Int>.fold(foldFunction: (IntArray) -> Int): Int = private inline fun StructureND<Int>.fold(foldFunction: (IntArray) -> Int): Int =
foldFunction(asIntTensor().copyArray()) foldFunction(asIntTensor().copyArray())
@ -475,6 +475,11 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
override fun StructureND<Int>.min(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.min(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asIntTensor() foldDim(dim, keepDim) { x -> x.minOrNull()!! }.asIntTensor()
override fun StructureND<Int>.argMin(dim: Int, keepDim: Boolean): IntTensor =
foldDim(dim, keepDim) { x ->
x.withIndex().minByOrNull { it.value }?.index!!
}.asIntTensor()
override fun StructureND<Int>.max(): Int = this.fold { it.maxOrNull()!! } override fun StructureND<Int>.max(): Int = this.fold { it.maxOrNull()!! }
override fun StructureND<Int>.max(dim: Int, keepDim: Boolean): IntTensor = override fun StructureND<Int>.max(dim: Int, keepDim: Boolean): IntTensor =

View File

@ -257,13 +257,13 @@ internal fun DoubleTensorAlgebra.qrHelper(
val qT = q.transpose(0, 1) val qT = q.transpose(0, 1)
for (j in 0 until n) { for (j in 0 until n) {
val v = matrixT[j] val v = matrixT.getTensor(j)
val vv = v.as1D() val vv = v.as1D()
if (j > 0) { if (j > 0) {
for (i in 0 until j) { for (i in 0 until j) {
r[i, j] = (qT[i] dot matrixT[j]).value() r[i, j] = (qT.getTensor(i) dot matrixT.getTensor(j)).value()
for (k in 0 until n) { for (k in 0 until n) {
val qTi = qT[i].as1D() val qTi = qT.getTensor(i).as1D()
vv[k] = vv[k] - r[i, j] * qTi[k] vv[k] = vv[k] - r[i, j] * qTi[k]
} }
} }
@ -313,7 +313,7 @@ internal fun DoubleTensorAlgebra.svdHelper(
val outerProduct = DoubleArray(u.shape[0] * v.shape[0]) val outerProduct = DoubleArray(u.shape[0] * v.shape[0])
for (i in 0 until u.shape[0]) { for (i in 0 until u.shape[0]) {
for (j in 0 until v.shape[0]) { for (j in 0 until v.shape[0]) {
outerProduct[i * v.shape[0] + j] = u[i].value() * v[j].value() outerProduct[i * v.shape[0] + j] = u.getTensor(i).value() * v.getTensor(j).value()
} }
} }
a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct)) a = a - singularValue.times(DoubleTensor(intArrayOf(u.shape[0], v.shape[0]), outerProduct))

View File

@ -36,17 +36,18 @@ internal class TestDoubleTensor {
val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4)) val tensor = fromArray(intArrayOf(2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
assertEquals(tensor[intArrayOf(0, 1)], 5.8) assertEquals(tensor[intArrayOf(0, 1)], 5.8)
assertTrue( assertTrue(
tensor.elements().map { it.second }.toList().toDoubleArray() contentEquals tensor.mutableBuffer.toDoubleArray() tensor.elements().map { it.second }.toList()
.toDoubleArray() contentEquals tensor.mutableBuffer.toDoubleArray()
) )
} }
@Test @Test
fun testGet() = DoubleTensorAlgebra { fun testGet() = DoubleTensorAlgebra {
val tensor = fromArray(intArrayOf(1, 2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4)) val tensor = fromArray(intArrayOf(1, 2, 2), doubleArrayOf(3.5, 5.8, 58.4, 2.4))
val matrix = tensor[0].as2D() val matrix = tensor.getTensor(0).as2D()
assertEquals(matrix[0, 1], 5.8) assertEquals(matrix[0, 1], 5.8)
val vector = tensor[0][1].as1D() val vector = tensor.getTensor(0, 1).as1D()
assertEquals(vector[0], 58.4) assertEquals(vector[0], 58.4)
matrix[0, 1] = 77.89 matrix[0, 1] = 77.89
@ -57,8 +58,8 @@ internal class TestDoubleTensor {
tensor.matrixSequence().forEach { tensor.matrixSequence().forEach {
val a = it.toTensor() val a = it.toTensor()
val secondRow = a[1].as1D() val secondRow = a.getTensor(1).as1D()
val secondColumn = a.transpose(0, 1)[1].as1D() val secondColumn = a.transpose(0, 1).getTensor(1).as1D()
assertEquals(secondColumn[0], 77.89) assertEquals(secondColumn[0], 77.89)
assertEquals(secondRow[1], secondColumn[1]) assertEquals(secondRow[1], secondColumn[1])
} }