From 80879d3736b687fe68c92d863e1b839405a941a5 Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Fri, 9 Jul 2021 07:55:15 +0100 Subject: [PATCH] tensor casting --- .../space/kscience/kmath/noa/algebras.kt | 45 +++++++++++++++++-- .../space/kscience/kmath/noa/TestUtils.kt | 1 + 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 93b27e246..039024861 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -11,6 +11,7 @@ import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.TensorAlgebra +import space.kscience.kmath.tensors.core.TensorLinearStructure public sealed class NoaAlgebra> @@ -276,8 +277,17 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear public class NoaDoubleAlgebra(scope: NoaScope) : NoaPartialDivisionAlgebra(scope) { + private fun Tensor.castHelper(): NoaDoubleTensor = + copyFromArray( + TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toDoubleArray(), + this.shape, Device.CPU + ) + override val Tensor.tensor: NoaDoubleTensor - get() = TODO("Not yet implemented") + get() = when (this) { + is NoaDoubleTensor -> this + else -> castHelper() + } override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor = NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) @@ -342,8 +352,17 @@ public class NoaDoubleAlgebra(scope: NoaScope) : public class NoaFloatAlgebra(scope: NoaScope) : NoaPartialDivisionAlgebra(scope) { + private fun Tensor.castHelper(): NoaFloatTensor = + copyFromArray( + TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toFloatArray(), + this.shape, Device.CPU + ) + override val Tensor.tensor: NoaFloatTensor - get() = TODO("Not yet implemented") + get() = when (this) { + is NoaFloatTensor -> this + else -> castHelper() + } override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor = NoaFloatTensor(scope = scope, tensorHandle = tensorHandle) @@ -408,8 +427,17 @@ public class NoaFloatAlgebra(scope: NoaScope) : public class NoaLongAlgebra(scope: NoaScope) : NoaAlgebra(scope) { + private fun Tensor.castHelper(): NoaLongTensor = + copyFromArray( + TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toLongArray(), + this.shape, Device.CPU + ) + override val Tensor.tensor: NoaLongTensor - get() = TODO("Not yet implemented") + get() = when (this) { + is NoaLongTensor -> this + else -> castHelper() + } override fun wrap(tensorHandle: TensorHandle): NoaLongTensor = NoaLongTensor(scope = scope, tensorHandle = tensorHandle) @@ -459,8 +487,17 @@ public class NoaLongAlgebra(scope: NoaScope) : public class NoaIntAlgebra(scope: NoaScope) : NoaAlgebra(scope) { + private fun Tensor.castHelper(): NoaIntTensor = + copyFromArray( + TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toIntArray(), + this.shape, Device.CPU + ) + override val Tensor.tensor: NoaIntTensor - get() = TODO("Not yet implemented") + get() = when (this) { + is NoaIntTensor -> this + else -> castHelper() + } override fun wrap(tensorHandle: TensorHandle): NoaIntTensor = NoaIntTensor(scope = scope, tensorHandle = tensorHandle) diff --git a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt index b16c84657..d24e9ffb8 100644 --- a/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt +++ b/kmath-noa/src/test/kotlin/space/kscience/kmath/noa/TestUtils.kt @@ -7,6 +7,7 @@ package space.kscience.kmath.noa import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.noa.memory.NoaScope +import space.kscience.kmath.tensors.core.TensorLinearStructure import kotlin.test.Test import kotlin.test.assertEquals