inheritance is back
This commit is contained in:
parent
e6e117f694
commit
6d5e4a5776
@ -14,7 +14,7 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
|
|||||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||||
|
|
||||||
|
|
||||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
public sealed class NoaAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||||
protected constructor(protected val scope: NoaScope) :
|
protected constructor(protected val scope: NoaScope) :
|
||||||
TensorAlgebra<T> {
|
TensorAlgebra<T> {
|
||||||
|
|
||||||
@ -36,6 +36,15 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
override fun Tensor<T>.value(): T = tensor.item()
|
override fun Tensor<T>.value(): T = tensor.item()
|
||||||
|
|
||||||
|
public abstract fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): TensorType
|
||||||
|
|
||||||
|
@PerformancePitfall
|
||||||
|
public abstract fun Tensor<T>.copyToArray(): PrimitiveArray
|
||||||
|
|
||||||
|
public abstract fun copyFromArray(array: PrimitiveArray, shape: IntArray, device: Device): TensorType
|
||||||
|
|
||||||
|
public abstract fun full(value: T, shape: IntArray, device: Device): TensorType
|
||||||
|
|
||||||
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
||||||
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||||
}
|
}
|
||||||
@ -132,9 +141,9 @@ protected constructor(protected val scope: NoaScope) :
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaPartialDivisionAlgebra<T, TensorType : NoaTensor<T>>
|
public sealed class NoaPartialDivisionAlgebra<T, PrimitiveArray, TensorType : NoaTensor<T>>
|
||||||
protected constructor(scope: NoaScope) :
|
protected constructor(scope: NoaScope) :
|
||||||
NoaAlgebra<T, TensorType>(scope),
|
NoaAlgebra<T, PrimitiveArray, TensorType>(scope),
|
||||||
LinearOpsTensorAlgebra<T>,
|
LinearOpsTensorAlgebra<T>,
|
||||||
AnalyticTensorAlgebra<T> {
|
AnalyticTensorAlgebra<T> {
|
||||||
|
|
||||||
@ -161,6 +170,10 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
|
override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): TensorType =
|
||||||
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
|
wrap(JNoa.varDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||||
|
|
||||||
|
public abstract fun randNormal(shape: IntArray, device: Device): TensorType
|
||||||
|
|
||||||
|
public abstract fun randUniform(shape: IntArray, device: Device): TensorType
|
||||||
|
|
||||||
public fun Tensor<T>.randUniform(): TensorType =
|
public fun Tensor<T>.randUniform(): TensorType =
|
||||||
wrap(JNoa.randLike(tensor.tensorHandle))
|
wrap(JNoa.randLike(tensor.tensorHandle))
|
||||||
|
|
||||||
@ -278,7 +291,7 @@ protected constructor(scope: NoaScope) :
|
|||||||
|
|
||||||
public sealed class NoaDoubleAlgebra
|
public sealed class NoaDoubleAlgebra
|
||||||
protected constructor(scope: NoaScope) :
|
protected constructor(scope: NoaScope) :
|
||||||
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
|
NoaPartialDivisionAlgebra<Double, DoubleArray, NoaDoubleTensor>(scope) {
|
||||||
|
|
||||||
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
|
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
|
||||||
copyFromArray(
|
copyFromArray(
|
||||||
@ -296,19 +309,19 @@ protected constructor(scope: NoaScope) :
|
|||||||
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public fun Tensor<Double>.copyToArray(): DoubleArray =
|
override fun Tensor<Double>.copyToArray(): DoubleArray =
|
||||||
tensor.elements().map { it.second }.toList().toDoubleArray()
|
tensor.elements().map { it.second }.toList().toDoubleArray()
|
||||||
|
|
||||||
public fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
|
override fun copyFromArray(array: DoubleArray, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||||
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
|
wrap(JNoa.fromBlobDouble(array, shape, device.toInt()))
|
||||||
|
|
||||||
public fun randNormalDouble(shape: IntArray, device: Device): NoaDoubleTensor =
|
override fun randNormal(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||||
wrap(JNoa.randnDouble(shape, device.toInt()))
|
wrap(JNoa.randnDouble(shape, device.toInt()))
|
||||||
|
|
||||||
public fun randUniformDouble(shape: IntArray, device: Device): NoaDoubleTensor =
|
override fun randUniform(shape: IntArray, device: Device): NoaDoubleTensor =
|
||||||
wrap(JNoa.randDouble(shape, device.toInt()))
|
wrap(JNoa.randDouble(shape, device.toInt()))
|
||||||
|
|
||||||
public fun randDiscreteDouble(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
|
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||||
wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
|
wrap(JNoa.randintDouble(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
|
override operator fun Double.plus(other: Tensor<Double>): NoaDoubleTensor =
|
||||||
@ -347,14 +360,14 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<Double>.divAssign(value: Double): Unit =
|
override fun Tensor<Double>.divAssign(value: Double): Unit =
|
||||||
tensor.timesAssign(1 / value)
|
tensor.timesAssign(1 / value)
|
||||||
|
|
||||||
public fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
|
override fun full(value: Double, shape: IntArray, device: Device): NoaDoubleTensor =
|
||||||
wrap(JNoa.fullDouble(value, shape, device.toInt()))
|
wrap(JNoa.fullDouble(value, shape, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaFloatAlgebra
|
public sealed class NoaFloatAlgebra
|
||||||
protected constructor(scope: NoaScope) :
|
protected constructor(scope: NoaScope) :
|
||||||
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
|
NoaPartialDivisionAlgebra<Float, FloatArray, NoaFloatTensor>(scope) {
|
||||||
|
|
||||||
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
|
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
|
||||||
copyFromArray(
|
copyFromArray(
|
||||||
@ -372,19 +385,19 @@ protected constructor(scope: NoaScope) :
|
|||||||
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public fun Tensor<Float>.copyToArray(): FloatArray =
|
override fun Tensor<Float>.copyToArray(): FloatArray =
|
||||||
tensor.elements().map { it.second }.toList().toFloatArray()
|
tensor.elements().map { it.second }.toList().toFloatArray()
|
||||||
|
|
||||||
public fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
|
override fun copyFromArray(array: FloatArray, shape: IntArray, device: Device): NoaFloatTensor =
|
||||||
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
|
wrap(JNoa.fromBlobFloat(array, shape, device.toInt()))
|
||||||
|
|
||||||
public fun randNormalFloat(shape: IntArray, device: Device): NoaFloatTensor =
|
override fun randNormal(shape: IntArray, device: Device): NoaFloatTensor =
|
||||||
wrap(JNoa.randnFloat(shape, device.toInt()))
|
wrap(JNoa.randnFloat(shape, device.toInt()))
|
||||||
|
|
||||||
public fun randUniformFloat(shape: IntArray, device: Device): NoaFloatTensor =
|
override fun randUniform(shape: IntArray, device: Device): NoaFloatTensor =
|
||||||
wrap(JNoa.randFloat(shape, device.toInt()))
|
wrap(JNoa.randFloat(shape, device.toInt()))
|
||||||
|
|
||||||
public fun randDiscreteFloat(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
|
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaFloatTensor =
|
||||||
wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
|
wrap(JNoa.randintFloat(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
|
override operator fun Float.plus(other: Tensor<Float>): NoaFloatTensor =
|
||||||
@ -423,14 +436,14 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<Float>.divAssign(value: Float): Unit =
|
override fun Tensor<Float>.divAssign(value: Float): Unit =
|
||||||
tensor.timesAssign(1 / value)
|
tensor.timesAssign(1 / value)
|
||||||
|
|
||||||
public fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
|
override fun full(value: Float, shape: IntArray, device: Device): NoaFloatTensor =
|
||||||
wrap(JNoa.fullFloat(value, shape, device.toInt()))
|
wrap(JNoa.fullFloat(value, shape, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaLongAlgebra
|
public sealed class NoaLongAlgebra
|
||||||
protected constructor(scope: NoaScope) :
|
protected constructor(scope: NoaScope) :
|
||||||
NoaAlgebra<Long, NoaLongTensor>(scope) {
|
NoaAlgebra<Long, LongArray, NoaLongTensor>(scope) {
|
||||||
|
|
||||||
private fun Tensor<Long>.castHelper(): NoaLongTensor =
|
private fun Tensor<Long>.castHelper(): NoaLongTensor =
|
||||||
copyFromArray(
|
copyFromArray(
|
||||||
@ -448,13 +461,13 @@ protected constructor(scope: NoaScope) :
|
|||||||
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public fun Tensor<Long>.copyToArray(): LongArray =
|
override fun Tensor<Long>.copyToArray(): LongArray =
|
||||||
tensor.elements().map { it.second }.toList().toLongArray()
|
tensor.elements().map { it.second }.toList().toLongArray()
|
||||||
|
|
||||||
public fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
|
override fun copyFromArray(array: LongArray, shape: IntArray, device: Device): NoaLongTensor =
|
||||||
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
|
wrap(JNoa.fromBlobLong(array, shape, device.toInt()))
|
||||||
|
|
||||||
public fun randDiscreteLong(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
|
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||||
wrap(JNoa.randintLong(low, high, shape, device.toInt()))
|
wrap(JNoa.randintLong(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
|
override operator fun Long.plus(other: Tensor<Long>): NoaLongTensor =
|
||||||
@ -484,14 +497,14 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<Long>.timesAssign(value: Long): Unit =
|
override fun Tensor<Long>.timesAssign(value: Long): Unit =
|
||||||
JNoa.timesLongAssign(value, tensor.tensorHandle)
|
JNoa.timesLongAssign(value, tensor.tensorHandle)
|
||||||
|
|
||||||
public fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
|
override fun full(value: Long, shape: IntArray, device: Device): NoaLongTensor =
|
||||||
wrap(JNoa.fullLong(value, shape, device.toInt()))
|
wrap(JNoa.fullLong(value, shape, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class NoaIntAlgebra
|
public sealed class NoaIntAlgebra
|
||||||
protected constructor(scope: NoaScope) :
|
protected constructor(scope: NoaScope) :
|
||||||
NoaAlgebra<Int, NoaIntTensor>(scope) {
|
NoaAlgebra<Int, IntArray, NoaIntTensor>(scope) {
|
||||||
|
|
||||||
private fun Tensor<Int>.castHelper(): NoaIntTensor =
|
private fun Tensor<Int>.castHelper(): NoaIntTensor =
|
||||||
copyFromArray(
|
copyFromArray(
|
||||||
@ -509,13 +522,13 @@ protected constructor(scope: NoaScope) :
|
|||||||
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public fun Tensor<Int>.copyToArray(): IntArray =
|
override fun Tensor<Int>.copyToArray(): IntArray =
|
||||||
tensor.elements().map { it.second }.toList().toIntArray()
|
tensor.elements().map { it.second }.toList().toIntArray()
|
||||||
|
|
||||||
public fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
|
override fun copyFromArray(array: IntArray, shape: IntArray, device: Device): NoaIntTensor =
|
||||||
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
|
wrap(JNoa.fromBlobInt(array, shape, device.toInt()))
|
||||||
|
|
||||||
public fun randDiscreteInt(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
|
override fun randDiscrete(low: Long, high: Long, shape: IntArray, device: Device): NoaIntTensor =
|
||||||
wrap(JNoa.randintInt(low, high, shape, device.toInt()))
|
wrap(JNoa.randintInt(low, high, shape, device.toInt()))
|
||||||
|
|
||||||
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
|
override operator fun Int.plus(other: Tensor<Int>): NoaIntTensor =
|
||||||
@ -545,7 +558,7 @@ protected constructor(scope: NoaScope) :
|
|||||||
override fun Tensor<Int>.timesAssign(value: Int): Unit =
|
override fun Tensor<Int>.timesAssign(value: Int): Unit =
|
||||||
JNoa.timesIntAssign(value, tensor.tensorHandle)
|
JNoa.timesIntAssign(value, tensor.tensorHandle)
|
||||||
|
|
||||||
public fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
|
override fun full(value: Int, shape: IntArray, device: Device): NoaIntTensor =
|
||||||
wrap(JNoa.fullInt(value, shape, device.toInt()))
|
wrap(JNoa.fullInt(value, shape, device.toInt()))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user