Merge branch 'dev' into kotlingrad

# Conflicts:
#	examples/build.gradle.kts
#	examples/src/main/kotlin/kscience/kmath/ast/ExpressionsInterpretersBenchmark.kt
This commit is contained in:
Iaroslav Postovalov 2020-10-30 19:08:50 +07:00
commit 6b71d8525d
No known key found for this signature in database
GPG Key ID: 46E15E4A31B3BCD7
3 changed files with 91 additions and 1 deletions

View File

@ -20,6 +20,20 @@ dependencies {
implementation(project(":kmath-viktor"))
implementation(project(":kmath-dimensions"))
implementation(project(":kmath-ejml"))
implementation(project(":kmath-nd4j"))
implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
// uncomment if your system supports AVX2
// val os = System.getProperty("os.name")
//
// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when {
// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2")
// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2")
// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2")
// } else
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
implementation("org.slf4j:slf4j-simple:1.7.30")
@ -48,4 +62,6 @@ kotlin.sourceSets.all {
}
}
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
tasks.withType<KotlinCompile> {
kotlinOptions.jvmTarget = "11"
}

View File

@ -1,8 +1,10 @@
package kscience.kmath.structures
import kotlinx.coroutines.GlobalScope
import kscience.kmath.nd4j.Nd4jArrayField
import kscience.kmath.operations.RealField
import kscience.kmath.operations.invoke
import org.nd4j.linalg.factory.Nd4j
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.system.measureTimeMillis
@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) {
}
fun main() {
// initializing Nd4j
Nd4j.zeros(0)
val dim = 1000
val n = 1000
@ -23,6 +27,8 @@ fun main() {
val specializedField = NDField.real(dim, dim)
//A generic boxing field. It should be used for objects, not primitives.
val genericField = NDField.boxing(RealField, dim, dim)
// Nd4j specialized field.
val nd4jField = Nd4jArrayField.real(dim, dim)
measureAndPrint("Automatic field addition") {
autoField {
@ -43,6 +49,13 @@ fun main() {
}
}
measureAndPrint("Nd4j specialized addition") {
nd4jField {
var res = one
repeat(n) { res += 1.0 as Number }
}
}
measureAndPrint("Lazy addition") {
val res = specializedField.one.mapAsync(GlobalScope) {
var c = 0.0

View File

@ -126,6 +126,36 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
check(b)
return b.ndArray.rsub(this).wrap()
}
public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
*/
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
*/
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
}
}
}
/**
@ -145,6 +175,37 @@ public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd
check(b)
return b.ndArray.rdiv(this).wrap()
}
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
private val realNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, RealNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
/**
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
*/
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
/**
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
*/
public fun real(vararg shape: Int): Nd4jArrayRing<Double, RealField> =
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [NDRing] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
}
/**