This commit is contained in:
Roland Grinis 2021-07-08 23:20:17 +01:00
parent 62b3ccd111
commit 280c4e97e2
2 changed files with 171 additions and 4 deletions

View File

@ -279,7 +279,7 @@ public class NoaDoubleAlgebra(scope: NoaScope) :
override val Tensor<Double>.tensor: NoaDoubleTensor override val Tensor<Double>.tensor: NoaDoubleTensor
get() = TODO("Not yet implemented") get() = TODO("Not yet implemented")
override fun wrap(tensorHandle: Long): NoaDoubleTensor = override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall @PerformancePitfall
@ -339,4 +339,170 @@ public class NoaDoubleAlgebra(scope: NoaScope) :
} }
public class NoaFloatAlgebra(scope: NoaScope) :
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
override val Tensor<Float>.tensor: NoaFloatTensor
get() = TODO("Not yet implemented")
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Float>.copyToArray(): FloatArray =
tensor.elements().map { it.second }.toList().toFloatArray()
public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randnFloat(shape, device.toInt()))
public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randFloat(shape, device.toInt()))
public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
wrap(JNoa.plusFloat(this, other.tensor.tensorHandle))
override fun Tensor<Float>.plus(value: Float): NoaFloatTensor =
wrap(space.kscience.kmath.noa.JNoa.plusFloat(value, tensor.tensorHandle))
override fun Tensor<Float>.plusAssign(value: Float): Unit =
space.kscience.kmath.noa.JNoa.plusFloatAssign(value, tensor.tensorHandle)
override operator fun Float.minus(other: Tensor<Float>): NoaFloatTensor =
wrap(JNoa.plusFloat(-this, other.tensor.tensorHandle))
override fun Tensor<Float>.minus(value: Float): NoaFloatTensor =
wrap(space.kscience.kmath.noa.JNoa.plusFloat(-value, tensor.tensorHandle))
override fun Tensor<Float>.minusAssign(value: Float): Unit =
space.kscience.kmath.noa.JNoa.plusFloatAssign(-value, tensor.tensorHandle)
override operator fun Float.times(other: Tensor<Float>): NoaFloatTensor =
wrap(JNoa.timesFloat(this, other.tensor.tensorHandle))
override fun Tensor<Float>.times(value: Float): NoaFloatTensor =
wrap(space.kscience.kmath.noa.JNoa.timesFloat(value, tensor.tensorHandle))
override fun Tensor<Float>.timesAssign(value: Float): Unit =
space.kscience.kmath.noa.JNoa.timesFloatAssign(value, tensor.tensorHandle)
override fun Float.div(other: Tensor<Float>): NoaFloatTensor =
other * (1 / this)
override fun Tensor<Float>.div(value: Float): NoaFloatTensor =
tensor * (1 / value)
override fun Tensor<Float>.divAssign(value: Float): Unit =
tensor.timesAssign(1 / value)
public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
wrap(JNoa.fullFloat(value, shape, device.toInt()))
}
public class NoaLongAlgebra(scope: NoaScope) :
NoaAlgebra<Long, NoaLongTensor>(scope) {
override val Tensor<Long>.tensor: NoaLongTensor
get() = TODO("Not yet implemented")
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Long>.copyToArray(): LongArray =
tensor.elements().map { it.second }.toList().toLongArray()
public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.randintLong(low, high, shape, device.toInt()))
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
wrap(JNoa.plusLong(this, other.tensor.tensorHandle))
override fun Tensor<Long>.plus(value: Long): NoaLongTensor =
wrap(space.kscience.kmath.noa.JNoa.plusLong(value, tensor.tensorHandle))
override fun Tensor<Long>.plusAssign(value: Long): Unit =
space.kscience.kmath.noa.JNoa.plusLongAssign(value, tensor.tensorHandle)
override operator fun Long.minus(other: Tensor<Long>): NoaLongTensor =
wrap(JNoa.plusLong(-this, other.tensor.tensorHandle))
override fun Tensor<Long>.minus(value: Long): NoaLongTensor =
wrap(space.kscience.kmath.noa.JNoa.plusLong(-value, tensor.tensorHandle))
override fun Tensor<Long>.minusAssign(value: Long): Unit =
space.kscience.kmath.noa.JNoa.plusLongAssign(-value, tensor.tensorHandle)
override operator fun Long.times(other: Tensor<Long>): NoaLongTensor =
wrap(JNoa.timesLong(this, other.tensor.tensorHandle))
override fun Tensor<Long>.times(value: Long): NoaLongTensor =
wrap(space.kscience.kmath.noa.JNoa.timesLong(value, tensor.tensorHandle))
override fun Tensor<Long>.timesAssign(value: Long): Unit =
space.kscience.kmath.noa.JNoa.timesLongAssign(value, tensor.tensorHandle)
public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
wrap(JNoa.fullLong(value, shape, device.toInt()))
}
public class NoaIntAlgebra(scope: NoaScope) :
NoaAlgebra<Int, NoaIntTensor>(scope) {
override val Tensor<Int>.tensor: NoaIntTensor
get() = TODO("Not yet implemented")
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
@PerformancePitfall
public fun Tensor<Int>.copyToArray(): IntArray =
tensor.elements().map { it.second }.toList().toIntArray()
public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.randintInt(low, high, shape, device.toInt()))
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
wrap(JNoa.plusInt(this, other.tensor.tensorHandle))
override fun Tensor<Int>.plus(value: Int): NoaIntTensor =
wrap(space.kscience.kmath.noa.JNoa.plusInt(value, tensor.tensorHandle))
override fun Tensor<Int>.plusAssign(value: Int): Unit =
space.kscience.kmath.noa.JNoa.plusIntAssign(value, tensor.tensorHandle)
override operator fun Int.minus(other: Tensor<Int>): NoaIntTensor =
wrap(JNoa.plusInt(-this, other.tensor.tensorHandle))
override fun Tensor<Int>.minus(value: Int): NoaIntTensor =
wrap(space.kscience.kmath.noa.JNoa.plusInt(-value, tensor.tensorHandle))
override fun Tensor<Int>.minusAssign(value: Int): Unit =
space.kscience.kmath.noa.JNoa.plusIntAssign(-value, tensor.tensorHandle)
override operator fun Int.times(other: Tensor<Int>): NoaIntTensor =
wrap(JNoa.timesInt(this, other.tensor.tensorHandle))
override fun Tensor<Int>.times(value: Int): NoaIntTensor =
wrap(space.kscience.kmath.noa.JNoa.timesInt(value, tensor.tensorHandle))
override fun Tensor<Int>.timesAssign(value: Int): Unit =
space.kscience.kmath.noa.JNoa.timesIntAssign(value, tensor.tensorHandle)
public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
wrap(JNoa.fullInt(value, shape, device.toInt()))
}

View File

@ -5,6 +5,8 @@
package space.kscience.kmath.noa package space.kscience.kmath.noa
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaScope
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
@ -26,6 +28,5 @@ class TestUtils {
setNumThreads(numThreads) setNumThreads(numThreads)
assertEquals(numThreads, getNumThreads()) assertEquals(numThreads, getNumThreads())
} }
}
}