tensor casting
This commit is contained in:
parent
bea6ed4d65
commit
80879d3736
@ -11,6 +11,7 @@ import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
|||||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||||
import space.kscience.kmath.tensors.api.Tensor
|
import space.kscience.kmath.tensors.api.Tensor
|
||||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||||
|
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||||
|
|
||||||
|
|
||||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||||
@ -276,8 +277,17 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
|
|||||||
public class NoaDoubleAlgebra(scope: NoaScope) :
|
public class NoaDoubleAlgebra(scope: NoaScope) :
|
||||||
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
|
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
|
||||||
|
|
||||||
|
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
|
||||||
|
copyFromArray(
|
||||||
|
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toDoubleArray(),
|
||||||
|
this.shape, Device.CPU
|
||||||
|
)
|
||||||
|
|
||||||
override val Tensor<Double>.tensor: NoaDoubleTensor
|
override val Tensor<Double>.tensor: NoaDoubleTensor
|
||||||
get() = TODO("Not yet implemented")
|
get() = when (this) {
|
||||||
|
is NoaDoubleTensor -> this
|
||||||
|
else -> castHelper()
|
||||||
|
}
|
||||||
|
|
||||||
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
|
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
|
||||||
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
@ -342,8 +352,17 @@ public class NoaDoubleAlgebra(scope: NoaScope) :
|
|||||||
public class NoaFloatAlgebra(scope: NoaScope) :
|
public class NoaFloatAlgebra(scope: NoaScope) :
|
||||||
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
|
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
|
||||||
|
|
||||||
|
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
|
||||||
|
copyFromArray(
|
||||||
|
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toFloatArray(),
|
||||||
|
this.shape, Device.CPU
|
||||||
|
)
|
||||||
|
|
||||||
override val Tensor<Float>.tensor: NoaFloatTensor
|
override val Tensor<Float>.tensor: NoaFloatTensor
|
||||||
get() = TODO("Not yet implemented")
|
get() = when (this) {
|
||||||
|
is NoaFloatTensor -> this
|
||||||
|
else -> castHelper()
|
||||||
|
}
|
||||||
|
|
||||||
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
|
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
|
||||||
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
@ -408,8 +427,17 @@ public class NoaFloatAlgebra(scope: NoaScope) :
|
|||||||
public class NoaLongAlgebra(scope: NoaScope) :
|
public class NoaLongAlgebra(scope: NoaScope) :
|
||||||
NoaAlgebra<Long, NoaLongTensor>(scope) {
|
NoaAlgebra<Long, NoaLongTensor>(scope) {
|
||||||
|
|
||||||
|
private fun Tensor<Long>.castHelper(): NoaLongTensor =
|
||||||
|
copyFromArray(
|
||||||
|
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toLongArray(),
|
||||||
|
this.shape, Device.CPU
|
||||||
|
)
|
||||||
|
|
||||||
override val Tensor<Long>.tensor: NoaLongTensor
|
override val Tensor<Long>.tensor: NoaLongTensor
|
||||||
get() = TODO("Not yet implemented")
|
get() = when (this) {
|
||||||
|
is NoaLongTensor -> this
|
||||||
|
else -> castHelper()
|
||||||
|
}
|
||||||
|
|
||||||
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
|
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
|
||||||
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
@ -459,8 +487,17 @@ public class NoaLongAlgebra(scope: NoaScope) :
|
|||||||
public class NoaIntAlgebra(scope: NoaScope) :
|
public class NoaIntAlgebra(scope: NoaScope) :
|
||||||
NoaAlgebra<Int, NoaIntTensor>(scope) {
|
NoaAlgebra<Int, NoaIntTensor>(scope) {
|
||||||
|
|
||||||
|
private fun Tensor<Int>.castHelper(): NoaIntTensor =
|
||||||
|
copyFromArray(
|
||||||
|
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toIntArray(),
|
||||||
|
this.shape, Device.CPU
|
||||||
|
)
|
||||||
|
|
||||||
override val Tensor<Int>.tensor: NoaIntTensor
|
override val Tensor<Int>.tensor: NoaIntTensor
|
||||||
get() = TODO("Not yet implemented")
|
get() = when (this) {
|
||||||
|
is NoaIntTensor -> this
|
||||||
|
else -> castHelper()
|
||||||
|
}
|
||||||
|
|
||||||
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
|
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
|
||||||
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
||||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.noa
|
|||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.noa.memory.NoaScope
|
import space.kscience.kmath.noa.memory.NoaScope
|
||||||
|
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user