tensor casting

This commit is contained in:
Roland Grinis 2021-07-09 07:55:15 +01:00
parent bea6ed4d65
commit 80879d3736
2 changed files with 42 additions and 4 deletions

View File

@ -11,6 +11,7 @@ import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.core.TensorLinearStructure
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>> public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
@ -276,8 +277,17 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
public class NoaDoubleAlgebra(scope: NoaScope) : public class NoaDoubleAlgebra(scope: NoaScope) :
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) { NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
copyFromArray(
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toDoubleArray(),
this.shape, Device.CPU
)
override val Tensor<Double>.tensor: NoaDoubleTensor override val Tensor<Double>.tensor: NoaDoubleTensor
get() = TODO("Not yet implemented") get() = when (this) {
is NoaDoubleTensor -> this
else -> castHelper()
}
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor = override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle) NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
@ -342,8 +352,17 @@ public class NoaDoubleAlgebra(scope: NoaScope) :
public class NoaFloatAlgebra(scope: NoaScope) : public class NoaFloatAlgebra(scope: NoaScope) :
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) { NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
copyFromArray(
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toFloatArray(),
this.shape, Device.CPU
)
override val Tensor<Float>.tensor: NoaFloatTensor override val Tensor<Float>.tensor: NoaFloatTensor
get() = TODO("Not yet implemented") get() = when (this) {
is NoaFloatTensor -> this
else -> castHelper()
}
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor = override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle) NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
@ -408,8 +427,17 @@ public class NoaFloatAlgebra(scope: NoaScope) :
public class NoaLongAlgebra(scope: NoaScope) : public class NoaLongAlgebra(scope: NoaScope) :
NoaAlgebra<Long, NoaLongTensor>(scope) { NoaAlgebra<Long, NoaLongTensor>(scope) {
private fun Tensor<Long>.castHelper(): NoaLongTensor =
copyFromArray(
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toLongArray(),
this.shape, Device.CPU
)
override val Tensor<Long>.tensor: NoaLongTensor override val Tensor<Long>.tensor: NoaLongTensor
get() = TODO("Not yet implemented") get() = when (this) {
is NoaLongTensor -> this
else -> castHelper()
}
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor = override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
NoaLongTensor(scope = scope, tensorHandle = tensorHandle) NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
@ -459,8 +487,17 @@ public class NoaLongAlgebra(scope: NoaScope) :
public class NoaIntAlgebra(scope: NoaScope) : public class NoaIntAlgebra(scope: NoaScope) :
NoaAlgebra<Int, NoaIntTensor>(scope) { NoaAlgebra<Int, NoaIntTensor>(scope) {
private fun Tensor<Int>.castHelper(): NoaIntTensor =
copyFromArray(
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toIntArray(),
this.shape, Device.CPU
)
override val Tensor<Int>.tensor: NoaIntTensor override val Tensor<Int>.tensor: NoaIntTensor
get() = TODO("Not yet implemented") get() = when (this) {
is NoaIntTensor -> this
else -> castHelper()
}
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor = override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
NoaIntTensor(scope = scope, tensorHandle = tensorHandle) NoaIntTensor(scope = scope, tensorHandle = tensorHandle)

View File

@ -7,6 +7,7 @@ package space.kscience.kmath.noa
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.noa.memory.NoaScope import space.kscience.kmath.noa.memory.NoaScope
import space.kscience.kmath.tensors.core.TensorLinearStructure
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals