unsafe buffer casts moved to internal utils

This commit is contained in:
Roland Grinis 2021-03-15 08:48:31 +00:00
parent 04f6ef1ed0
commit 4e4690e510
8 changed files with 47 additions and 51 deletions

View File

@ -43,14 +43,6 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
else -> FloatArray(size, ::get)
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns [FloatBuffer] over this array.
*

View File

@ -42,14 +42,6 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
else -> IntArray(size, ::get)
}
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/**
* Returns [IntBuffer] over this array.
*

View File

@ -42,14 +42,6 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
else -> LongArray(size, ::get)
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns [LongBuffer] over this array.
*

View File

@ -47,14 +47,6 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
else -> DoubleArray(size, ::get)
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
/**
* Returns [RealBuffer] over this array.
*

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToIntArray
public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>,
RealTensorAlgebra()

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray
import space.kscience.kmath.structures.*
import kotlin.math.max
@ -94,3 +94,35 @@ internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Test
import kotlin.test.assertTrue
@ -10,7 +10,7 @@ class TestRealTensorAlgebra {
fun doublePlus() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
}
@Test
@ -18,7 +18,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(0.0))
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1))
}
@ -27,7 +27,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3))
}
@ -42,9 +42,9 @@ class TestRealTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
}
@Test
@ -70,9 +70,9 @@ class TestRealTensorAlgebra {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.toDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.toDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
}
@Test
@ -82,14 +82,14 @@ class TestRealTensorAlgebra {
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.toDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.toDoubleArray()
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.toDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
}
}