Complex buffer optimization

This commit is contained in:
Alexander Nozik 2019-02-12 11:58:58 +03:00
parent f6cc23ce0a
commit c3989159ab
16 changed files with 286 additions and 65 deletions

View File

@ -1,6 +1,6 @@
plugins {
id "java"
id "me.champeau.gradle.jmh" version "0.4.7"
id "me.champeau.gradle.jmh" version "0.4.8"
id 'org.jetbrains.kotlin.jvm'
}
@ -15,8 +15,8 @@ dependencies {
implementation project(":kmath-coroutines")
implementation project(":kmath-commons")
implementation project(":kmath-koma")
compile group: "com.kyonifer", name:"koma-core-ejml", version: "0.12"
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
implementation group: "com.kyonifer", name:"koma-core-ejml", version: "0.12"
//compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
//jmh project(':kmath-core')
}

View File

@ -0,0 +1,36 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.toComplex
import kotlin.system.measureTimeMillis
fun main() {
val dim = 1000
val n = 1000
val realField = NDField.real(intArrayOf(dim, dim))
val complexField = NDField.complex(intArrayOf(dim, dim))
val realTime = measureTimeMillis {
realField.run {
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
}
}
}
println("Real addition completed in $realTime millis")
val complexTime = measureTimeMillis {
complexField.run {
var res: NDBuffer<Complex> = one
repeat(n) {
res += 1.0.toComplex()
}
}
}
println("Complex addition completed in $complexTime millis")
}

View File

@ -39,7 +39,7 @@ fun main(args: Array<String>) {
val specializedTime = measureTimeMillis {
specializedField.run {
var res:NDBuffer<Double> = one
var res: NDBuffer<Double> = one
repeat(n) {
res += 1.0
}

View File

@ -1,8 +1,8 @@
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
buildscript {
val kotlinVersion: String by rootProject.extra("1.3.20")
val ioVersion: String by rootProject.extra("0.1.2")
val kotlinVersion: String by rootProject.extra("1.3.21")
val ioVersion: String by rootProject.extra("0.1.4")
val coroutinesVersion: String by rootProject.extra("1.1.1")
val atomicfuVersion: String by rootProject.extra("0.12.1")
@ -27,7 +27,7 @@ allprojects {
apply(plugin = "com.jfrog.artifactory")
group = "scientifik"
version = "0.0.3-dev-5"
version = "0.0.3"
repositories {
//maven("https://dl.bintray.com/kotlin/kotlin-eap")

View File

@ -2,9 +2,11 @@ plugins {
kotlin("multiplatform")
}
val ioVersion: String by rootProject.extra
kotlin {
jvm ()
jvm()
js()
sourceSets {

View File

@ -1,9 +1,11 @@
package scientifik.kmath.operations
import kotlin.math.*
/**
* A field for complex numbers
*/
object ComplexField : Field<Complex> {
object ComplexField : ExtendedField<Complex> {
override val zero: Complex = Complex(0.0, 0.0)
override val one: Complex = Complex(1.0, 0.0)
@ -22,6 +24,17 @@ object ComplexField : Field<Complex> {
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
}
override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg))
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
override fun power(arg: Complex, pow: Number): Complex =
arg.abs.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
override fun ln(arg: Complex): Complex = ln(arg.abs) + i * atan2(arg.im, arg.re)
operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
@ -41,20 +54,18 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
override fun Complex.wrap(): Complex = this
override val context: ComplexField
get() = ComplexField
override val context: ComplexField get() = ComplexField
/**
* A complex conjugate
*/
val conjugate: Complex
get() = Complex(re, -im)
val conjugate: Complex get() = Complex(re, -im)
val square: Double
get() = re * re + im * im
val square: Double get() = re * re + im * im
val abs: Double
get() = kotlin.math.sqrt(square)
val abs: Double get() = sqrt(square)
val theta: Double get() = atan(im / re)
companion object
}

View File

@ -12,7 +12,7 @@ class BoxingNDField<T, F : Field<T>>(
override val strides: Strides = DefaultStrides(shape)
override fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
bufferFactory(size, initializer)
override fun check(vararg elements: NDBuffer<T>) {

View File

@ -5,8 +5,6 @@ import scientifik.kmath.operations.*
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
val strides: Strides
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T>
override fun check(vararg elements: NDBuffer<T>) {
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
}

View File

@ -1,11 +1,12 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Field
import scientifik.kmath.operations.ExtendedField
import kotlin.math.*
/**
* A simple field over linear buffers of [Double]
*/
class RealBufferField(val size: Int) : Field<Buffer<Double>> {
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
@ -56,4 +57,54 @@ class RealBufferField(val size: Int) : Field<Buffer<Double>> {
DoubleBuffer(DoubleArray(size) { a[it] / b[it] })
}
}
override fun sin(arg: Buffer<Double>): Buffer<Double> {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) {
val array = arg.array
DoubleBuffer(DoubleArray(size) { sin(array[it]) })
} else {
DoubleBuffer(DoubleArray(size) { sin(arg[it]) })
}
}
override fun cos(arg: Buffer<Double>): Buffer<Double> {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) {
val array = arg.array
DoubleBuffer(DoubleArray(size) { cos(array[it]) })
} else {
DoubleBuffer(DoubleArray(size) { cos(arg[it]) })
}
}
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) {
val array = arg.array
DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) })
} else {
DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) })
}
}
override fun exp(arg: Buffer<Double>): Buffer<Double> {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) {
val array = arg.array
DoubleBuffer(DoubleArray(size) { exp(array[it]) })
} else {
DoubleBuffer(DoubleArray(size) { exp(arg[it]) })
}
}
override fun ln(arg: Buffer<Double>): Buffer<Double> {
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
return if (arg is DoubleBuffer) {
val array = arg.array
DoubleBuffer(DoubleArray(size) { ln(array[it]) })
} else {
DoubleBuffer(DoubleArray(size) { ln(arg[it]) })
}
}
}

