forked from kscience/kmath
tensor casting
This commit is contained in:
parent
bea6ed4d65
commit
80879d3736
@ -11,6 +11,7 @@ import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||
|
||||
|
||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||
@ -276,8 +277,17 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
|
||||
public class NoaDoubleAlgebra(scope: NoaScope) :
|
||||
NoaPartialDivisionAlgebra<Double, NoaDoubleTensor>(scope) {
|
||||
|
||||
private fun Tensor<Double>.castHelper(): NoaDoubleTensor =
|
||||
copyFromArray(
|
||||
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toDoubleArray(),
|
||||
this.shape, Device.CPU
|
||||
)
|
||||
|
||||
override val Tensor<Double>.tensor: NoaDoubleTensor
|
||||
get() = TODO("Not yet implemented")
|
||||
get() = when (this) {
|
||||
is NoaDoubleTensor -> this
|
||||
else -> castHelper()
|
||||
}
|
||||
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaDoubleTensor =
|
||||
NoaDoubleTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
@ -342,8 +352,17 @@ public class NoaDoubleAlgebra(scope: NoaScope) :
|
||||
public class NoaFloatAlgebra(scope: NoaScope) :
|
||||
NoaPartialDivisionAlgebra<Float, NoaFloatTensor>(scope) {
|
||||
|
||||
private fun Tensor<Float>.castHelper(): NoaFloatTensor =
|
||||
copyFromArray(
|
||||
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toFloatArray(),
|
||||
this.shape, Device.CPU
|
||||
)
|
||||
|
||||
override val Tensor<Float>.tensor: NoaFloatTensor
|
||||
get() = TODO("Not yet implemented")
|
||||
get() = when (this) {
|
||||
is NoaFloatTensor -> this
|
||||
else -> castHelper()
|
||||
}
|
||||
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaFloatTensor =
|
||||
NoaFloatTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
@ -408,8 +427,17 @@ public class NoaFloatAlgebra(scope: NoaScope) :
|
||||
public class NoaLongAlgebra(scope: NoaScope) :
|
||||
NoaAlgebra<Long, NoaLongTensor>(scope) {
|
||||
|
||||
private fun Tensor<Long>.castHelper(): NoaLongTensor =
|
||||
copyFromArray(
|
||||
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toLongArray(),
|
||||
this.shape, Device.CPU
|
||||
)
|
||||
|
||||
override val Tensor<Long>.tensor: NoaLongTensor
|
||||
get() = TODO("Not yet implemented")
|
||||
get() = when (this) {
|
||||
is NoaLongTensor -> this
|
||||
else -> castHelper()
|
||||
}
|
||||
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaLongTensor =
|
||||
NoaLongTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
@ -459,8 +487,17 @@ public class NoaLongAlgebra(scope: NoaScope) :
|
||||
public class NoaIntAlgebra(scope: NoaScope) :
|
||||
NoaAlgebra<Int, NoaIntTensor>(scope) {
|
||||
|
||||
private fun Tensor<Int>.castHelper(): NoaIntTensor =
|
||||
copyFromArray(
|
||||
TensorLinearStructure(this.shape).indices().map(this::get).toMutableList().toIntArray(),
|
||||
this.shape, Device.CPU
|
||||
)
|
||||
|
||||
override val Tensor<Int>.tensor: NoaIntTensor
|
||||
get() = TODO("Not yet implemented")
|
||||
get() = when (this) {
|
||||
is NoaIntTensor -> this
|
||||
else -> castHelper()
|
||||
}
|
||||
|
||||
override fun wrap(tensorHandle: TensorHandle): NoaIntTensor =
|
||||
NoaIntTensor(scope = scope, tensorHandle = tensorHandle)
|
||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.noa
|
||||
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.noa.memory.NoaScope
|
||||
import space.kscience.kmath.tensors.core.TensorLinearStructure
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user