From 280c4e97e2c3ec03463d4c583098b752a1ee52bd Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 23:20:17 +0100 Subject: [PATCH] algebras --- .../space/kscience/kmath/noa/algebras.kt | 168 +++++++++++++++++- .../space/kscience/kmath/noa/TestUtils.kt | 7 +- 2 files changed, 171 insertions(+), 4 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 1ae6ae2e9..93b27e246 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -279,7 +279,7 @@ public class NoaDoubleAlgebra(scope: NoaScope) : override val Tensor.tensor: NoaDoubleTensor get() = TODO("Not yet implemented") - override fun wrap(tensorHandle: Long): NoaDoubleTensor = + override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor = NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) @PerformancePitfall @@ -339,4 +339,170 @@ public class NoaDoubleAlgebra(scope: NoaScope) : } +public class NoaFloatAlgebra(scope: NoaScope) : + NoaPartialDivisionAlgebra(scope) { + override val Tensor.tensor: NoaFloatTensor + get() = TODO("Not yet implemented") + + override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor = + NoaFloatTensor(scope = scope, tensorHandle = tensorHandle) + + @PerformancePitfall + public fun Tensor.copyToArray(): FloatArray = + tensor.elements().map { it.second }.toList().toFloatArray() + + public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor = + wrap(JNoa.fromBlobFloat(array, shape, device.toInt())) + + public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor = + wrap(JNoa.randnFloat(shape, device.toInt())) + + public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor = + wrap(JNoa.randFloat(shape, device.toInt())) + + public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor = + wrap(JNoa.randintFloat(low, high, shape, device.toInt())) + + override operator fun Float.plus(other: Tensor): NoaFloatTensor = + wrap(JNoa.plusFloat(this, other.tensor.tensorHandle)) + + override fun Tensor.plus(value: Float): NoaFloatTensor = + wrap(space.kscience.kmath.noa.JNoa.plusFloat(value, tensor.tensorHandle)) + + override fun Tensor.plusAssign(value: Float): Unit = + space.kscience.kmath.noa.JNoa.plusFloatAssign(value, tensor.tensorHandle) + + override operator fun Float.minus(other: Tensor): NoaFloatTensor = + wrap(JNoa.plusFloat(-this, other.tensor.tensorHandle)) + + override fun Tensor.minus(value: Float): NoaFloatTensor = + wrap(space.kscience.kmath.noa.JNoa.plusFloat(-value, tensor.tensorHandle)) + + override fun Tensor.minusAssign(value: Float): Unit = + space.kscience.kmath.noa.JNoa.plusFloatAssign(-value, tensor.tensorHandle) + + override operator fun Float.times(other: Tensor): NoaFloatTensor = + wrap(JNoa.timesFloat(this, other.tensor.tensorHandle)) + + override fun Tensor.times(value: Float): NoaFloatTensor = + wrap(space.kscience.kmath.noa.JNoa.timesFloat(value, tensor.tensorHandle)) + + override fun Tensor.timesAssign(value: Float): Unit = + space.kscience.kmath.noa.JNoa.timesFloatAssign(value, tensor.tensorHandle) + + override fun Float.div(other: Tensor): NoaFloatTensor = + other * (1 / this) + + override fun Tensor.div(value: Float): NoaFloatTensor = + tensor * (1 / value) + + override fun Tensor.divAssign(value: Float): Unit = + tensor.timesAssign(1 / value) + + public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor = + wrap(JNoa.fullFloat(value, shape, device.toInt())) + +} + +public class NoaLongAlgebra(scope: NoaScope) : + NoaAlgebra(scope) { + + override val Tensor.tensor: NoaLongTensor + get() = TODO("Not yet implemented") + + override fun wrap(tensorHandle: TensorHandle): NoaLongTensor = + NoaLongTensor(scope = scope, tensorHandle = tensorHandle) + + @PerformancePitfall + public fun Tensor.copyToArray(): LongArray = + tensor.elements().map { it.second }.toList().toLongArray() + + public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor = + wrap(JNoa.fromBlobLong(array, shape, device.toInt())) + + public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor = + wrap(JNoa.randintLong(low, high, shape, device.toInt())) + + override operator fun Long.plus(other: Tensor): NoaLongTensor = + wrap(JNoa.plusLong(this, other.tensor.tensorHandle)) + + override fun Tensor.plus(value: Long): NoaLongTensor = + wrap(space.kscience.kmath.noa.JNoa.plusLong(value, tensor.tensorHandle)) + + override fun Tensor.plusAssign(value: Long): Unit = + space.kscience.kmath.noa.JNoa.plusLongAssign(value, tensor.tensorHandle) + + override operator fun Long.minus(other: Tensor): NoaLongTensor = + wrap(JNoa.plusLong(-this, other.tensor.tensorHandle)) + + override fun Tensor.minus(value: Long): NoaLongTensor = + wrap(space.kscience.kmath.noa.JNoa.plusLong(-value, tensor.tensorHandle)) + + override fun Tensor.minusAssign(value: Long): Unit = + space.kscience.kmath.noa.JNoa.plusLongAssign(-value, tensor.tensorHandle) + + override operator fun Long.times(other: Tensor): NoaLongTensor = + wrap(JNoa.timesLong(this, other.tensor.tensorHandle)) + + override fun Tensor.times(value: Long): NoaLongTensor = + wrap(space.kscience.kmath.noa.JNoa.timesLong(value, tensor.tensorHandle)) + + override fun Tensor.timesAssign(value: Long): Unit = + space.kscience.kmath.noa.JNoa.timesLongAssign(value, tensor.tensorHandle) + + public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor = + wrap(JNoa.fullLong(value, shape, device.toInt())) + +} + +public class NoaIntAlgebra(scope: NoaScope) : + NoaAlgebra(scope) { + + override val Tensor.tensor: NoaIntTensor + get() = TODO("Not yet implemented") + + override fun wrap(tensorHandle: TensorHandle): NoaIntTensor = + NoaIntTensor(scope = scope, tensorHandle = tensorHandle) + + @PerformancePitfall + public fun Tensor.copyToArray(): IntArray = + tensor.elements().map { it.second }.toList().toIntArray() + + public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor = + wrap(JNoa.fromBlobInt(array, shape, device.toInt())) + + public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor = + wrap(JNoa.randintInt(low, high, shape, device.toInt())) + + override operator fun Int.plus(other: Tensor): NoaIntTensor = + wrap(JNoa.plusInt(this, other.tensor.tensorHandle)) + + override fun Tensor.plus(value: Int): NoaIntTensor = + wrap(space.kscience.kmath.noa.JNoa.plusInt(value, tensor.tensorHandle)) + + override fun Tensor.plusAssign(value: Int): Unit = + space.kscience.kmath.noa.JNoa.plusIntAssign(value, tensor.tensorHandle) + + override operator fun Int.minus(other: Tensor): NoaIntTensor = + wrap(JNoa.plusInt(-this, other.tensor.tensorHandle)) + + override fun Tensor.minus(value: Int): NoaIntTensor = + wrap(space.kscience.kmath.noa.JNoa.plusInt(-value, tensor.tensorHandle)) + + override fun Tensor.minusAssign(value: Int): Unit = + space.kscience.kmath.noa.JNoa.plusIntAssign(-value, tensor.tensorHandle) + + override operator fun Int.times(other: Tensor): NoaIntTensor = + wrap(JNoa.timesInt(this, other.tensor.tensorHandle)) + + override fun Tensor.times(value: Int): NoaIntTensor = + wrap(space.kscience.kmath.noa.JNoa.timesInt(value, tensor.tensorHandle)) + + override fun Tensor.timesAssign(value: Int): Unit = + space.kscience.kmath.noa.JNoa.timesIntAssign(value, tensor.tensorHandle) + + public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor = + wrap(JNoa.fullInt(value, shape, device.toInt())) + +} diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt index ff37da3e3..b16c84657 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt @@ -5,6 +5,8 @@ package space.kscience.kmath.noa +import space.kscience.kmath.misc.PerformancePitfall +import space.kscience.kmath.noa.memory.NoaScope import kotlin.test.Test import kotlin.test.assertEquals @@ -26,6 +28,5 @@ class TestUtils { setNumThreads(numThreads) assertEquals(numThreads, getNumThreads()) } - - -} \ No newline at end of file + +}