v0.3.0-dev-9 #324

Merged
altavir merged 265 commits from dev into master 2021-05-08 17:16:29 +03:00
8 changed files with 47 additions and 51 deletions
Showing only changes of commit 4e4690e510 - Show all commits

View File

@ -43,14 +43,6 @@ public fun Buffer<Float>.toFloatArray(): FloatArray = when(this) {
else -> FloatArray(size, ::get) else -> FloatArray(size, ::get)
} }
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/** /**
* Returns [FloatBuffer] over this array. * Returns [FloatBuffer] over this array.
* *

View File

@ -42,14 +42,6 @@ public fun Buffer<Int>.toIntArray(): IntArray = when(this) {
else -> IntArray(size, ::get) else -> IntArray(size, ::get)
} }
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/** /**
* Returns [IntBuffer] over this array. * Returns [IntBuffer] over this array.
* *

View File

@ -42,14 +42,6 @@ public fun Buffer<Long>.toLongArray(): LongArray = when(this) {
else -> LongArray(size, ::get) else -> LongArray(size, ::get)
} }
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/** /**
* Returns [LongBuffer] over this array. * Returns [LongBuffer] over this array.
* *

View File

@ -47,14 +47,6 @@ public fun Buffer<Double>.toDoubleArray(): DoubleArray = when(this) {
else -> DoubleArray(size, ::get) else -> DoubleArray(size, ::get)
} }
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
public fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}
/** /**
* Returns [RealBuffer] over this array. * Returns [RealBuffer] over this array.
* *

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToIntArray
public class RealLinearOpsTensorAlgebra : public class RealLinearOpsTensorAlgebra :
LinearOpsTensorAlgebra<Double, RealTensor>, LinearOpsTensorAlgebra<Double, RealTensor>,
RealTensorAlgebra() RealTensorAlgebra()

View File

@ -1,7 +1,5 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray
public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> { public open class RealTensorAlgebra : TensorPartialDivisionAlgebra<Double, RealTensor> {

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.unsafeToDoubleArray import space.kscience.kmath.structures.*
import kotlin.math.max import kotlin.math.max
@ -94,3 +94,35 @@ internal inline fun <T, TensorType : TensorStructure<T>,
TorchTensorAlgebraType : TensorAlgebra<T, TensorType>> TorchTensorAlgebraType : TensorAlgebra<T, TensorType>>
TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit = TorchTensorAlgebraType.checkView(a: TensorType, shape: IntArray): Unit =
check(a.shape.reduce(Int::times) == shape.reduce(Int::times)) check(a.shape.reduce(Int::times) == shape.reduce(Int::times))
/**
* Returns a reference to [IntArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Int>.unsafeToIntArray(): IntArray = when(this) {
is IntBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to IntArray")
}
/**
* Returns a reference to [LongArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Long>.unsafeToLongArray(): LongArray = when(this) {
is LongBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to LongArray")
}
/**
* Returns a reference to [FloatArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Float>.unsafeToFloatArray(): FloatArray = when(this) {
is FloatBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to FloatArray")
}
/**
* Returns a reference to [DoubleArray] containing all of the elements of this [Buffer].
*/
internal fun Buffer<Double>.unsafeToDoubleArray(): DoubleArray = when(this) {
is RealBuffer -> array
else -> throw RuntimeException("Failed to cast Buffer to DoubleArray")
}

View File

@ -1,6 +1,6 @@
package space.kscience.kmath.tensors package space.kscience.kmath.tensors
import space.kscience.kmath.structures.toDoubleArray
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -10,7 +10,7 @@ class TestRealTensorAlgebra {
fun doublePlus() = RealTensorAlgebra { fun doublePlus() = RealTensorAlgebra {
val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0)) val tensor = RealTensor(intArrayOf(2), doubleArrayOf(1.0, 2.0))
val res = 10.0 + tensor val res = 10.0 + tensor
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(11.0,12.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(11.0,12.0))
} }
@Test @Test
@ -18,7 +18,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0)) val tensor = RealTensor(intArrayOf(1), doubleArrayOf(0.0))
val res = tensor.transpose(0, 0) val res = tensor.transpose(0, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(0.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(0.0))
assertTrue(res.shape contentEquals intArrayOf(1)) assertTrue(res.shape contentEquals intArrayOf(1))
} }
@ -27,7 +27,7 @@ class TestRealTensorAlgebra {
val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) val tensor = RealTensor(intArrayOf(3, 2), doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
val res = tensor.transpose(1, 0) val res = tensor.transpose(1, 0)
assertTrue(res.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) assertTrue(res.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 3.0, 5.0, 2.0, 4.0, 6.0))
assertTrue(res.shape contentEquals intArrayOf(2, 3)) assertTrue(res.shape contentEquals intArrayOf(2, 3))
} }
@ -42,9 +42,9 @@ class TestRealTensorAlgebra {
assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1)) assertTrue(res02.shape contentEquals intArrayOf(3, 2, 1))
assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2)) assertTrue(res12.shape contentEquals intArrayOf(1, 3, 2))
assertTrue(res01.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res01.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res02.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res02.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
assertTrue(res12.buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0)) assertTrue(res12.buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 4.0, 2.0, 5.0, 3.0, 6.0))
} }
@Test @Test
@ -70,9 +70,9 @@ class TestRealTensorAlgebra {
assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[1].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3)) assertTrue(res[2].shape contentEquals intArrayOf(1, 2, 3))
assertTrue(res[0].buffer.toDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)) assertTrue(res[0].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))
assertTrue(res[1].buffer.toDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0)) assertTrue(res[1].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(10.0, 20.0, 30.0, 10.0, 20.0, 30.0))
assertTrue(res[2].buffer.toDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0)) assertTrue(res[2].buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(500.0, 500.0, 500.0, 500.0, 500.0, 500.0))
} }
@Test @Test
@ -82,14 +82,14 @@ class TestRealTensorAlgebra {
val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0)) val tensor3 = RealTensor(intArrayOf(1, 1, 1), doubleArrayOf(500.0))
assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3)) assertTrue((tensor2 - tensor1).shape contentEquals intArrayOf(2, 3))
assertTrue((tensor2 - tensor1).buffer.toDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0)) assertTrue((tensor2 - tensor1).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(9.0, 18.0, 27.0, 6.0, 15.0, 24.0))
assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3)) assertTrue((tensor3 - tensor1).shape contentEquals intArrayOf(1, 2, 3))
assertTrue((tensor3 - tensor1).buffer.toDoubleArray() assertTrue((tensor3 - tensor1).buffer.unsafeToDoubleArray()
contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0)) contentEquals doubleArrayOf(499.0, 498.0, 497.0, 496.0, 495.0, 494.0))
assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3)) assertTrue((tensor3 - tensor2).shape contentEquals intArrayOf(1, 1, 3))
assertTrue((tensor3 - tensor2).buffer.toDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0)) assertTrue((tensor3 - tensor2).buffer.unsafeToDoubleArray() contentEquals doubleArrayOf(490.0, 480.0, 470.0))
} }
} }