diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt index 19bc44787..dffb14886 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/tensors.kt @@ -28,6 +28,7 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : get() = (1..dimension).map { JNoa.getStrideAt(tensorHandle, it - 1) }.toIntArray() public val numElements: Int get() = JNoa.getNumel(tensorHandle) + public val device: Device get() = Device.fromInt(JNoa.getDevice(tensorHandle)) override fun toString(): String = JNoa.tensorToString(tensorHandle) @@ -48,21 +49,21 @@ constructor(scope: NoaScope, internal val tensorHandle: TensorHandle) : scope = scope, tensorHandle = JNoa.copyToDouble(this.tensorHandle) ) -/* - public fun copyToFloat(): TorchTensorFloat = TorchTensorFloat( + + public fun copyToFloat(): NoaFloatTensor = NoaFloatTensor( scope = scope, - tensorHandle = JTorch.copyToFloat(this.tensorHandle) + tensorHandle = JNoa.copyToFloat(this.tensorHandle) ) - public fun copyToLong(): TorchTensorLong = TorchTensorLong( + public fun copyToLong(): NoaLongTensor = NoaLongTensor( scope = scope, - tensorHandle = JTorch.copyToLong(this.tensorHandle) + tensorHandle = JNoa.copyToLong(this.tensorHandle) ) - public fun copyToInt(): TorchTensorInt = TorchTensorInt( + public fun copyToInt(): NoaIntTensor = NoaIntTensor( scope = scope, - tensorHandle = JTorch.copyToInt(this.tensorHandle) - )*/ + tensorHandle = JNoa.copyToInt(this.tensorHandle) + ) } public sealed class NoaTensorOverField @@ -134,9 +135,6 @@ internal constructor(scope: NoaScope, tensorHandle: TensorHandle) : } } - - - public sealed class Device { public object CPU : Device() { override fun toString(): String {