forked from kscience/kmath
Fixed getting value test for tensors
This commit is contained in:
parent
6298189fb3
commit
3a37b88b5c
@ -3,8 +3,6 @@ package space.kscience.kmath.tensors
|
|||||||
import space.kscience.kmath.nd.MutableNDBuffer
|
import space.kscience.kmath.nd.MutableNDBuffer
|
||||||
import space.kscience.kmath.structures.RealBuffer
|
import space.kscience.kmath.structures.RealBuffer
|
||||||
import space.kscience.kmath.structures.array
|
import space.kscience.kmath.structures.array
|
||||||
import kotlin.js.JsName
|
|
||||||
import kotlin.math.abs
|
|
||||||
|
|
||||||
|
|
||||||
public class RealTensor(
|
public class RealTensor(
|
||||||
@ -19,11 +17,9 @@ public class RealTensor(
|
|||||||
|
|
||||||
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
|
|
||||||
//rename to item?
|
|
||||||
override fun RealTensor.value(): Double {
|
override fun RealTensor.value(): Double {
|
||||||
check(this.dimension == 0) {
|
check(this.shape contentEquals intArrayOf(1)) {
|
||||||
// todo change message
|
"Inconsistent value for tensor of shape ${shape.toList()}"
|
||||||
"This tensor has shape ${shape.toList()}"
|
|
||||||
}
|
}
|
||||||
return this.buffer.array[0]
|
return this.buffer.array[0]
|
||||||
}
|
}
|
||||||
|
@ -48,5 +48,5 @@ public class TensorStrides(override val shape: IntArray): Strides
|
|||||||
indexFromOffset(offset, strides, shape.size)
|
indexFromOffset(offset, strides, shape.size)
|
||||||
|
|
||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = shape.fold(1) { acc, i -> acc * i }
|
get() = shape.reduce(Int::times)
|
||||||
}
|
}
|
@ -9,9 +9,9 @@ import kotlin.test.assertTrue
|
|||||||
class TestRealTensor {
|
class TestRealTensor {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun valueTest(){
|
fun valueTest() = RealTensorAlgebra {
|
||||||
val value = 12.5
|
val value = 12.5
|
||||||
val tensor = RealTensor(IntArray(0), doubleArrayOf(value))
|
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(value))
|
||||||
assertEquals(tensor.value(), value)
|
assertEquals(tensor.value(), value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user