unsafe buffer casts moved to internal utils
This commit is contained in:
parent
04f6ef1ed0
commit
4e4690e510
@ -43,14 +43,6 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
|
||||
else -> FloatArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [FloatBuffer] over this array.
|
||||
*
|
||||
|
@ -42,14 +42,6 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
|
||||
else -> IntArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [IntBuffer] over this array.
|
||||
*
|
||||
|
@ -42,14 +42,6 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
|
||||
else -> LongArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [LongBuffer] over this array.
|
||||
*
|
||||
|
@ -47,14 +47,6 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
|
||||
else -> DoubleArray(size, ::get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
||||
is RealBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns [RealBuffer] over this array.
|
||||
*
|
||||
|
@ -1,7 +1,5 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.unsafeToIntArray
|
||||
|
||||
public class RealLinearOpsTensorAlgebra :
|
||||
LinearOpsTensorAlgebra<Double, RealTensor>,
|
||||
RealTensorAlgebra()
|
||||
|
@ -1,7 +1,5 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||
|
||||
|
||||
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.unsafeToDoubleArray
|
||||
import space.kscience.kmath.structures.*
|
||||
import kotlin.math.max
|
||||
|
||||
|
||||
@ -94,3 +94,35 @@ internal inline fun <T, TensorType : TensorStructure<T>,
|
||||
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
|
||||
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
|
||||
check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
|
||||
|
||||
/**
|
||||
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
|
||||
is IntBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
|
||||
is LongBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
|
||||
is FloatBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
|
||||
*/
|
||||
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
|
||||
is RealBuffer -> array
|
||||
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
package space.kscience.kmath.tensors
|
||||
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
@ -10,7 +10,7 @@ class TestRealTensorAlgebra {
|
||||
fun doublePlus() = RealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
|
||||
val res = 10.0 + tensor
|
||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -18,7 +18,7 @@ class TestRealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
|
||||
val res = tensor.transpose(0, 0)
|
||||
|
||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(1))
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ class TestRealTensorAlgebra {
|
||||
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
val res = tensor.transpose(1, 0)
|
||||
|
||||
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
|
||||
assertTrue(res.shape contentEquals intArrayOf(2, 3))
|
||||
}
|
||||
|
||||
@ -42,9 +42,9 @@ class TestRealTensorAlgebra {
|
||||
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
|
||||
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
|
||||
|
||||
assertTrue(res01.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res12.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -70,9 +70,9 @@ class TestRealTensorAlgebra {
|
||||
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
|
||||
|
||||
assertTrue(res[0].buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].buffer.toDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].buffer.toDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||
assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
|
||||
assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
|
||||
assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -82,14 +82,14 @@ class TestRealTensorAlgebra {
|
||||
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
|
||||
|
||||
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
|
||||
assertTrue((tensor2 - tensor1).buffer.toDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
|
||||
|
||||
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
|
||||
assertTrue((tensor3 - tensor1).buffer.toDoubleArray()
|
||||
assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
|
||||
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
|
||||
|
||||
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
|
||||
assertTrue((tensor3 - tensor2).buffer.toDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user