Merge branch 'nd4j' into dev

This commit is contained in:
Iaroslav Postovalov 2020-10-30 01:09:46 +07:00
commit 0731f2bd89
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
4 changed files with 161 additions and 71 deletions

View File

@ -19,7 +19,7 @@ repositories {
sourceSets.register("benchmarks") sourceSets.register("benchmarks")
dependencies { dependencies {
// implementation(project(":kmath-ast")) implementation(project(":kmath-ast"))
implementation(project(":kmath-core")) implementation(project(":kmath-core"))
implementation(project(":kmath-coroutines")) implementation(project(":kmath-coroutines"))
implementation(project(":kmath-commons")) implementation(project(":kmath-commons"))
@ -27,6 +27,20 @@ dependencies {
implementation(project(":kmath-viktor")) implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions")) implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml")) implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j"))
implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
// uncomment if your system supports AVX2
// val os = System.getProperty("os.name")
//
// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when {
// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2")
// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2")
// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2")
// } else
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11") implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20") implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
implementation("org.slf4j:slf4j-simple:1.7.30") implementation("org.slf4j:slf4j-simple:1.7.30")
@ -55,4 +69,6 @@ kotlin.sourceSets.all {
} }
} }
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" } tasks.withType<KotlinCompile> {
kotlinOptions.jvmTarget = "11"
}

View File

@ -1,70 +1,70 @@
package kscience.kmath.ast package kscience.kmath.ast
//
//import kscience.kmath.asm.compile import kscience.kmath.asm.compile
//import kscience.kmath.expressions.Expression import kscience.kmath.expressions.Expression
//import kscience.kmath.expressions.expressionInField import kscience.kmath.expressions.expressionInField
//import kscience.kmath.expressions.invoke import kscience.kmath.expressions.invoke
//import kscience.kmath.operations.Field import kscience.kmath.operations.Field
//import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
//import kotlin.random.Random import kotlin.random.Random
//import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
//
//class ExpressionsInterpretersBenchmark { class ExpressionsInterpretersBenchmark {
// private val algebra: Field<Double> = RealField private val algebra: Field<Double> = RealField
// fun functionalExpression() { fun functionalExpression() {
// val expr = algebra.expressionInField { val expr = algebra.expressionInField {
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0) symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
// } }
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// fun mstExpression() { fun mstExpression() {
// val expr = algebra.mstInField { val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// } }
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// fun asmExpression() { fun asmExpression() {
// val expr = algebra.mstInField { val expr = algebra.mstInField {
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0) symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
// }.compile() }.compile()
//
// invokeAndSum(expr) invokeAndSum(expr)
// } }
//
// private fun invokeAndSum(expr: Expression<Double>) { private fun invokeAndSum(expr: Expression<Double>) {
// val random = Random(0) val random = Random(0)
// var sum = 0.0 var sum = 0.0
//
// repeat(1000000) { repeat(1000000) {
// sum += expr("x" to random.nextDouble()) sum += expr("x" to random.nextDouble())
// } }
//
// println(sum) println(sum)
// } }
//} }
//
//fun main() { fun main() {
// val benchmark = ExpressionsInterpretersBenchmark() val benchmark = ExpressionsInterpretersBenchmark()
//
// val fe = measureTimeMillis { val fe = measureTimeMillis {
// benchmark.functionalExpression() benchmark.functionalExpression()
// } }
//
// println("fe=$fe") println("fe=$fe")
//
// val mst = measureTimeMillis { val mst = measureTimeMillis {
// benchmark.mstExpression() benchmark.mstExpression()
// } }
//
// println("mst=$mst") println("mst=$mst")
//
// val asm = measureTimeMillis { val asm = measureTimeMillis {
// benchmark.asmExpression() benchmark.asmExpression()
// } }
//
// println("asm=$asm") println("asm=$asm")
//} }

View File

@ -1,8 +1,10 @@
package kscience.kmath.structures package kscience.kmath.structures
import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.GlobalScope
import kscience.kmath.nd4j.Nd4jArrayField
import kscience.kmath.operations.RealField import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke import kscience.kmath.operations.invoke
import org.nd4j.linalg.factory.Nd4j
import kotlin.contracts.InvocationKind import kotlin.contracts.InvocationKind
import kotlin.contracts.contract import kotlin.contracts.contract
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) {
} }
fun main() { fun main() {
// initializing Nd4j
Nd4j.zeros(0)
val dim = 1000 val dim = 1000
val n = 1000 val n = 1000
@ -23,6 +27,8 @@ fun main() {
val specializedField = NDField.real(dim, dim) val specializedField = NDField.real(dim, dim)
//A generic boxing field. It should be used for objects, not primitives. //A generic boxing field. It should be used for objects, not primitives.
val genericField = NDField.boxing(RealField, dim, dim) val genericField = NDField.boxing(RealField, dim, dim)
// Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim)
measureAndPrint("Automatic field addition") { measureAndPrint("Automatic field addition") {
autoField { autoField {
@ -43,6 +49,13 @@ fun main() {
} }
} }
measureAndPrint("Nd4j specialized addition") {
nd4jField {
var res = one
repeat(n) { res += 1.0 as Number }
}
}
measureAndPrint("Lazy addition") { measureAndPrint("Lazy addition") {
val res = specializedField.one.mapAsync(GlobalScope) { val res = specializedField.one.mapAsync(GlobalScope) {
var c = 0.0 var c = 0.0

View File

@ -126,6 +126,36 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
check(b) check(b)
return b.ndArray.rsub(this).wrap() return b.ndArray.rsub(this).wrap()
} }
public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
*/
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
}
}
} }
/** /**
@ -145,6 +175,37 @@ public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd
check(b) check(b)
return b.ndArray.rdiv(this).wrap() return b.ndArray.rdiv(this).wrap()
} }
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
private val realNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, RealNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, RealField> =
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
} }
/** /**