forked from kscience/kmath
Fixed getting value test for tensors
This commit is contained in:
parent
6298189fb3
commit
3a37b88b5c
@ -3,8 +3,6 @@ package space.kscience.kmath.tensors
|
||||
import space.kscience.kmath.nd.MutableNDBuffer
|
||||
import space.kscience.kmath.structures.RealBuffer
|
||||
import space.kscience.kmath.structures.array
|
||||
import kotlin.js.JsName
|
||||
import kotlin.math.abs
|
||||
|
||||
|
||||
public class RealTensor(
|
||||
@ -19,11 +17,9 @@ public class RealTensor(
|
||||
|
||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||
|
||||
//rename to item?
|
||||
override fun RealTensor.value(): Double {
|
||||
check(this.dimension == 0) {
|
||||
// todo change message
|
||||
"This tensor has shape ${shape.toList()}"
|
||||
check(this.shape contentEquals intArrayOf(1)) {
|
||||
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||
}
|
||||
return this.buffer.array[0]
|
||||
}
|
||||
|
@ -48,5 +48,5 @@ public class TensorStrides(override val shape: IntArray): Strides
|
||||
indexFromOffset(offset, strides, shape.size)
|
||||
|
||||
override val linearSize: Int
|
||||
get() = shape.fold(1) { acc, i -> acc * i }
|
||||
get() = shape.reduce(Int::times)
|
||||
}
|
@ -9,9 +9,9 @@ import kotlin.test.assertTrue
|
||||
class TestRealTensor {
|
||||
|
||||
@Test
|
||||
fun valueTest(){
|
||||
fun valueTest() = RealTensorAlgebra {
|
||||
val value = 12.5
|
||||
val tensor = RealTensor(IntArray(0), doubleArrayOf(value))
|
||||
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(value))
|
||||
assertEquals(tensor.value(), value)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user