View File

@ -15,8 +15,7 @@ class RealNDField(override val shape: IntArray) :
override val zero by lazy { produce { zero } }
override val one by lazy { produce { one } }
@Suppress("OVERRIDE_BY_INLINE")
override inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
DoubleBuffer(DoubleArray(size) { initializer(it) })
/**

View File

@ -15,8 +15,7 @@ class ShortNDRing(override val shape: IntArray) :
override val zero by lazy { produce { ShortRing.zero } }
override val one by lazy { produce { ShortRing.one } }
@Suppress("OVERRIDE_BY_INLINE")
override inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
ShortBuffer(ShortArray(size) { initializer(it) })
/**

View File

@ -7,8 +7,15 @@ import java.nio.ByteBuffer
* A specification for serialization and deserialization objects to buffer
*/
interface BufferSpec<T : Any> {
fun fromBuffer(buffer: ByteBuffer): T
fun toBuffer(value: T): ByteBuffer
/**
* Read an object from buffer in current position
*/
fun ByteBuffer.readObject(): T
/**
* Write object to [ByteBuffer] in current buffer position
*/
fun ByteBuffer.writeObject(value: T)
}
/**
@ -17,39 +24,21 @@ interface BufferSpec<T : Any> {
interface FixedSizeBufferSpec<T : Any> : BufferSpec<T> {
val unitSize: Int
/**
* Read an object from buffer in current position
*/
fun ByteBuffer.readObject(): T {
val buffer = ByteArray(unitSize)
get(buffer)
return fromBuffer(ByteBuffer.wrap(buffer))
}
/**
* Read an object from buffer in given index (not buffer position
*/
fun ByteBuffer.readObject(index: Int): T {
val dup = duplicate()
dup.position(index * unitSize)
return dup.readObject()
position(index * unitSize)
return readObject()
}
/**
* Write object to [ByteBuffer] in current buffer position
*/
fun ByteBuffer.writeObject(obj: T) {
val buffer = toBuffer(obj).apply { rewind() }
assert(buffer.limit() == unitSize)
put(buffer)
}
/**
* Put an object in given index
*/
fun ByteBuffer.writeObject(index: Int, obj: T) {
val dup = duplicate()
dup.position(index * unitSize)
dup.writeObject(obj)
position(index * unitSize)
writeObject(obj)
}
}

View File

