Optimize Real NDField
This commit is contained in:
parent
1cb41f4dc2
commit
9829a16a32
@ -15,9 +15,9 @@ internal class ViktorBenchmark {
|
|||||||
final val n: Int = 100
|
final val n: Int = 100
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
final val autoField: BufferedNDField<Double, RealField> = NDAlgebra.auto(RealField, dim, dim)
|
final val autoField: NDField<Double, RealField> = NDAlgebra.auto(RealField, dim, dim)
|
||||||
final val realField: RealNDField = NDAlgebra.real(dim, dim)
|
final val realField: RealNDField = NDAlgebra.real(dim, dim)
|
||||||
final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim))
|
final val viktorField: ViktorNDField = ViktorNDField(dim, dim)
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun automaticFieldAddition() {
|
fun automaticFieldAddition() {
|
||||||
@ -27,6 +27,14 @@ internal class ViktorBenchmark {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun realFieldAddition() {
|
||||||
|
realField {
|
||||||
|
var res: NDStructure<Double> = one
|
||||||
|
repeat(n) { res += one }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun viktorFieldAddition() {
|
fun viktorFieldAddition() {
|
||||||
viktorField {
|
viktorField {
|
||||||
@ -41,22 +49,4 @@ internal class ViktorBenchmark {
|
|||||||
var res = one
|
var res = one
|
||||||
repeat(n) { res = res + one }
|
repeat(n) { res = res + one }
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
|
||||||
fun realFieldLog() {
|
|
||||||
realField {
|
|
||||||
val fortyTwo = produce { 42.0 }
|
|
||||||
var res = one
|
|
||||||
repeat(n) { res = ln(fortyTwo) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Benchmark
|
|
||||||
fun rawViktorLog() {
|
|
||||||
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
|
||||||
var res: F64Array
|
|
||||||
repeat(n) {
|
|
||||||
res = fortyTwo.log()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
package kscience.kmath.benchmarks
|
||||||
|
|
||||||
|
import kscience.kmath.nd.*
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.viktor.ViktorNDField
|
||||||
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
|
import org.openjdk.jmh.annotations.Benchmark
|
||||||
|
import org.openjdk.jmh.annotations.Scope
|
||||||
|
import org.openjdk.jmh.annotations.State
|
||||||
|
|
||||||
|
@State(Scope.Benchmark)
|
||||||
|
internal class ViktorLogBenchmark {
|
||||||
|
final val dim: Int = 1000
|
||||||
|
final val n: Int = 100
|
||||||
|
|
||||||
|
// automatically build context most suited for given type.
|
||||||
|
final val autoField: BufferedNDField<Double, RealField> = NDAlgebra.auto(RealField, dim, dim)
|
||||||
|
final val realField: RealNDField = NDAlgebra.real(dim, dim)
|
||||||
|
final val viktorField: ViktorNDField = ViktorNDField(intArrayOf(dim, dim))
|
||||||
|
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun realFieldLog() {
|
||||||
|
realField {
|
||||||
|
val fortyTwo = produce { 42.0 }
|
||||||
|
var res = one
|
||||||
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Benchmark
|
||||||
|
fun rawViktorLog() {
|
||||||
|
val fortyTwo = F64Array.full(dim, dim, init = 42.0)
|
||||||
|
var res: F64Array
|
||||||
|
repeat(n) {
|
||||||
|
res = fortyTwo.log()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -5,6 +5,7 @@ import kscience.kmath.nd.*
|
|||||||
import kscience.kmath.nd4j.Nd4jArrayField
|
import kscience.kmath.nd4j.Nd4jArrayField
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
|
import kscience.kmath.viktor.ViktorNDField
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
@ -25,18 +26,15 @@ fun main() {
|
|||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
val autoField = NDAlgebra.auto(RealField, dim, dim)
|
val autoField = NDAlgebra.auto(RealField, dim, dim)
|
||||||
// specialized nd-field for Double. It works as generic Double field as well
|
// specialized nd-field for Double. It works as generic Double field as well
|
||||||
val specializedField = NDAlgebra.real(dim, dim)
|
val realField = NDAlgebra.real(dim, dim)
|
||||||
//A generic boxing field. It should be used for objects, not primitives.
|
//A generic boxing field. It should be used for objects, not primitives.
|
||||||
val boxingField = NDAlgebra.field(RealField, Buffer.Companion::boxing, dim, dim)
|
val boxingField = NDAlgebra.field(RealField, Buffer.Companion::boxing, dim, dim)
|
||||||
// Nd4j specialized field.
|
// Nd4j specialized field.
|
||||||
val nd4jField = Nd4jArrayField.real(dim, dim)
|
val nd4jField = Nd4jArrayField.real(dim, dim)
|
||||||
|
//viktor field
|
||||||
measureAndPrint("Automatic field addition") {
|
val viktorField = ViktorNDField(dim,dim)
|
||||||
autoField {
|
//parallel processing based on Java Streams
|
||||||
var res: NDStructure<Double> = one
|
val parallelField = NDAlgebra.realWithStream(dim,dim)
|
||||||
repeat(n) { res += 1.0 }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
measureAndPrint("Boxing addition") {
|
measureAndPrint("Boxing addition") {
|
||||||
boxingField {
|
boxingField {
|
||||||
@ -46,7 +44,7 @@ fun main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Specialized addition") {
|
measureAndPrint("Specialized addition") {
|
||||||
specializedField {
|
realField {
|
||||||
var res: NDStructure<Double> = one
|
var res: NDStructure<Double> = one
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
@ -59,8 +57,29 @@ fun main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
measureAndPrint("Viktor addition") {
|
||||||
|
viktorField {
|
||||||
|
var res: NDStructure<Double> = one
|
||||||
|
repeat(n) { res += 1.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
measureAndPrint("Parallel stream addition") {
|
||||||
|
parallelField {
|
||||||
|
var res: NDStructure<Double> = one
|
||||||
|
repeat(n) { res += 1.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
measureAndPrint("Automatic field addition") {
|
||||||
|
autoField {
|
||||||
|
var res: NDStructure<Double> = one
|
||||||
|
repeat(n) { res += 1.0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
measureAndPrint("Lazy addition") {
|
measureAndPrint("Lazy addition") {
|
||||||
val res = specializedField.one.mapAsync(GlobalScope) {
|
val res = realField.one.mapAsync(GlobalScope) {
|
||||||
var c = 0.0
|
var c = 0.0
|
||||||
repeat(n) {
|
repeat(n) {
|
||||||
c += 1.0
|
c += 1.0
|
||||||
|
@ -0,0 +1,103 @@
|
|||||||
|
package kscience.kmath.structures
|
||||||
|
|
||||||
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import kscience.kmath.nd.*
|
||||||
|
import kscience.kmath.operations.ExtendedField
|
||||||
|
import kscience.kmath.operations.RealField
|
||||||
|
import kscience.kmath.operations.RingWithNumbers
|
||||||
|
import java.util.*
|
||||||
|
import java.util.stream.IntStream
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A demonstration implementation of NDField over Real using Java [DoubleStream] for parallel execution
|
||||||
|
*/
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public class StreamRealNDField(
|
||||||
|
shape: IntArray,
|
||||||
|
) : BufferedNDField<Double, RealField>(shape, RealField, Buffer.Companion::real),
|
||||||
|
RingWithNumbers<NDStructure<Double>>,
|
||||||
|
ExtendedField<NDStructure<Double>> {
|
||||||
|
|
||||||
|
override val zero: NDBuffer<Double> by lazy { produce { zero } }
|
||||||
|
override val one: NDBuffer<Double> by lazy { produce { one } }
|
||||||
|
|
||||||
|
override fun number(value: Number): NDBuffer<Double> {
|
||||||
|
val d = value.toDouble() // minimize conversions
|
||||||
|
return produce { d }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val NDStructure<Double>.buffer: RealBuffer
|
||||||
|
get() = when {
|
||||||
|
!shape.contentEquals(this@StreamRealNDField.shape) -> throw ShapeMismatchException(
|
||||||
|
this@StreamRealNDField.shape,
|
||||||
|
shape
|
||||||
|
)
|
||||||
|
this is NDBuffer && this.strides == this@StreamRealNDField.strides -> this.buffer as RealBuffer
|
||||||
|
else -> RealBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override fun produce(initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
||||||
|
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||||
|
val index = strides.index(offset)
|
||||||
|
RealField.initializer(index)
|
||||||
|
}.toArray()
|
||||||
|
|
||||||
|
return NDBuffer(strides, array.asBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun map(
|
||||||
|
arg: NDStructure<Double>,
|
||||||
|
transform: RealField.(Double) -> Double,
|
||||||
|
): NDBuffer<Double> {
|
||||||
|
val array = Arrays.stream(arg.buffer.array).parallel().map { RealField.transform(it) }.toArray()
|
||||||
|
return NDBuffer(strides, array.asBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun mapIndexed(
|
||||||
|
arg: NDStructure<Double>,
|
||||||
|
transform: RealField.(index: IntArray, Double) -> Double,
|
||||||
|
): NDBuffer<Double> {
|
||||||
|
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||||
|
RealField.transform(
|
||||||
|
strides.index(offset),
|
||||||
|
arg.buffer.array[offset]
|
||||||
|
)
|
||||||
|
}.toArray()
|
||||||
|
|
||||||
|
return NDBuffer(strides, array.asBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun combine(
|
||||||
|
a: NDStructure<Double>,
|
||||||
|
b: NDStructure<Double>,
|
||||||
|
transform: RealField.(Double, Double) -> Double,
|
||||||
|
): NDBuffer<Double> {
|
||||||
|
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||||
|
RealField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
||||||
|
}.toArray()
|
||||||
|
return NDBuffer(strides, array.asBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun power(arg: NDStructure<Double>, pow: Number): NDBuffer<Double> = map(arg) { power(it, pow) }
|
||||||
|
|
||||||
|
override fun exp(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { exp(it) }
|
||||||
|
|
||||||
|
override fun ln(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { ln(it) }
|
||||||
|
|
||||||
|
override fun sin(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { sin(it) }
|
||||||
|
override fun cos(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { cos(it) }
|
||||||
|
override fun tan(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { tan(it) }
|
||||||
|
override fun asin(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { asin(it) }
|
||||||
|
override fun acos(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { acos(it) }
|
||||||
|
override fun atan(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { atan(it) }
|
||||||
|
|
||||||
|
override fun sinh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { sinh(it) }
|
||||||
|
override fun cosh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { cosh(it) }
|
||||||
|
override fun tanh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { tanh(it) }
|
||||||
|
override fun asinh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { asinh(it) }
|
||||||
|
override fun acosh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { acosh(it) }
|
||||||
|
override fun atanh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { atanh(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fun NDAlgebra.Companion.realWithStream(vararg shape: Int): StreamRealNDField = StreamRealNDField(shape)
|
@ -18,40 +18,36 @@ public interface BufferNDAlgebra<T, C> : NDAlgebra<T, C> {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
public val NDStructure<T>.ndBuffer: NDBuffer<T>
|
public val NDStructure<T>.buffer: Buffer<T>
|
||||||
get() = when {
|
get() = when {
|
||||||
!shape.contentEquals(this@BufferNDAlgebra.shape) -> throw ShapeMismatchException(
|
!shape.contentEquals(this@BufferNDAlgebra.shape) -> throw ShapeMismatchException(
|
||||||
this@BufferNDAlgebra.shape,
|
this@BufferNDAlgebra.shape,
|
||||||
shape
|
shape
|
||||||
)
|
)
|
||||||
this is NDBuffer && this.strides == this@BufferNDAlgebra.strides -> this
|
this is NDBuffer && this.strides == this@BufferNDAlgebra.strides -> this.buffer
|
||||||
else -> produce { this@ndBuffer[it] }
|
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun map(arg: NDStructure<T>, transform: C.(T) -> T): NDBuffer<T> {
|
override fun map(arg: NDStructure<T>, transform: C.(T) -> T): NDBuffer<T> {
|
||||||
val argAsBuffer = arg.ndBuffer
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||||
elementContext.transform(argAsBuffer.buffer[offset])
|
elementContext.transform(arg.buffer[offset])
|
||||||
}
|
}
|
||||||
return NDBuffer(strides, buffer)
|
return NDBuffer(strides, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun mapIndexed(arg: NDStructure<T>, transform: C.(index: IntArray, T) -> T): NDStructure<T> {
|
override fun mapIndexed(arg: NDStructure<T>, transform: C.(index: IntArray, T) -> T): NDStructure<T> {
|
||||||
val argAsBuffer = arg.ndBuffer
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||||
elementContext.transform(
|
elementContext.transform(
|
||||||
strides.index(offset),
|
strides.index(offset),
|
||||||
argAsBuffer[offset]
|
arg.buffer[offset]
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return NDBuffer(strides, buffer)
|
return NDBuffer(strides, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun combine(a: NDStructure<T>, b: NDStructure<T>, transform: C.(T, T) -> T): NDStructure<T> {
|
override fun combine(a: NDStructure<T>, b: NDStructure<T>, transform: C.(T, T) -> T): NDStructure<T> {
|
||||||
val aBuffer = a.ndBuffer
|
|
||||||
val bBuffer = b.ndBuffer
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
val buffer = bufferFactory(strides.linearSize) { offset ->
|
||||||
elementContext.transform(aBuffer.buffer[offset], bBuffer.buffer[offset])
|
elementContext.transform(a.buffer[offset], b.buffer[offset])
|
||||||
}
|
}
|
||||||
return NDBuffer(strides, buffer)
|
return NDBuffer(strides, buffer)
|
||||||
}
|
}
|
||||||
@ -119,10 +115,14 @@ public fun <T, A : Field<T>> NDAlgebra.Companion.field(
|
|||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
|
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
|
||||||
|
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Any, A : Field<T>> NDAlgebra.Companion.auto(
|
public inline fun <reified T : Any, A : Field<T>> NDAlgebra.Companion.auto(
|
||||||
field: A,
|
field: A,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedNDField<T, A> = BufferedNDField(shape, field, Buffer.Companion::auto)
|
): NDField<T, A> = when (field) {
|
||||||
|
RealField -> RealNDField(shape) as NDField<T, A>
|
||||||
|
else -> BufferedNDField(shape, field, Buffer.Companion::auto)
|
||||||
|
}
|
||||||
|
|
||||||
public inline fun <T, A : Field<T>, R> A.ndField(
|
public inline fun <T, A : Field<T>, R> A.ndField(
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
noinline bufferFactory: BufferFactory<T>,
|
||||||
|
@ -11,7 +11,7 @@ import kotlin.contracts.contract
|
|||||||
* An optimized nd-field for complex numbers
|
* An optimized nd-field for complex numbers
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public open class ComplexNDField(
|
public class ComplexNDField(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
) : BufferedNDField<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
|
) : BufferedNDField<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
|
||||||
RingWithNumbers<NDStructure<Complex>>,
|
RingWithNumbers<NDStructure<Complex>>,
|
||||||
|
@ -24,37 +24,46 @@ public class RealNDField(
|
|||||||
return produce { d }
|
return produce { d }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override val NDStructure<Double>.buffer: RealBuffer
|
||||||
|
get() = when {
|
||||||
|
!shape.contentEquals(this@RealNDField.shape) -> throw ShapeMismatchException(
|
||||||
|
this@RealNDField.shape,
|
||||||
|
shape
|
||||||
|
)
|
||||||
|
this is NDBuffer && this.strides == this@RealNDField.strides -> this.buffer as RealBuffer
|
||||||
|
else -> RealBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||||
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
override inline fun map(
|
override inline fun map(
|
||||||
arg: NDStructure<Double>,
|
arg: NDStructure<Double>,
|
||||||
transform: RealField.(Double) -> Double,
|
transform: RealField.(Double) -> Double,
|
||||||
): NDBuffer<Double> {
|
): NDBuffer<Double> {
|
||||||
val argAsBuffer = arg.ndBuffer
|
val buffer = RealBuffer(strides.linearSize) { offset -> RealField.transform(arg.buffer.array[offset]) }
|
||||||
val buffer = RealBuffer(strides.linearSize) { offset -> RealField.transform(argAsBuffer.buffer[offset]) }
|
|
||||||
return NDBuffer(strides, buffer)
|
return NDBuffer(strides, buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
override inline fun produce(initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
override inline fun produce(initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
||||||
val buffer = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
val array = DoubleArray(strides.linearSize) { offset ->
|
||||||
return NDBuffer(strides, buffer)
|
val index = strides.index(offset)
|
||||||
|
RealField.initializer(index)
|
||||||
|
}
|
||||||
|
return NDBuffer(strides, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
override inline fun mapIndexed(
|
override inline fun mapIndexed(
|
||||||
arg: NDStructure<Double>,
|
arg: NDStructure<Double>,
|
||||||
transform: RealField.(index: IntArray, Double) -> Double,
|
transform: RealField.(index: IntArray, Double) -> Double,
|
||||||
): NDBuffer<Double> {
|
): NDBuffer<Double> = NDBuffer(
|
||||||
val argAsBuffer = arg.ndBuffer
|
strides,
|
||||||
return NDBuffer(
|
buffer = RealBuffer(strides.linearSize) { offset ->
|
||||||
strides,
|
RealField.transform(
|
||||||
RealBuffer(strides.linearSize) { offset ->
|
strides.index(offset),
|
||||||
elementContext.transform(
|
arg.buffer.array[offset]
|
||||||
strides.index(offset),
|
)
|
||||||
argAsBuffer.buffer[offset]
|
})
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
override inline fun combine(
|
override inline fun combine(
|
||||||
@ -62,10 +71,8 @@ public class RealNDField(
|
|||||||
b: NDStructure<Double>,
|
b: NDStructure<Double>,
|
||||||
transform: RealField.(Double, Double) -> Double,
|
transform: RealField.(Double, Double) -> Double,
|
||||||
): NDBuffer<Double> {
|
): NDBuffer<Double> {
|
||||||
val aBuffer = a.ndBuffer
|
|
||||||
val bBuffer = b.ndBuffer
|
|
||||||
val buffer = RealBuffer(strides.linearSize) { offset ->
|
val buffer = RealBuffer(strides.linearSize) { offset ->
|
||||||
elementContext.transform(aBuffer.buffer[offset], bBuffer.buffer[offset])
|
RealField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
||||||
}
|
}
|
||||||
return NDBuffer(strides, buffer)
|
return NDBuffer(strides, buffer)
|
||||||
}
|
}
|
||||||
@ -91,19 +98,6 @@ public class RealNDField(
|
|||||||
override fun atanh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { atanh(it) }
|
override fun atanh(arg: NDStructure<Double>): NDBuffer<Double> = map(arg) { atanh(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fast element production using function inlining
|
|
||||||
*/
|
|
||||||
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
|
||||||
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
val array = DoubleArray(strides.linearSize) { offset ->
|
|
||||||
val index = strides.index(offset)
|
|
||||||
RealField.initializer(index)
|
|
||||||
}
|
|
||||||
return NDBuffer(strides, RealBuffer(array))
|
|
||||||
}
|
|
||||||
|
|
||||||
public fun NDAlgebra.Companion.real(vararg shape: Int): RealNDField = RealNDField(shape)
|
public fun NDAlgebra.Companion.real(vararg shape: Int): RealNDField = RealNDField(shape)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -9,7 +9,7 @@ import kotlin.contracts.InvocationKind
|
|||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public open class ShortNDRing(
|
public class ShortNDRing(
|
||||||
shape: IntArray,
|
shape: IntArray,
|
||||||
) : BufferedNDRing<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto),
|
) : BufferedNDRing<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto),
|
||||||
RingWithNumbers<NDStructure<Short>> {
|
RingWithNumbers<NDStructure<Short>> {
|
||||||
|
@ -58,7 +58,7 @@ public interface Structure2D<T> : NDStructure<T> {
|
|||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
crossinline init: (i: Int, j: Int) -> Double,
|
crossinline init: (i: Int, j: Int) -> Double,
|
||||||
): Matrix<Double> = NDAlgebra.real(rows, columns).produceInline { (i, j) ->
|
): Matrix<Double> = NDAlgebra.real(rows, columns).produce { (i, j) ->
|
||||||
init(i, j)
|
init(i, j)
|
||||||
}.as2D()
|
}.as2D()
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,8 @@ import kotlin.test.assertEquals
|
|||||||
@Suppress("UNUSED_VARIABLE")
|
@Suppress("UNUSED_VARIABLE")
|
||||||
class NumberNDFieldTest {
|
class NumberNDFieldTest {
|
||||||
val algebra = NDAlgebra.real(3,3)
|
val algebra = NDAlgebra.real(3,3)
|
||||||
val array1 = algebra.produceInline { (i, j) -> (i + j).toDouble() }
|
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() }
|
||||||
val array2 = algebra.produceInline { (i, j) -> (i - j).toDouble() }
|
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() }
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSum() {
|
fun testSum() {
|
||||||
|
@ -93,4 +93,6 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
|||||||
|
|
||||||
public override inline fun NDStructure<Double>.plus(arg: Double): ViktorNDStructure =
|
public override inline fun NDStructure<Double>.plus(arg: Double): ViktorNDStructure =
|
||||||
(f64Buffer.plus(arg)).asStructure()
|
(f64Buffer.plus(arg)).asStructure()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public fun ViktorNDField(vararg shape: Int): ViktorNDField = ViktorNDField(shape)
|
Loading…
Reference in New Issue
Block a user