KMP library for tensors #300

Merged
grinisrit merged 215 commits from feature/tensor-algebra into dev 2021-05-08 09:48:04 +03:00
3 changed files with 5 additions and 9 deletions
Showing only changes of commit 3a37b88b5c - Show all commits

View File

@ -3,8 +3,6 @@ package space.kscience.kmath.tensors
import space.kscience.kmath.nd.MutableNDBuffer import space.kscience.kmath.nd.MutableNDBuffer
import space.kscience.kmath.structures.RealBuffer import space.kscience.kmath.structures.RealBuffer
import space.kscience.kmath.structures.array import space.kscience.kmath.structures.array
import kotlin.js.JsName
import kotlin.math.abs
public class RealTensor( public class RealTensor(
@ -19,11 +17,9 @@ public class RealTensor(
public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
//rename to item?
override fun RealTensor.value(): Double { override fun RealTensor.value(): Double {
check(this.dimension == 0) { check(this.shape contentEquals intArrayOf(1)) {
// todo change message "Inconsistent value for tensor of shape ${shape.toList()}"
"This tensor has shape ${shape.toList()}"
} }
return this.buffer.array[0] return this.buffer.array[0]
} }

View File

@ -48,5 +48,5 @@ public class TensorStrides(override val shape: IntArray): Strides
indexFromOffset(offset, strides, shape.size) indexFromOffset(offset, strides, shape.size)
override val linearSize: Int override val linearSize: Int
get() = shape.fold(1) { acc, i -> acc * i } get() = shape.reduce(Int::times)
} }

View File

@ -9,9 +9,9 @@ import kotlin.test.assertTrue
class TestRealTensor { class TestRealTensor {
@Test @Test
fun valueTest(){ fun valueTest() = RealTensorAlgebra {
val value = 12.5 val value = 12.5
val tensor = RealTensor(IntArray(0), doubleArrayOf(value)) val tensor = RealTensor(intArrayOf(1), doubleArrayOf(value))
assertEquals(tensor.value(), value) assertEquals(tensor.value(), value)
} }