@ -4,16 +4,20 @@ import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import java.nio.ByteBuffer
/**
* A serialization specification for complex numbers
*/
object ComplexBufferSpec : FixedSizeBufferSpec<Complex> {
override val unitSize: Int = 16
override fun fromBuffer(buffer: ByteBuffer): Complex {
val re = buffer.getDouble(0)
val im = buffer.getDouble(8)
override fun ByteBuffer.readObject(): Complex {
val re = double
val im = double
return Complex(re, im)
}
override fun toBuffer(value: Complex): ByteBuffer = ByteBuffer.allocate(16).apply {
override fun ByteBuffer.writeObject(value: Complex) {
putDouble(value.re)
putDouble(value.im)
}
@ -22,14 +26,13 @@ object ComplexBufferSpec : FixedSizeBufferSpec<Complex> {
/**
* Create a read-only/mutable buffer which ignores boxing
*/
fun Buffer.Companion.complex(size: Int): Buffer<Complex> =
ObjectBuffer.create(ComplexBufferSpec, size)
fun Buffer.Companion.complex(size: Int, initializer: ((Int) -> Complex)? = null): Buffer<Complex> =
ObjectBuffer.create(ComplexBufferSpec, size, initializer)
fun MutableBuffer.Companion.complex(size: Int) =
ObjectBuffer.create(ComplexBufferSpec, size)
fun MutableBuffer.Companion.complex(size: Int, initializer: ((Int) -> Complex)? = null) =
ObjectBuffer.create(ComplexBufferSpec, size, initializer)
fun NDField.Companion.complex(shape: IntArray) =
BoxingNDField(shape, ComplexField) { size, init -> ObjectBuffer.create(ComplexBufferSpec, size, init) }
fun NDField.Companion.complex(shape: IntArray) = ComplexNDField(shape)
fun NDElement.Companion.complex(shape: IntArray, initializer: ComplexField.(IntArray) -> Complex) =
NDField.complex(shape).produce(initializer)

View File

@ -0,0 +1,125 @@
package scientifik.kmath.structures
import scientifik.kmath.operations.Complex
import scientifik.kmath.operations.ComplexField
import scientifik.kmath.operations.FieldElement
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
/**
* An optimized nd-field for complex numbers
*/
class ComplexNDField(override val shape: IntArray) :
BufferedNDField<Complex, ComplexField>,
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
override val strides: Strides = DefaultStrides(shape)
override val elementContext: ComplexField get() = ComplexField
override val zero by lazy { produce { zero } }
override val one by lazy { produce { one } }
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
Buffer.complex(size) { initializer(it) }
/**
* Inline transform an NDStructure to another structure
*/
override fun map(
arg: NDBuffer<Complex>,
transform: ComplexField.(Complex) -> Complex
): ComplexNDElement {
check(arg)
val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) }
return BufferedNDFieldElement(this, array)
}
override fun produce(initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement {
val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
return BufferedNDFieldElement(this, array)
}
override fun mapIndexed(
arg: NDBuffer<Complex>,
transform: ComplexField.(index: IntArray, Complex) -> Complex
): ComplexNDElement {
check(arg)
return BufferedNDFieldElement(
this,
buildBuffer(arg.strides.linearSize) { offset ->
elementContext.transform(
arg.strides.index(offset),
arg.buffer[offset]
)
})
}
override fun combine(
a: NDBuffer<Complex>,
b: NDBuffer<Complex>,
transform: ComplexField.(Complex, Complex) -> Complex
): ComplexNDElement {
check(a, b)
return BufferedNDFieldElement(
this,
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
}
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
BufferedNDFieldElement(this@ComplexNDField, buffer)
override fun power(arg: NDBuffer<Complex>, pow: Number) = map(arg) { power(it, pow) }
override fun exp(arg: NDBuffer<Complex>) = map(arg) { exp(it) }
override fun ln(arg: NDBuffer<Complex>) = map(arg) { ln(it) }
override fun sin(arg: NDBuffer<Complex>) = map(arg) { sin(it) }
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) }
}
/**
* Fast element production using function inlining
*/
inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
return BufferedNDFieldElement(this, buffer)
}
/**
* Map one [ComplexNDElement] using function with indexes
*/
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) =
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
/**
* Map one [ComplexNDElement] using function without indexes
*/
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
return BufferedNDFieldElement(context, buffer)
}
/**
* Element by element application of any operation on elements to the whole array. Just like in numpy
*/
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
ndElement.map { this@invoke(it) }
/* plus and minus */
/**
* Summation operation for [BufferedNDElement] and single element
*/
operator fun ComplexNDElement.plus(arg: Complex) =
map { it + arg }
/**
* Subtraction operation between [BufferedNDElement] and single element
*/
operator fun ComplexNDElement.minus(arg: Complex) =
map { it - arg }

View File

@ -2,6 +2,9 @@ package scientifik.kmath.structures
import java.nio.ByteBuffer
/**
* A non-boxing buffer based on [ByteBuffer] storage
*/
class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: FixedSizeBufferSpec<T>) :
MutableBuffer<T> {
override val size: Int

View File

@ -6,17 +6,22 @@ import java.nio.ByteBuffer
object RealBufferSpec : FixedSizeBufferSpec<Real> {
override val unitSize: Int = 8
override fun fromBuffer(buffer: ByteBuffer): Real = Real(buffer.double)
override fun ByteBuffer.readObject(): Real = Real(double)
override fun toBuffer(value: Real): ByteBuffer = ByteBuffer.allocate(8).apply { putDouble(value.value) }
override fun ByteBuffer.writeObject(value: Real) {
putDouble(value.value)
}
}
object DoubleBufferSpec : FixedSizeBufferSpec<Double> {
override val unitSize: Int = 8
override fun fromBuffer(buffer: ByteBuffer): Double = buffer.double
override fun ByteBuffer.readObject() = double
override fun ByteBuffer.writeObject(value: Double) {
putDouble(value)
}
override fun toBuffer(value: Double): ByteBuffer = ByteBuffer.allocate(8).apply { putDouble(value) }
}
fun Double.Companion.createBuffer(size: Int) = ObjectBuffer.create(DoubleBufferSpec, size)