Merge branch 'nd4j' into dev
This commit is contained in:
commit
0731f2bd89
@ -19,7 +19,7 @@ repositories {
|
|||||||
sourceSets.register("benchmarks")
|
sourceSets.register("benchmarks")
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
// implementation(project(":kmath-ast"))
|
implementation(project(":kmath-ast"))
|
||||||
implementation(project(":kmath-core"))
|
implementation(project(":kmath-core"))
|
||||||
implementation(project(":kmath-coroutines"))
|
implementation(project(":kmath-coroutines"))
|
||||||
implementation(project(":kmath-commons"))
|
implementation(project(":kmath-commons"))
|
||||||
@ -27,6 +27,20 @@ dependencies {
|
|||||||
implementation(project(":kmath-viktor"))
|
implementation(project(":kmath-viktor"))
|
||||||
implementation(project(":kmath-dimensions"))
|
implementation(project(":kmath-dimensions"))
|
||||||
implementation(project(":kmath-ejml"))
|
implementation(project(":kmath-ejml"))
|
||||||
|
implementation(project(":kmath-nd4j"))
|
||||||
|
implementation("org.deeplearning4j:deeplearning4j-core:1.0.0-beta7")
|
||||||
|
implementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||||
|
|
||||||
|
// uncomment if your system supports AVX2
|
||||||
|
// val os = System.getProperty("os.name")
|
||||||
|
//
|
||||||
|
// if (System.getProperty("os.arch") in arrayOf("x86_64", "amd64")) when {
|
||||||
|
// os.startsWith("Windows") -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:windows-x86_64-avx2")
|
||||||
|
// os == "Linux" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:linux-x86_64-avx2")
|
||||||
|
// os == "Mac OS X" -> implementation("org.nd4j:nd4j-native:1.0.0-beta7:macosx-x86_64-avx2")
|
||||||
|
// } else
|
||||||
|
implementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||||
|
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
implementation("org.jetbrains.kotlinx:kotlinx-io:0.2.0-npm-dev-11")
|
||||||
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
implementation("org.jetbrains.kotlinx:kotlinx.benchmark.runtime:0.2.0-dev-20")
|
||||||
implementation("org.slf4j:slf4j-simple:1.7.30")
|
implementation("org.slf4j:slf4j-simple:1.7.30")
|
||||||
@ -55,4 +69,6 @@ kotlin.sourceSets.all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.withType<KotlinCompile> { kotlinOptions.jvmTarget = "11" }
|
tasks.withType<KotlinCompile> {
|
||||||
|
kotlinOptions.jvmTarget = "11"
|
||||||
|
}
|
||||||
|
@ -1,70 +1,70 @@
|
|||||||
package kscience.kmath.ast
|
package kscience.kmath.ast
|
||||||
//
|
|
||||||
//import kscience.kmath.asm.compile
|
import kscience.kmath.asm.compile
|
||||||
//import kscience.kmath.expressions.Expression
|
import kscience.kmath.expressions.Expression
|
||||||
//import kscience.kmath.expressions.expressionInField
|
import kscience.kmath.expressions.expressionInField
|
||||||
//import kscience.kmath.expressions.invoke
|
import kscience.kmath.expressions.invoke
|
||||||
//import kscience.kmath.operations.Field
|
import kscience.kmath.operations.Field
|
||||||
//import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
//import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
//import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
//
|
|
||||||
//class ExpressionsInterpretersBenchmark {
|
class ExpressionsInterpretersBenchmark {
|
||||||
// private val algebra: Field<Double> = RealField
|
private val algebra: Field<Double> = RealField
|
||||||
// fun functionalExpression() {
|
fun functionalExpression() {
|
||||||
// val expr = algebra.expressionInField {
|
val expr = algebra.expressionInField {
|
||||||
// variable("x") * const(2.0) + const(2.0) / variable("x") - const(16.0)
|
symbol("x") * const(2.0) + const(2.0) / symbol("x") - const(16.0)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// fun mstExpression() {
|
fun mstExpression() {
|
||||||
// val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// fun asmExpression() {
|
fun asmExpression() {
|
||||||
// val expr = algebra.mstInField {
|
val expr = algebra.mstInField {
|
||||||
// symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
symbol("x") * number(2.0) + number(2.0) / symbol("x") - number(16.0)
|
||||||
// }.compile()
|
}.compile()
|
||||||
//
|
|
||||||
// invokeAndSum(expr)
|
invokeAndSum(expr)
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// private fun invokeAndSum(expr: Expression<Double>) {
|
private fun invokeAndSum(expr: Expression<Double>) {
|
||||||
// val random = Random(0)
|
val random = Random(0)
|
||||||
// var sum = 0.0
|
var sum = 0.0
|
||||||
//
|
|
||||||
// repeat(1000000) {
|
repeat(1000000) {
|
||||||
// sum += expr("x" to random.nextDouble())
|
sum += expr("x" to random.nextDouble())
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// println(sum)
|
println(sum)
|
||||||
// }
|
}
|
||||||
//}
|
}
|
||||||
//
|
|
||||||
//fun main() {
|
fun main() {
|
||||||
// val benchmark = ExpressionsInterpretersBenchmark()
|
val benchmark = ExpressionsInterpretersBenchmark()
|
||||||
//
|
|
||||||
// val fe = measureTimeMillis {
|
val fe = measureTimeMillis {
|
||||||
// benchmark.functionalExpression()
|
benchmark.functionalExpression()
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// println("fe=$fe")
|
println("fe=$fe")
|
||||||
//
|
|
||||||
// val mst = measureTimeMillis {
|
val mst = measureTimeMillis {
|
||||||
// benchmark.mstExpression()
|
benchmark.mstExpression()
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// println("mst=$mst")
|
println("mst=$mst")
|
||||||
//
|
|
||||||
// val asm = measureTimeMillis {
|
val asm = measureTimeMillis {
|
||||||
// benchmark.asmExpression()
|
benchmark.asmExpression()
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// println("asm=$asm")
|
println("asm=$asm")
|
||||||
//}
|
}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.GlobalScope
|
import kotlinx.coroutines.GlobalScope
|
||||||
|
import kscience.kmath.nd4j.Nd4jArrayField
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
@ -14,6 +16,8 @@ internal inline fun measureAndPrint(title: String, block: () -> Unit) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
|
// initializing Nd4j
|
||||||
|
Nd4j.zeros(0)
|
||||||
val dim = 1000
|
val dim = 1000
|
||||||
val n = 1000
|
val n = 1000
|
||||||
|
|
||||||
@ -23,6 +27,8 @@ fun main() {
|
|||||||
val specializedField = NDField.real(dim, dim)
|
val specializedField = NDField.real(dim, dim)
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val genericField = NDField.boxing(RealField, dim, dim)
|
val genericField = NDField.boxing(RealField, dim, dim)
|
||||||
|
// Nd4j specialized field.
|
||||||
|
val nd4jField = Nd4jArrayField.real(dim, dim)
|
||||||
|
|
||||||
measureAndPrint("Automatic field addition") {
|
measureAndPrint("Automatic field addition") {
|
||||||
autoField {
|
autoField {
|
||||||
@ -43,6 +49,13 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
measureAndPrint("Nd4j specialized addition") {
|
||||||
|
nd4jField {
|
||||||
|
var res = one
|
||||||
|
repeat(n) { res += 1.0 as Number }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
measureAndPrint("Lazy addition") {
|
measureAndPrint("Lazy addition") {
|
||||||
val res = specializedField.one.mapAsync(GlobalScope) {
|
val res = specializedField.one.mapAsync(GlobalScope) {
|
||||||
var c = 0.0
|
var c = 0.0
|
||||||
|
@ -126,6 +126,36 @@ public interface Nd4jArrayRing<T, R> : NDRing<T, R, Nd4jArrayStructure<T>>, Nd4j
|
|||||||
check(b)
|
check(b)
|
||||||
return b.ndArray.rsub(this).wrap()
|
return b.ndArray.rsub(this).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDRing] for [Int] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
|
||||||
|
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDRing] for [Long] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
|
||||||
|
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a most suitable implementation of [NDRing] using reified class.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
|
||||||
|
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||||
|
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||||
|
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -145,6 +175,37 @@ public interface Nd4jArrayField<T, F> : NDField<T, F, Nd4jArrayStructure<T>>, Nd
|
|||||||
check(b)
|
check(b)
|
||||||
return b.ndArray.rdiv(this).wrap()
|
return b.ndArray.rdiv(this).wrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
private val realNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, RealNd4jArrayField>> =
|
||||||
|
ThreadLocal.withInitial { hashMapOf() }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDField] for [Float] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
|
||||||
|
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an [NDField] for [Double] values or pull it from cache if it was created previously.
|
||||||
|
*/
|
||||||
|
public fun real(vararg shape: Int): Nd4jArrayRing<Double, RealField> =
|
||||||
|
realNd4jArrayFieldCache.get().getOrPut(shape) { RealNd4jArrayField(shape) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a most suitable implementation of [NDRing] using reified class.
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
|
||||||
|
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||||
|
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||||
|
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
Loading…
Reference in New Issue
Block a user