unsafe buffer casts moved to internal utils

This commit is contained in:
Roland Grinis 2021-03-15 08:48:31 +00:00
parent 04f6ef1ed0
commit 4e4690e510
8 changed files with 47 additions and 51 deletions

View File

@ -43,14 +43,6 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
else -> FloatArray(size, ::get) else -> FloatArray(size, ::get)
} }
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/** /**
* Returns [FloatBuffer] over this array. * Returns [FloatBuffer] over this array.
* *

View File

@ -42,14 +42,6 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
else -> IntArray(size, ::get) else -> IntArray(size, ::get)
} }
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/** /**
* Returns [IntBuffer] over this array. * Returns [IntBuffer] over this array.
* *

View File

@ -42,14 +42,6 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
else -> LongArray(size, ::get) else -> LongArray(size, ::get)
} }
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/** /**
* Returns [LongBuffer] over this array. * Returns [LongBuffer] over this array.
* *

View File

@ -47,14 +47,6 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
else -> DoubleArray(size, ::get) else -> DoubleArray(size, ::get)
} }
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
/** /**
* Returns [RealBuffer] over this array. * Returns [RealBuffer] over this array.
* *

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToIntArray
public class RealLinearOpsTensorAlgebra : public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>, LinearOpsTensorAlgebra<Double, RealTensor>,
RealTensorAlgebra() RealTensorAlgebra()

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray import space.kscience.kmath.structures.*
import kotlin.math.max import kotlin.math.max
@ -94,3 +94,35 @@ internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>> TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit = TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -10,7 +10,7 @@ class TestRealTensorAlgebra {
fun doublePlus() = RealTensorAlgebra { fun doublePlus() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0)) val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor val res = 10.0 + tensor
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(11.0,12.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
} }
@Test @Test
@ -18,7 +18,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0)) val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0) val res = tensor.transpose(0, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(0.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1)) assertTrue(res.shape contentEquals intArrayOf(1))
} }
@ -27,7 +27,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0) val res = tensor.transpose(1, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3)) assertTrue(res.shape contentEquals intArrayOf(2, 3))
} }
@ -42,9 +42,9 @@ class TestRealTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test @Test
@ -70,9 +70,9 @@ class TestRealTensorAlgebra {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.toDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.toDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
} }
@Test @Test
@ -82,14 +82,14 @@ class TestRealTensorAlgebra {
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3)) assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.toDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3)) assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.toDoubleArray() assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)) contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3)) assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.toDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
} }
} }