forked from kscience/kmath
unsafe buffer casts moved to internal utils
This commit is contained in:
parent
04f6ef1ed0
commit
4e4690e510
@ -43,14 +43,6 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
|
|||||||
else -> FloatArray(size, ::get)
|
else -> FloatArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
|
||||||
is FloatBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [FloatBuffer] over this array.
|
* Returns [FloatBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -42,14 +42,6 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
|
|||||||
else -> IntArray(size, ::get)
|
else -> IntArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
|
||||||
is IntBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [IntBuffer] over this array.
|
* Returns [IntBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -42,14 +42,6 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
|
|||||||
else -> LongArray(size, ::get)
|
else -> LongArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
|
||||||
is LongBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [LongBuffer] over this array.
|
* Returns [LongBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -47,14 +47,6 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
|
|||||||
else -> DoubleArray(size, ::get)
|
else -> DoubleArray(size, ::get)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
|
||||||
*/
|
|
||||||
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
|
||||||
is RealBuffer -> array
|
|
||||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns [RealBuffer] over this array.
|
* Returns [RealBuffer] over this array.
|
||||||
*
|
*
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.unsafeToIntArray
|
|
||||||
|
|
||||||
public class RealLinearOpsTensorAlgebra :
|
public class RealLinearOpsTensorAlgebra :
|
||||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||||
RealTensorAlgebra()
|
RealTensorAlgebra()
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
|
||||||
|
|
||||||
|
|
||||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
import space.kscience.kmath.structures.*
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
|
|
||||||
|
|
||||||
@ -94,3 +94,35 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
|||||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
||||||
|
is IntBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
||||||
|
is LongBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
||||||
|
is FloatBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||||
|
*/
|
||||||
|
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
||||||
|
is RealBuffer -> array
|
||||||
|
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||||
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
package space.kscience.kmath.tensors
|
package space.kscience.kmath.tensors
|
||||||
|
|
||||||
import space.kscience.kmath.structures.toDoubleArray
|
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class TestRealTensorAlgebra {
|
|||||||
fun doublePlus() = RealTensorAlgebra {
|
fun doublePlus() = RealTensorAlgebra {
|
||||||
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||||
val res = 10.0 + tensor
|
val res = 10.0 + tensor
|
||||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -18,7 +18,7 @@ class TestRealTensorAlgebra {
|
|||||||
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
|
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||||
val res = tensor.transpose(0, 0)
|
val res = tensor.transpose(0, 0)
|
||||||
|
|
||||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(0.0))
|
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
|
||||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class TestRealTensorAlgebra {
|
|||||||
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
val res = tensor.transpose(1, 0)
|
val res = tensor.transpose(1, 0)
|
||||||
|
|
||||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,9 +42,9 @@ class TestRealTensorAlgebra {
|
|||||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||||
|
|
||||||
assertTrue(res01.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res02.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
assertTrue(res12.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -70,9 +70,9 @@ class TestRealTensorAlgebra {
|
|||||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||||
|
|
||||||
assertTrue(res[0].buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||||
assertTrue(res[1].buffer.toDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||||
assertTrue(res[2].buffer.toDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -82,14 +82,14 @@ class TestRealTensorAlgebra {
|
|||||||
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||||
|
|
||||||
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
||||||
assertTrue((tensor2 - tensor1).buffer.toDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
||||||
assertTrue((tensor3 - tensor1).buffer.toDoubleArray()
|
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
|
||||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
||||||
|
|
||||||
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
||||||
assertTrue((tensor3 - tensor2).buffer.toDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user