Update tests of GSL
This commit is contained in:
parent
20767a3b35
commit
c34af4d8bd
@ -191,7 +191,7 @@ internal class AsmBuilder<T>(
|
||||
}
|
||||
|
||||
val cls = classLoader.defineClass(className, classWriter.toByteArray())
|
||||
java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
// java.io.File("dump.class").writeBytes(classWriter.toByteArray())
|
||||
val l = MethodHandles.publicLookup()
|
||||
|
||||
if (hasConstants)
|
||||
|
@ -1,6 +1,5 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.memScoped
|
||||
import org.gnu.gsl.gsl_block_calloc
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertFailsWith
|
||||
@ -17,7 +16,7 @@ internal class ErrorHandler {
|
||||
@Test
|
||||
fun matrixAllocation() {
|
||||
assertFailsWith<GslException> {
|
||||
memScoped { GslRealMatrixContext(this).produce(Int.MAX_VALUE, Int.MAX_VALUE) { _, _ -> 0.0 } }
|
||||
GslRealMatrixContext { produce(Int.MAX_VALUE, Int.MAX_VALUE) { _, _ -> 0.0 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,46 +1,71 @@
|
||||
package kscience.kmath.gsl
|
||||
|
||||
import kotlinx.cinterop.memScoped
|
||||
import kscience.kmath.linear.RealMatrixContext
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.*
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
import kotlin.time.measureTime
|
||||
|
||||
internal class RealTest {
|
||||
@Test
|
||||
fun testScale() = memScoped {
|
||||
(GslRealMatrixContext(this)) {
|
||||
val ma = produce(10, 10) { _, _ -> 0.1 }
|
||||
val mb = (ma * 20.0)
|
||||
assertEquals(mb[0, 1], 2.0)
|
||||
}
|
||||
fun testScale() = GslRealMatrixContext {
|
||||
val ma = produce(10, 10) { _, _ -> 0.1 }
|
||||
val mb = (ma * 20.0)
|
||||
assertEquals(mb[0, 1], 2.0)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDotOfMatrixAndVector() = memScoped {
|
||||
(GslRealMatrixContext(this)) {
|
||||
val ma = produce(2, 2) { _, _ -> 100.0 }
|
||||
val vb = RealBuffer(2) { 0.1 }
|
||||
val res1 = ma dot vb
|
||||
val res2 = RealMatrixContext { ma dot vb }
|
||||
println(res1.asSequence().toList())
|
||||
println(res2.asSequence().toList())
|
||||
assertTrue(res1.contentEquals(res2))
|
||||
}
|
||||
fun testDotOfMatrixAndVector() = GslRealMatrixContext {
|
||||
val ma = produce(2, 2) { _, _ -> 100.0 }
|
||||
val vb = RealBuffer(2) { 0.1 }
|
||||
val res1 = ma dot vb
|
||||
val res2 = RealMatrixContext { ma dot vb }
|
||||
println(res1.asSequence().toList())
|
||||
println(res2.asSequence().toList())
|
||||
assertTrue(res1.contentEquals(res2))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testDotOfMatrixAndMatrix() = memScoped {
|
||||
(GslRealMatrixContext(this)) {
|
||||
val ma = produce(2, 2) { _, _ -> 100.0 }
|
||||
val mb = produce(2, 2) { _, _ -> 100.0 }
|
||||
val res1: Matrix<Double> = ma dot mb
|
||||
val res2: Matrix<Double> = RealMatrixContext { ma dot mb }
|
||||
println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
assertEquals(res1, res2)
|
||||
fun testDotOfMatrixAndMatrix() = GslRealMatrixContext {
|
||||
val ma = produce(2, 2) { _, _ -> 100.0 }
|
||||
val mb = produce(2, 2) { _, _ -> 100.0 }
|
||||
val res1: Matrix<Double> = ma dot mb
|
||||
val res2: Matrix<Double> = RealMatrixContext { ma dot mb }
|
||||
println(res1.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
println(res2.rows.asIterable().map { it.asSequence() }.flatMap(Sequence<*>::toList))
|
||||
assertEquals(res1, res2)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testManyCalls() {
|
||||
val r1 = GslRealMatrixContext {
|
||||
var prod = produce(20, 20) { _, _ -> 100.0 }
|
||||
val mult = produce(20, 20) { _, _ -> 3.0 }
|
||||
|
||||
measureTime {
|
||||
repeat(100) {
|
||||
prod = prod dot mult
|
||||
}
|
||||
}.also(::println)
|
||||
|
||||
prod
|
||||
}
|
||||
|
||||
val r2 = RealMatrixContext {
|
||||
var prod = produce(20, 20) { _, _ -> 100.0 }
|
||||
val mult = produce(20, 20) { _, _ -> 3.0 }
|
||||
|
||||
measureTime {
|
||||
repeat(100) {
|
||||
prod = prod dot mult
|
||||
}
|
||||
}.also(::println)
|
||||
|
||||
prod
|
||||
}
|
||||
|
||||
assertTrue(NDStructure.equals(r1, r2))
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user