forked from kscience/kmath
Complex buffer optimization
This commit is contained in:
parent
f6cc23ce0a
commit
c3989159ab
@ -1,6 +1,6 @@
|
|||||||
plugins {
|
plugins {
|
||||||
id "java"
|
id "java"
|
||||||
id "me.champeau.gradle.jmh" version "0.4.7"
|
id "me.champeau.gradle.jmh" version "0.4.8"
|
||||||
id 'org.jetbrains.kotlin.jvm'
|
id 'org.jetbrains.kotlin.jvm'
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -15,8 +15,8 @@ dependencies {
|
|||||||
implementation project(":kmath-coroutines")
|
implementation project(":kmath-coroutines")
|
||||||
implementation project(":kmath-commons")
|
implementation project(":kmath-commons")
|
||||||
implementation project(":kmath-koma")
|
implementation project(":kmath-koma")
|
||||||
compile group: "com.kyonifer", name:"koma-core-ejml", version: "0.12"
|
implementation group: "com.kyonifer", name:"koma-core-ejml", version: "0.12"
|
||||||
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
//compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
|
||||||
//jmh project(':kmath-core')
|
//jmh project(':kmath-core')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.toComplex
|
||||||
|
import kotlin.system.measureTimeMillis
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
val dim = 1000
|
||||||
|
val n = 1000
|
||||||
|
|
||||||
|
val realField = NDField.real(intArrayOf(dim, dim))
|
||||||
|
val complexField = NDField.complex(intArrayOf(dim, dim))
|
||||||
|
|
||||||
|
|
||||||
|
val realTime = measureTimeMillis {
|
||||||
|
realField.run {
|
||||||
|
var res: NDBuffer<Double> = one
|
||||||
|
repeat(n) {
|
||||||
|
res += 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("Real addition completed in $realTime millis")
|
||||||
|
|
||||||
|
val complexTime = measureTimeMillis {
|
||||||
|
complexField.run {
|
||||||
|
var res: NDBuffer<Complex> = one
|
||||||
|
repeat(n) {
|
||||||
|
res += 1.0.toComplex()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println("Complex addition completed in $complexTime millis")
|
||||||
|
}
|
@ -1,8 +1,8 @@
|
|||||||
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
import org.jetbrains.kotlin.gradle.dsl.KotlinMultiplatformExtension
|
||||||
|
|
||||||
buildscript {
|
buildscript {
|
||||||
val kotlinVersion: String by rootProject.extra("1.3.20")
|
val kotlinVersion: String by rootProject.extra("1.3.21")
|
||||||
val ioVersion: String by rootProject.extra("0.1.2")
|
val ioVersion: String by rootProject.extra("0.1.4")
|
||||||
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
val coroutinesVersion: String by rootProject.extra("1.1.1")
|
||||||
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
val atomicfuVersion: String by rootProject.extra("0.12.1")
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ allprojects {
|
|||||||
apply(plugin = "com.jfrog.artifactory")
|
apply(plugin = "com.jfrog.artifactory")
|
||||||
|
|
||||||
group = "scientifik"
|
group = "scientifik"
|
||||||
version = "0.0.3-dev-5"
|
version = "0.0.3"
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
//maven("https://dl.bintray.com/kotlin/kotlin-eap")
|
||||||
|
@ -2,6 +2,8 @@ plugins {
|
|||||||
kotlin("multiplatform")
|
kotlin("multiplatform")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val ioVersion: String by rootProject.extra
|
||||||
|
|
||||||
|
|
||||||
kotlin {
|
kotlin {
|
||||||
jvm()
|
jvm()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package scientifik.kmath.operations
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field for complex numbers
|
* A field for complex numbers
|
||||||
*/
|
*/
|
||||||
object ComplexField : Field<Complex> {
|
object ComplexField : ExtendedField<Complex> {
|
||||||
override val zero: Complex = Complex(0.0, 0.0)
|
override val zero: Complex = Complex(0.0, 0.0)
|
||||||
|
|
||||||
override val one: Complex = Complex(1.0, 0.0)
|
override val one: Complex = Complex(1.0, 0.0)
|
||||||
@ -22,6 +24,17 @@ object ComplexField : Field<Complex> {
|
|||||||
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
return Complex((a.re * b.re + a.im * b.im) / norm, (a.re * b.im - a.im * b.re) / norm)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: Complex): Complex = i / 2 * (exp(-i * arg) - exp(i * arg))
|
||||||
|
|
||||||
|
override fun cos(arg: Complex): Complex = (exp(-i * arg) + exp(i * arg)) / 2
|
||||||
|
|
||||||
|
override fun power(arg: Complex, pow: Number): Complex =
|
||||||
|
arg.abs.pow(pow.toDouble()) * (cos(pow.toDouble() * arg.theta) + i * sin(pow.toDouble() * arg.theta))
|
||||||
|
|
||||||
|
override fun exp(arg: Complex): Complex = exp(arg.re) * (cos(arg.im) + i * sin(arg.im))
|
||||||
|
|
||||||
|
override fun ln(arg: Complex): Complex = ln(arg.abs) + i * atan2(arg.im, arg.re)
|
||||||
|
|
||||||
operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
|
operator fun Double.plus(c: Complex) = add(this.toComplex(), c)
|
||||||
|
|
||||||
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
|
operator fun Double.minus(c: Complex) = add(this.toComplex(), -c)
|
||||||
@ -41,20 +54,18 @@ data class Complex(val re: Double, val im: Double) : FieldElement<Complex, Compl
|
|||||||
|
|
||||||
override fun Complex.wrap(): Complex = this
|
override fun Complex.wrap(): Complex = this
|
||||||
|
|
||||||
override val context: ComplexField
|
override val context: ComplexField get() = ComplexField
|
||||||
get() = ComplexField
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A complex conjugate
|
* A complex conjugate
|
||||||
*/
|
*/
|
||||||
val conjugate: Complex
|
val conjugate: Complex get() = Complex(re, -im)
|
||||||
get() = Complex(re, -im)
|
|
||||||
|
|
||||||
val square: Double
|
val square: Double get() = re * re + im * im
|
||||||
get() = re * re + im * im
|
|
||||||
|
|
||||||
val abs: Double
|
val abs: Double get() = sqrt(square)
|
||||||
get() = kotlin.math.sqrt(square)
|
|
||||||
|
val theta: Double get() = atan(im / re)
|
||||||
|
|
||||||
companion object
|
companion object
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ class BoxingNDField<T, F : Field<T>>(
|
|||||||
|
|
||||||
override val strides: Strides = DefaultStrides(shape)
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T> =
|
||||||
bufferFactory(size, initializer)
|
bufferFactory(size, initializer)
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
|
@ -5,8 +5,6 @@ import scientifik.kmath.operations.*
|
|||||||
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
interface BufferedNDAlgebra<T, C>: NDAlgebra<T, C, NDBuffer<T>>{
|
||||||
val strides: Strides
|
val strides: Strides
|
||||||
|
|
||||||
fun buildBuffer(size: Int, initializer: (Int) -> T): Buffer<T>
|
|
||||||
|
|
||||||
override fun check(vararg elements: NDBuffer<T>) {
|
override fun check(vararg elements: NDBuffer<T>) {
|
||||||
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
if (!elements.all { it.strides == this.strides }) error("Strides mismatch")
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
package scientifik.kmath.structures
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
import scientifik.kmath.operations.Field
|
import scientifik.kmath.operations.ExtendedField
|
||||||
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A simple field over linear buffers of [Double]
|
* A simple field over linear buffers of [Double]
|
||||||
*/
|
*/
|
||||||
class RealBufferField(val size: Int) : Field<Buffer<Double>> {
|
class RealBufferField(val size: Int) : ExtendedField<Buffer<Double>> {
|
||||||
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
|
override val zero: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 0.0 }
|
||||||
|
|
||||||
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
|
override val one: Buffer<Double> = Buffer.DoubleBufferFactory(size) { 1.0 }
|
||||||
@ -56,4 +57,54 @@ class RealBufferField(val size: Int) : Field<Buffer<Double>> {
|
|||||||
DoubleBuffer(DoubleArray(size) { a[it] / b[it] })
|
DoubleBuffer(DoubleArray(size) { a[it] / b[it] })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun sin(arg: Buffer<Double>): Buffer<Double> {
|
||||||
|
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||||
|
return if (arg is DoubleBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
DoubleBuffer(DoubleArray(size) { sin(array[it]) })
|
||||||
|
} else {
|
||||||
|
DoubleBuffer(DoubleArray(size) { sin(arg[it]) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun cos(arg: Buffer<Double>): Buffer<Double> {
|
||||||
|
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||||
|
return if (arg is DoubleBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
DoubleBuffer(DoubleArray(size) { cos(array[it]) })
|
||||||
|
} else {
|
||||||
|
DoubleBuffer(DoubleArray(size) { cos(arg[it]) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: Buffer<Double>, pow: Number): Buffer<Double> {
|
||||||
|
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||||
|
return if (arg is DoubleBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
DoubleBuffer(DoubleArray(size) { array[it].pow(pow.toDouble()) })
|
||||||
|
} else {
|
||||||
|
DoubleBuffer(DoubleArray(size) { arg[it].pow(pow.toDouble()) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun exp(arg: Buffer<Double>): Buffer<Double> {
|
||||||
|
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||||
|
return if (arg is DoubleBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
DoubleBuffer(DoubleArray(size) { exp(array[it]) })
|
||||||
|
} else {
|
||||||
|
DoubleBuffer(DoubleArray(size) { exp(arg[it]) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun ln(arg: Buffer<Double>): Buffer<Double> {
|
||||||
|
require(arg.size == size) { "The size of buffer is ${arg.size} but context requires $size " }
|
||||||
|
return if (arg is DoubleBuffer) {
|
||||||
|
val array = arg.array
|
||||||
|
DoubleBuffer(DoubleArray(size) { ln(array[it]) })
|
||||||
|
} else {
|
||||||
|
DoubleBuffer(DoubleArray(size) { ln(arg[it]) })
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -15,8 +15,7 @@ class RealNDField(override val shape: IntArray) :
|
|||||||
override val zero by lazy { produce { zero } }
|
override val zero by lazy { produce { zero } }
|
||||||
override val one by lazy { produce { one } }
|
override val one by lazy { produce { one } }
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
||||||
override inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Double): Buffer<Double> =
|
|
||||||
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
DoubleBuffer(DoubleArray(size) { initializer(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -15,8 +15,7 @@ class ShortNDRing(override val shape: IntArray) :
|
|||||||
override val zero by lazy { produce { ShortRing.zero } }
|
override val zero by lazy { produce { ShortRing.zero } }
|
||||||
override val one by lazy { produce { ShortRing.one } }
|
override val one by lazy { produce { ShortRing.one } }
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
|
||||||
override inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Short): Buffer<Short> =
|
|
||||||
ShortBuffer(ShortArray(size) { initializer(it) })
|
ShortBuffer(ShortArray(size) { initializer(it) })
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -7,8 +7,15 @@ import java.nio.ByteBuffer
|
|||||||
* A specification for serialization and deserialization objects to buffer
|
* A specification for serialization and deserialization objects to buffer
|
||||||
*/
|
*/
|
||||||
interface BufferSpec<T : Any> {
|
interface BufferSpec<T : Any> {
|
||||||
fun fromBuffer(buffer: ByteBuffer): T
|
/**
|
||||||
fun toBuffer(value: T): ByteBuffer
|
* Read an object from buffer in current position
|
||||||
|
*/
|
||||||
|
fun ByteBuffer.readObject(): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Write object to [ByteBuffer] in current buffer position
|
||||||
|
*/
|
||||||
|
fun ByteBuffer.writeObject(value: T)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -17,39 +24,21 @@ interface BufferSpec<T : Any> {
|
|||||||
interface FixedSizeBufferSpec<T : Any> : BufferSpec<T> {
|
interface FixedSizeBufferSpec<T : Any> : BufferSpec<T> {
|
||||||
val unitSize: Int
|
val unitSize: Int
|
||||||
|
|
||||||
/**
|
|
||||||
* Read an object from buffer in current position
|
|
||||||
*/
|
|
||||||
fun ByteBuffer.readObject(): T {
|
|
||||||
val buffer = ByteArray(unitSize)
|
|
||||||
get(buffer)
|
|
||||||
return fromBuffer(ByteBuffer.wrap(buffer))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read an object from buffer in given index (not buffer position
|
* Read an object from buffer in given index (not buffer position
|
||||||
*/
|
*/
|
||||||
fun ByteBuffer.readObject(index: Int): T {
|
fun ByteBuffer.readObject(index: Int): T {
|
||||||
val dup = duplicate()
|
position(index * unitSize)
|
||||||
dup.position(index * unitSize)
|
return readObject()
|
||||||
return dup.readObject()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Write object to [ByteBuffer] in current buffer position
|
|
||||||
*/
|
|
||||||
fun ByteBuffer.writeObject(obj: T) {
|
|
||||||
val buffer = toBuffer(obj).apply { rewind() }
|
|
||||||
assert(buffer.limit() == unitSize)
|
|
||||||
put(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Put an object in given index
|
* Put an object in given index
|
||||||
*/
|
*/
|
||||||
fun ByteBuffer.writeObject(index: Int, obj: T) {
|
fun ByteBuffer.writeObject(index: Int, obj: T) {
|
||||||
val dup = duplicate()
|
position(index * unitSize)
|
||||||
dup.position(index * unitSize)
|
writeObject(obj)
|
||||||
dup.writeObject(obj)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -4,16 +4,20 @@ import scientifik.kmath.operations.Complex
|
|||||||
import scientifik.kmath.operations.ComplexField
|
import scientifik.kmath.operations.ComplexField
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A serialization specification for complex numbers
|
||||||
|
*/
|
||||||
object ComplexBufferSpec : FixedSizeBufferSpec<Complex> {
|
object ComplexBufferSpec : FixedSizeBufferSpec<Complex> {
|
||||||
|
|
||||||
override val unitSize: Int = 16
|
override val unitSize: Int = 16
|
||||||
|
|
||||||
override fun fromBuffer(buffer: ByteBuffer): Complex {
|
override fun ByteBuffer.readObject(): Complex {
|
||||||
val re = buffer.getDouble(0)
|
val re = double
|
||||||
val im = buffer.getDouble(8)
|
val im = double
|
||||||
return Complex(re, im)
|
return Complex(re, im)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun toBuffer(value: Complex): ByteBuffer = ByteBuffer.allocate(16).apply {
|
override fun ByteBuffer.writeObject(value: Complex) {
|
||||||
putDouble(value.re)
|
putDouble(value.re)
|
||||||
putDouble(value.im)
|
putDouble(value.im)
|
||||||
}
|
}
|
||||||
@ -22,14 +26,13 @@ object ComplexBufferSpec : FixedSizeBufferSpec<Complex> {
|
|||||||
/**
|
/**
|
||||||
* Create a read-only/mutable buffer which ignores boxing
|
* Create a read-only/mutable buffer which ignores boxing
|
||||||
*/
|
*/
|
||||||
fun Buffer.Companion.complex(size: Int): Buffer<Complex> =
|
fun Buffer.Companion.complex(size: Int, initializer: ((Int) -> Complex)? = null): Buffer<Complex> =
|
||||||
ObjectBuffer.create(ComplexBufferSpec, size)
|
ObjectBuffer.create(ComplexBufferSpec, size, initializer)
|
||||||
|
|
||||||
fun MutableBuffer.Companion.complex(size: Int) =
|
fun MutableBuffer.Companion.complex(size: Int, initializer: ((Int) -> Complex)? = null) =
|
||||||
ObjectBuffer.create(ComplexBufferSpec, size)
|
ObjectBuffer.create(ComplexBufferSpec, size, initializer)
|
||||||
|
|
||||||
fun NDField.Companion.complex(shape: IntArray) =
|
fun NDField.Companion.complex(shape: IntArray) = ComplexNDField(shape)
|
||||||
BoxingNDField(shape, ComplexField) { size, init -> ObjectBuffer.create(ComplexBufferSpec, size, init) }
|
|
||||||
|
|
||||||
fun NDElement.Companion.complex(shape: IntArray, initializer: ComplexField.(IntArray) -> Complex) =
|
fun NDElement.Companion.complex(shape: IntArray, initializer: ComplexField.(IntArray) -> Complex) =
|
||||||
NDField.complex(shape).produce(initializer)
|
NDField.complex(shape).produce(initializer)
|
||||||
|
@ -0,0 +1,125 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Complex
|
||||||
|
import scientifik.kmath.operations.ComplexField
|
||||||
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
|
||||||
|
typealias ComplexNDElement = BufferedNDFieldElement<Complex, ComplexField>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An optimized nd-field for complex numbers
|
||||||
|
*/
|
||||||
|
class ComplexNDField(override val shape: IntArray) :
|
||||||
|
BufferedNDField<Complex, ComplexField>,
|
||||||
|
ExtendedNDField<Complex, ComplexField, NDBuffer<Complex>> {
|
||||||
|
|
||||||
|
override val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
|
override val elementContext: ComplexField get() = ComplexField
|
||||||
|
override val zero by lazy { produce { zero } }
|
||||||
|
override val one by lazy { produce { one } }
|
||||||
|
|
||||||
|
inline fun buildBuffer(size: Int, crossinline initializer: (Int) -> Complex): Buffer<Complex> =
|
||||||
|
Buffer.complex(size) { initializer(it) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline transform an NDStructure to another structure
|
||||||
|
*/
|
||||||
|
override fun map(
|
||||||
|
arg: NDBuffer<Complex>,
|
||||||
|
transform: ComplexField.(Complex) -> Complex
|
||||||
|
): ComplexNDElement {
|
||||||
|
check(arg)
|
||||||
|
val array = buildBuffer(arg.strides.linearSize) { offset -> ComplexField.transform(arg.buffer[offset]) }
|
||||||
|
return BufferedNDFieldElement(this, array)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun produce(initializer: ComplexField.(IntArray) -> Complex): ComplexNDElement {
|
||||||
|
val array = buildBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
||||||
|
return BufferedNDFieldElement(this, array)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun mapIndexed(
|
||||||
|
arg: NDBuffer<Complex>,
|
||||||
|
transform: ComplexField.(index: IntArray, Complex) -> Complex
|
||||||
|
): ComplexNDElement {
|
||||||
|
check(arg)
|
||||||
|
return BufferedNDFieldElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(arg.strides.linearSize) { offset ->
|
||||||
|
elementContext.transform(
|
||||||
|
arg.strides.index(offset),
|
||||||
|
arg.buffer[offset]
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun combine(
|
||||||
|
a: NDBuffer<Complex>,
|
||||||
|
b: NDBuffer<Complex>,
|
||||||
|
transform: ComplexField.(Complex, Complex) -> Complex
|
||||||
|
): ComplexNDElement {
|
||||||
|
check(a, b)
|
||||||
|
return BufferedNDFieldElement(
|
||||||
|
this,
|
||||||
|
buildBuffer(strides.linearSize) { offset -> elementContext.transform(a.buffer[offset], b.buffer[offset]) })
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun NDBuffer<Complex>.toElement(): FieldElement<NDBuffer<Complex>, *, out BufferedNDField<Complex, ComplexField>> =
|
||||||
|
BufferedNDFieldElement(this@ComplexNDField, buffer)
|
||||||
|
|
||||||
|
override fun power(arg: NDBuffer<Complex>, pow: Number) = map(arg) { power(it, pow) }
|
||||||
|
|
||||||
|
override fun exp(arg: NDBuffer<Complex>) = map(arg) { exp(it) }
|
||||||
|
|
||||||
|
override fun ln(arg: NDBuffer<Complex>) = map(arg) { ln(it) }
|
||||||
|
|
||||||
|
override fun sin(arg: NDBuffer<Complex>) = map(arg) { sin(it) }
|
||||||
|
|
||||||
|
override fun cos(arg: NDBuffer<Complex>) = map(arg) { cos(it) }
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fast element production using function inlining
|
||||||
|
*/
|
||||||
|
inline fun BufferedNDField<Complex, ComplexField>.produceInline(crossinline initializer: ComplexField.(Int) -> Complex): ComplexNDElement {
|
||||||
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
|
||||||
|
return BufferedNDFieldElement(this, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map one [ComplexNDElement] using function with indexes
|
||||||
|
*/
|
||||||
|
inline fun ComplexNDElement.mapIndexed(crossinline transform: ComplexField.(index: IntArray, Complex) -> Complex) =
|
||||||
|
context.produceInline { offset -> transform(strides.index(offset), buffer[offset]) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map one [ComplexNDElement] using function without indexes
|
||||||
|
*/
|
||||||
|
inline fun ComplexNDElement.map(crossinline transform: ComplexField.(Complex) -> Complex): ComplexNDElement {
|
||||||
|
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.transform(buffer[offset]) }
|
||||||
|
return BufferedNDFieldElement(context, buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element by element application of any operation on elements to the whole array. Just like in numpy
|
||||||
|
*/
|
||||||
|
operator fun Function1<Complex, Complex>.invoke(ndElement: ComplexNDElement) =
|
||||||
|
ndElement.map { this@invoke(it) }
|
||||||
|
|
||||||
|
|
||||||
|
/* plus and minus */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Summation operation for [BufferedNDElement] and single element
|
||||||
|
*/
|
||||||
|
operator fun ComplexNDElement.plus(arg: Complex) =
|
||||||
|
map { it + arg }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Subtraction operation between [BufferedNDElement] and single element
|
||||||
|
*/
|
||||||
|
operator fun ComplexNDElement.minus(arg: Complex) =
|
||||||
|
map { it - arg }
|
@ -2,6 +2,9 @@ package scientifik.kmath.structures
|
|||||||
|
|
||||||
import java.nio.ByteBuffer
|
import java.nio.ByteBuffer
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A non-boxing buffer based on [ByteBuffer] storage
|
||||||
|
*/
|
||||||
class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: FixedSizeBufferSpec<T>) :
|
class ObjectBuffer<T : Any>(private val buffer: ByteBuffer, private val spec: FixedSizeBufferSpec<T>) :
|
||||||
MutableBuffer<T> {
|
MutableBuffer<T> {
|
||||||
override val size: Int
|
override val size: Int
|
||||||
|
@ -6,17 +6,22 @@ import java.nio.ByteBuffer
|
|||||||
object RealBufferSpec : FixedSizeBufferSpec<Real> {
|
object RealBufferSpec : FixedSizeBufferSpec<Real> {
|
||||||
override val unitSize: Int = 8
|
override val unitSize: Int = 8
|
||||||
|
|
||||||
override fun fromBuffer(buffer: ByteBuffer): Real = Real(buffer.double)
|
override fun ByteBuffer.readObject(): Real = Real(double)
|
||||||
|
|
||||||
override fun toBuffer(value: Real): ByteBuffer = ByteBuffer.allocate(8).apply { putDouble(value.value) }
|
override fun ByteBuffer.writeObject(value: Real) {
|
||||||
|
putDouble(value.value)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object DoubleBufferSpec : FixedSizeBufferSpec<Double> {
|
object DoubleBufferSpec : FixedSizeBufferSpec<Double> {
|
||||||
override val unitSize: Int = 8
|
override val unitSize: Int = 8
|
||||||
|
|
||||||
override fun fromBuffer(buffer: ByteBuffer): Double = buffer.double
|
override fun ByteBuffer.readObject() = double
|
||||||
|
|
||||||
|
override fun ByteBuffer.writeObject(value: Double) {
|
||||||
|
putDouble(value)
|
||||||
|
}
|
||||||
|
|
||||||
override fun toBuffer(value: Double): ByteBuffer = ByteBuffer.allocate(8).apply { putDouble(value) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fun Double.Companion.createBuffer(size: Int) = ObjectBuffer.create(DoubleBufferSpec, size)
|
fun Double.Companion.createBuffer(size: Int) = ObjectBuffer.create(DoubleBufferSpec, size)
|
||||||
|
Loading…
Reference in New Issue
Block a user