[WIP] wasm prototype

This commit is contained in:
Alexander Nozik 2023-02-11 20:24:35 +03:00
parent 2c6d1e89c5
commit 8437dd1cc1
7 changed files with 135 additions and 3 deletions

View File

@ -10,7 +10,6 @@ import kotlinx.html.h3
import space.kscience.kmath.data.XYErrorColumnarData import space.kscience.kmath.data.XYErrorColumnarData
import space.kscience.kmath.distributions.NormalDistribution import space.kscience.kmath.distributions.NormalDistribution
import space.kscience.kmath.expressions.Symbol import space.kscience.kmath.expressions.Symbol
import space.kscience.kmath.expressions.autodiff
import space.kscience.kmath.expressions.binding import space.kscience.kmath.expressions.binding
import space.kscience.kmath.expressions.symbol import space.kscience.kmath.expressions.symbol
import space.kscience.kmath.operations.asIterable import space.kscience.kmath.operations.asIterable
@ -62,7 +61,6 @@ suspend fun main() {
val result = XYErrorColumnarData.of(x, y, yErr).fitWith( val result = XYErrorColumnarData.of(x, y, yErr).fitWith(
QowOptimizer, QowOptimizer,
Double.autodiff,
mapOf(a to 0.9, b to 1.2, c to 2.0, e to 1.0, d to 1.0, e to 0.0), mapOf(a to 0.9, b to 1.2, c to 2.0, e to 1.0, d to 1.0, e to 0.0),
OptimizationParameters(a, b, c, d) OptimizationParameters(a, b, c, d)
) { arg -> ) { arg ->

View File

@ -9,7 +9,7 @@ kotlin.native.ignoreDisabledTargets=true
org.gradle.configureondemand=true org.gradle.configureondemand=true
org.gradle.jvmargs=-Xmx4096m org.gradle.jvmargs=-Xmx4096m
toolsVersion=0.14.0-kotlin-1.8.10 toolsVersion=0.14.1-kotlin-1.8.20-Beta
org.gradle.parallel=true org.gradle.parallel=true

View File

@ -6,6 +6,7 @@ kscience{
jvm() jvm()
js() js()
native() native()
wasm()
dependencies { dependencies {
api(projects.kmathMemory) api(projects.kmathMemory)

View File

@ -8,6 +8,12 @@ kscience {
native() native()
} }
kotlin {
wasm {
browser()
}
}
readme { readme {
maturity = space.kscience.gradle.Maturity.DEVELOPMENT maturity = space.kscience.gradle.Maturity.DEVELOPMENT
description = """ description = """

View File

@ -0,0 +1,100 @@
/*
* Copyright 2018-2023 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.memory
import kotlin.wasm.*
@PublishedApi
internal class WasmMemory(
val array: ByteArray,
val startOffset: Int = 0,
override val size: Int = array.size,
) : Memory {
@Suppress("NOTHING_TO_INLINE")
private inline fun position(o: Int): Int = startOffset + o
override fun view(offset: Int, length: Int): Memory {
require(offset >= 0) { "offset shouldn't be negative: $offset" }
require(length >= 0) { "length shouldn't be negative: $length" }
require(offset + length <= size) { "Can't view memory outside the parent region." }
return WasmMemory(array, position(offset), length)
}
override fun copy(): Memory {
val copy = array.copyOfRange(startOffset, startOffset + size)
return WasmMemory(copy)
}
private val reader: MemoryReader = object : MemoryReader {
override val memory: Memory get() = this@WasmMemory
override fun readDouble(offset: Int) = array.getDoubleAt(position(offset))
override fun readFloat(offset: Int) = array.getFloatAt(position(offset))
override fun readByte(offset: Int) = array[position(offset)]
override fun readShort(offset: Int) = array.getShortAt(position(offset))
override fun readInt(offset: Int) = array.getIntAt(position(offset))
override fun readLong(offset: Int) = array.getLongAt(position(offset))
override fun release() {
// does nothing on JVM
}
}
override fun reader(): MemoryReader = reader
private val writer: MemoryWriter = object : MemoryWriter {
override val memory: Memory get() = this@WasmMemory
override fun writeDouble(offset: Int, value: Double) {
array.setDoubleAt(position(offset), value)
}
override fun writeFloat(offset: Int, value: Float) {
array.setFloatAt(position(offset), value)
}
override fun writeByte(offset: Int, value: Byte) {
array[position(offset)] = value
}
override fun writeShort(offset: Int, value: Short) {
array.setShortAt(position(offset), value)
}
override fun writeInt(offset: Int, value: Int) {
array.setIntAt(position(offset), value)
}
override fun writeLong(offset: Int, value: Long) {
array.setLongAt(position(offset), value)
}
override fun release() {
// does nothing on JVM
}
}
override fun writer(): MemoryWriter = writer
}
/**
* Wraps a [Memory] around existing [ByteArray]. This operation is unsafe since the array is not copied
* and could be mutated independently of the resulting [Memory].
*/
public actual fun Memory.Companion.wrap(array: ByteArray): Memory = WasmMemory(array)
/**
* Allocates the most effective platform-specific memory.
*/
public actual fun Memory.Companion.allocate(length: Int): Memory {
val array = ByteArray(length)
return WasmMemory(array)
}

View File

@ -12,6 +12,7 @@ import space.kscience.kmath.expressions.*
import space.kscience.kmath.misc.FeatureSet import space.kscience.kmath.misc.FeatureSet
import space.kscience.kmath.misc.Loggable import space.kscience.kmath.misc.Loggable
import space.kscience.kmath.misc.UnstableKMathAPI import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.ExtendedField import space.kscience.kmath.operations.ExtendedField
import space.kscience.kmath.operations.bindSymbol import space.kscience.kmath.operations.bindSymbol
import kotlin.math.pow import kotlin.math.pow
@ -147,6 +148,31 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
) )
} }
public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
optimizer: Optimizer<Double, XYFit>,
startingPoint: Map<Symbol, Double>,
vararg features: OptimizationFeature = emptyArray(),
xSymbol: Symbol = Symbol.x,
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
pointWeight: PointWeight = PointWeight.byYSigma,
model: DSField<Double, DoubleField>.(DS<Double, DoubleField>) -> DS<Double, DoubleField>,
): XYFit {
val modelExpression: DifferentiableExpression<Double> = Double.autodiff.differentiate {
val x = bindSymbol(xSymbol)
model(x)
}
return fitWith(
optimizer = optimizer,
modelExpression = modelExpression,
startingPoint = startingPoint,
features = features,
xSymbol = xSymbol,
pointToCurveDistance = pointToCurveDistance,
pointWeight = pointWeight
)
}
/** /**
* Compute chi squared value for completed fit. Return null for incomplete fit * Compute chi squared value for completed fit. Return null for incomplete fit
*/ */

View File

@ -6,6 +6,7 @@ kscience{
jvm() jvm()
js() js()
native() native()
wasm()
} }
kotlin.sourceSets { kotlin.sourceSets {