forked from kscience/kmath
Initial commit. Algebra operations. NDArray
This commit is contained in:
parent
ed6d56b627
commit
b42beb4b68
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,14 +1,9 @@
|
|||||||
.gradle
|
.gradle
|
||||||
/build/
|
/build/
|
||||||
|
.idea/
|
||||||
# Ignore Gradle GUI config
|
|
||||||
gradle-app.setting
|
|
||||||
|
|
||||||
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
# Avoid ignoring Gradle wrapper jar file (.jar files are usually ignored)
|
||||||
!gradle-wrapper.jar
|
!gradle-wrapper.jar
|
||||||
|
|
||||||
# Cache of project
|
# Cache of project
|
||||||
.gradletasknamecache
|
.gradletasknamecache
|
||||||
|
|
||||||
# # Work around https://youtrack.jetbrains.com/issue/IDEA-116898
|
|
||||||
# gradle/wrapper/gradle-wrapper.properties
|
|
||||||
|
19
.idea/gradle.xml
Normal file
19
.idea/gradle.xml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="GradleSettings">
|
||||||
|
<option name="linkedExternalProjectsSettings">
|
||||||
|
<GradleProjectSettings>
|
||||||
|
<option name="distributionType" value="LOCAL" />
|
||||||
|
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||||
|
<option name="gradleHome" value="$USER_HOME$/.posh_gvm/gradle/current" />
|
||||||
|
<option name="modules">
|
||||||
|
<set>
|
||||||
|
<option value="$PROJECT_DIR$" />
|
||||||
|
<option value="$PROJECT_DIR$/common" />
|
||||||
|
<option value="$PROJECT_DIR$/jvm" />
|
||||||
|
</set>
|
||||||
|
</option>
|
||||||
|
</GradleProjectSettings>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
</project>
|
7
.idea/misc.xml
Normal file
7
.idea/misc.xml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||||
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_10" project-jdk-name="10" project-jdk-type="JavaSDK">
|
||||||
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
|
</component>
|
||||||
|
</project>
|
23
.idea/modules/common/common_main.iml
Normal file
23
.idea/modules/common/common_main.iml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module version="4">
|
||||||
|
<component name="FacetManager">
|
||||||
|
<facet type="kotlin-language" name="Kotlin">
|
||||||
|
<configuration version="3" platform="Common (experimental) " useProjectSettings="false">
|
||||||
|
<compilerSettings />
|
||||||
|
<compilerArguments>
|
||||||
|
<option name="destination" value="$MODULE_DIR$/../../../common/build/classes/kotlin/main" />
|
||||||
|
<option name="classpath" value="$USER_HOME$/.gradle/caches/modules-2/files-2.1/org.jetbrains.kotlin/kotlin-stdlib-common/1.2.40/37abadf1cca4450f39672a80d24a379d2fd06356/kotlin-stdlib-common-1.2.40.jar" />
|
||||||
|
<option name="languageVersion" value="1.2" />
|
||||||
|
<option name="apiVersion" value="1.2" />
|
||||||
|
<option name="pluginOptions">
|
||||||
|
<array />
|
||||||
|
</option>
|
||||||
|
<option name="pluginClasspaths">
|
||||||
|
<array />
|
||||||
|
</option>
|
||||||
|
<option name="multiPlatform" value="true" />
|
||||||
|
</compilerArguments>
|
||||||
|
</configuration>
|
||||||
|
</facet>
|
||||||
|
</component>
|
||||||
|
</module>
|
23
.idea/modules/common/common_test.iml
Normal file
23
.idea/modules/common/common_test.iml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module version="4">
|
||||||
|
<component name="FacetManager">
|
||||||
|
<facet type="kotlin-language" name="Kotlin">
|
||||||
|
<configuration version="3" platform="Common (experimental) " useProjectSettings="false">
|
||||||
|
<compilerSettings />
|
||||||
|
<compilerArguments>
|
||||||
|
<option name="destination" value="$MODULE_DIR$/../../../common/build/classes/kotlin/test" />
|
||||||
|
<option name="classpath" value="$MODULE_DIR$/../../../common/build/classes/java/main;D:/Work/Projects/kmath/common/build/classes/kotlin/main;D:/Work/Projects/kmath/common/build/resources/main;C:/Users/darksnake/.gradle/caches/modules-2/files-2.1/org.jetbrains.kotlin/kotlin-test-annotations-common/1.2.40/2d4a60c84c93f625f5bffe4cc76b21b243688f5a/kotlin-test-annotations-common-1.2.40.jar;C:/Users/darksnake/.gradle/caches/modules-2/files-2.1/org.jetbrains.kotlin/kotlin-test-common/1.2.40/2364922e3ca01d51cad8f585a2c8c8d731bb375a/kotlin-test-common-1.2.40.jar;C:/Users/darksnake/.gradle/caches/modules-2/files-2.1/org.jetbrains.kotlin/kotlin-stdlib-common/1.2.40/37abadf1cca4450f39672a80d24a379d2fd06356/kotlin-stdlib-common-1.2.40.jar;D:/Work/Projects/kmath/common/build/classes/kotlin/main" />
|
||||||
|
<option name="languageVersion" value="1.2" />
|
||||||
|
<option name="apiVersion" value="1.2" />
|
||||||
|
<option name="pluginOptions">
|
||||||
|
<array />
|
||||||
|
</option>
|
||||||
|
<option name="pluginClasspaths">
|
||||||
|
<array />
|
||||||
|
</option>
|
||||||
|
<option name="multiPlatform" value="true" />
|
||||||
|
</compilerArguments>
|
||||||
|
</configuration>
|
||||||
|
</facet>
|
||||||
|
</component>
|
||||||
|
</module>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
3
build.gradle
Normal file
3
build.gradle
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
group = 'scientifik'
|
||||||
|
version = '0.1-SNAPSHOT'
|
||||||
|
|
25
common/build.gradle
Normal file
25
common/build.gradle
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
buildscript {
|
||||||
|
ext.kotlin_version = '1.2.40'
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
dependencies {
|
||||||
|
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
description = "Platform-independent interfaces for kotlin maths"
|
||||||
|
|
||||||
|
apply plugin: 'kotlin-platform-common'
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
compile "org.jetbrains.kotlin:kotlin-stdlib-common:$kotlin_version"
|
||||||
|
testCompile "org.jetbrains.kotlin:kotlin-test-annotations-common:$kotlin_version"
|
||||||
|
testCompile "org.jetbrains.kotlin:kotlin-test-common:$kotlin_version"
|
||||||
|
}
|
||||||
|
|
103
common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt
Normal file
103
common/src/main/kotlin/scientifik/kmath/operations/Algebra.kt
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A general interface representing linear context of some kind.
|
||||||
|
* The context defines sum operation for its elements and multiplication by real value.
|
||||||
|
* One must note that in some cases context is a singleton class, but in some cases it
|
||||||
|
* works as a context for operations inside it.
|
||||||
|
*
|
||||||
|
* TODO do we need commutative context?
|
||||||
|
*/
|
||||||
|
interface Space<T> {
|
||||||
|
/**
|
||||||
|
* Neutral element for sum operation
|
||||||
|
*/
|
||||||
|
val zero: T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Addition operation for two context elements
|
||||||
|
*/
|
||||||
|
fun add(a: T, b: T): T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplication operation for context element and real number
|
||||||
|
*/
|
||||||
|
fun multiply(a: T, k: Double): T
|
||||||
|
|
||||||
|
//Operation to be performed in this context
|
||||||
|
operator fun T.unaryMinus(): T = multiply(this, -1.0)
|
||||||
|
|
||||||
|
operator fun T.plus(b: T): T = add(this, b)
|
||||||
|
operator fun T.minus(b: T): T = add(this, -b)
|
||||||
|
operator fun T.times(k: Number) = multiply(this, k.toDouble())
|
||||||
|
operator fun T.div(k: Number) = multiply(this, 1.0 / k.toDouble())
|
||||||
|
operator fun Number.times(b: T) = b * this
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The element of linear context
|
||||||
|
* @param S self type of the element. Needed for static type checking
|
||||||
|
*/
|
||||||
|
interface SpaceElement<S : SpaceElement<S>> {
|
||||||
|
/**
|
||||||
|
* The context this element belongs to
|
||||||
|
*/
|
||||||
|
val context: Space<S>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Self value. Needed for static type checking. Needed to avoid type erasure on JVM.
|
||||||
|
*/
|
||||||
|
val self: S
|
||||||
|
|
||||||
|
operator fun plus(b: S): S = with(context) { self + b }
|
||||||
|
operator fun minus(b: S): S = with(context) { self - b }
|
||||||
|
operator fun times(k: Number): S = with(context) { self * k }
|
||||||
|
operator fun div(k: Number): S = with(context) { self / k }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The same as {@link Space} but with additional multiplication operation
|
||||||
|
*/
|
||||||
|
interface Ring<T> : Space<T> {
|
||||||
|
/**
|
||||||
|
* neutral operation for multiplication
|
||||||
|
*/
|
||||||
|
val one: T
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiplication for two field elements
|
||||||
|
*/
|
||||||
|
fun multiply(a: T, b: T): T
|
||||||
|
|
||||||
|
operator fun T.times(b: T): T = multiply(this, b)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ring element
|
||||||
|
*/
|
||||||
|
interface RingElement<S : RingElement<S>> : SpaceElement<S> {
|
||||||
|
override val context: Ring<S>
|
||||||
|
|
||||||
|
operator fun times(b: S): S = with(context) { self * b }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Four operations algebra
|
||||||
|
*/
|
||||||
|
interface Field<T> : Ring<T> {
|
||||||
|
fun divide(a: T, b: T): T
|
||||||
|
|
||||||
|
operator fun T.div(b: T): T = divide(this, b)
|
||||||
|
operator fun Double.div(b: T) = this * divide(one, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Field element
|
||||||
|
*/
|
||||||
|
interface FieldElement<S : FieldElement<S>> : RingElement<S> {
|
||||||
|
override val context: Field<S>
|
||||||
|
|
||||||
|
operator fun div(b: S): S = with(context) { self / b }
|
||||||
|
}
|
78
common/src/main/kotlin/scientifik/kmath/operations/Fields.kt
Normal file
78
common/src/main/kotlin/scientifik/kmath/operations/Fields.kt
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package scientifik.kmath.operations
|
||||||
|
|
||||||
|
import kotlin.math.sqrt
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Field for real values
|
||||||
|
*/
|
||||||
|
object RealField : Field<Real> {
|
||||||
|
override val zero: Real = Real(0.0)
|
||||||
|
override fun add(a: Real, b: Real): Real = Real(a.value + b.value)
|
||||||
|
override val one: Real = Real(1.0)
|
||||||
|
override fun multiply(a: Real, b: Real): Real = Real(a.value * b.value)
|
||||||
|
override fun multiply(a: Real, k: Double): Real = Real(a.value * k)
|
||||||
|
override fun divide(a: Real, b: Real): Real = Real(a.value / b.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Real field element wrapping double
|
||||||
|
*/
|
||||||
|
class Real(val value: Double) : FieldElement<Real>, Number() {
|
||||||
|
override fun toByte(): Byte = value.toByte()
|
||||||
|
override fun toChar(): Char = value.toChar()
|
||||||
|
override fun toDouble(): Double = value
|
||||||
|
override fun toFloat(): Float = value.toFloat()
|
||||||
|
override fun toInt(): Int = value.toInt()
|
||||||
|
override fun toLong(): Long = value.toLong()
|
||||||
|
override fun toShort(): Short = value.toShort()
|
||||||
|
|
||||||
|
//values are dynamically calculated to save memory
|
||||||
|
override val self
|
||||||
|
get() = this
|
||||||
|
override val context
|
||||||
|
get() = RealField
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field for complex numbers
|
||||||
|
*/
|
||||||
|
object ComplexField : Field<Complex> {
|
||||||
|
override val zero: Complex = Complex(0.0, 0.0)
|
||||||
|
|
||||||
|
override fun add(a: Complex, b: Complex): Complex = Complex(a.re + b.re, a.im + b.im)
|
||||||
|
|
||||||
|
override fun multiply(a: Complex, k: Double): Complex = Complex(a.re * k, a.im * k)
|
||||||
|
|
||||||
|
override val one: Complex = Complex(1.0, 0.0)
|
||||||
|
|
||||||
|
override fun multiply(a: Complex, b: Complex): Complex = Complex(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re)
|
||||||
|
|
||||||
|
override fun divide(a: Complex, b: Complex): Complex = Complex(a.re * b.re + a.im * b.im, a.re * b.im - a.im * b.re) / b.square
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Complex number class
|
||||||
|
*/
|
||||||
|
data class Complex(val re: Double, val im: Double) : FieldElement<Complex> {
|
||||||
|
override val self: Complex
|
||||||
|
get() = this
|
||||||
|
override val context: Field<Complex>
|
||||||
|
get() = ComplexField
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A complex conjugate
|
||||||
|
*/
|
||||||
|
val conjugate: Complex
|
||||||
|
get() = Complex(re, -im)
|
||||||
|
|
||||||
|
val square: Double
|
||||||
|
get() = re * re + im * im
|
||||||
|
|
||||||
|
val module: Double
|
||||||
|
get() = sqrt(square)
|
||||||
|
|
||||||
|
|
||||||
|
//TODO is it convenient?
|
||||||
|
operator fun not() = conjugate
|
||||||
|
}
|
120
common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt
Normal file
120
common/src/main/kotlin/scientifik/kmath/structures/NDArray.kt
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.FieldElement
|
||||||
|
import scientifik.kmath.operations.Real
|
||||||
|
|
||||||
|
class ShapeMismatchException(val expected: List<Int>, val actual: List<Int>) : RuntimeException()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Field for n-dimensional arrays.
|
||||||
|
* @param shape - the list of dimensions of the array
|
||||||
|
* @param elementField - operations field defined on individual array element
|
||||||
|
*/
|
||||||
|
abstract class NDField<T : FieldElement<T>>(val shape: List<Int>, val elementField: Field<T>) : Field<NDArray<T>> {
|
||||||
|
/**
|
||||||
|
* Create new instance of NDArray using field shape and given initializer
|
||||||
|
*/
|
||||||
|
abstract fun produce(initializer: (List<Int>) -> T): NDArray<T>
|
||||||
|
|
||||||
|
override val zero: NDArray<T>
|
||||||
|
get() = produce { elementField.zero }
|
||||||
|
|
||||||
|
private fun checkShape(vararg arrays: NDArray<T>) {
|
||||||
|
arrays.forEach {
|
||||||
|
if (shape != it.shape) {
|
||||||
|
throw ShapeMismatchException(shape, it.shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-by-element addition
|
||||||
|
*/
|
||||||
|
override fun add(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
|
checkShape(a, b)
|
||||||
|
return produce { a[it] + b[it] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Multiply all elements by cinstant
|
||||||
|
*/
|
||||||
|
override fun multiply(a: NDArray<T>, k: Double): NDArray<T> {
|
||||||
|
checkShape(a)
|
||||||
|
return produce { a[it] * k }
|
||||||
|
}
|
||||||
|
|
||||||
|
override val one: NDArray<T>
|
||||||
|
get() = produce { elementField.one }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-by-element multiplication
|
||||||
|
*/
|
||||||
|
override fun multiply(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
|
checkShape(a)
|
||||||
|
return produce { a[it] * b[it] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Element-by-element division
|
||||||
|
*/
|
||||||
|
override fun divide(a: NDArray<T>, b: NDArray<T>): NDArray<T> {
|
||||||
|
checkShape(a)
|
||||||
|
return produce { a[it] / b[it] }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
interface NDArray<T : FieldElement<T>> : FieldElement<NDArray<T>>, Iterable<Pair<List<Int>, T>> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The list of dimensions of this NDArray
|
||||||
|
*/
|
||||||
|
val shape: List<Int>
|
||||||
|
get() = (context as NDField<T>).shape
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of dimentsions for this array
|
||||||
|
*/
|
||||||
|
val dimension: Int
|
||||||
|
get() = shape.size
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the element with given indexes. If number of indexes is different from {@link dimension}, throws exception.
|
||||||
|
*/
|
||||||
|
operator fun get(vararg index: Int): T
|
||||||
|
|
||||||
|
operator fun get(index: List<Int>): T {
|
||||||
|
return get(*index.toIntArray())
|
||||||
|
}
|
||||||
|
|
||||||
|
override operator fun iterator(): Iterator<Pair<List<Int>, T>> {
|
||||||
|
return iterateIndexes(shape).map { Pair(it, this[it]) }.iterator()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate new NDArray, using given transformation for each element
|
||||||
|
*/
|
||||||
|
fun transform(action: (List<Int>, T) -> T): NDArray<T> = (context as NDField<T>).produce { action(it, this[it]) }
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
/**
|
||||||
|
* Iterate over all indexes in the nd-shape
|
||||||
|
*/
|
||||||
|
fun iterateIndexes(shape: List<Int>): Sequence<List<Int>> {
|
||||||
|
return if (shape.size == 1) {
|
||||||
|
(0 until shape[0]).asSequence().map { listOf(it) }
|
||||||
|
} else {
|
||||||
|
val tailShape = ArrayList(shape).apply { remove(0) }
|
||||||
|
val tailSequence: List<List<Int>> = iterateIndexes(tailShape).toList()
|
||||||
|
(0 until shape[0]).asSequence().map { firstIndex ->
|
||||||
|
//adding first element to each of provided index lists
|
||||||
|
tailSequence.map { listOf(firstIndex) + it }.asSequence()
|
||||||
|
}.flatten()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
expect fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real>
|
33
jvm/build.gradle
Normal file
33
jvm/build.gradle
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
buildscript {
|
||||||
|
ext.kotlin_version = '1.2.40'
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
dependencies {
|
||||||
|
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
apply plugin: 'kotlin-platform-jvm'
|
||||||
|
|
||||||
|
repositories {
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
expectedBy project(":common")
|
||||||
|
compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlin_version"
|
||||||
|
testCompile "junit:junit:4.12"
|
||||||
|
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
|
||||||
|
testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
|
||||||
|
}
|
||||||
|
|
||||||
|
compileKotlin {
|
||||||
|
kotlinOptions.jvmTarget = "1.8"
|
||||||
|
}
|
||||||
|
compileTestKotlin {
|
||||||
|
kotlinOptions.jvmTarget = "1.8"
|
||||||
|
}
|
||||||
|
sourceCompatibility = "1.8"
|
@ -0,0 +1,64 @@
|
|||||||
|
package scientifik.kmath.structures
|
||||||
|
|
||||||
|
import scientifik.kmath.operations.Field
|
||||||
|
import scientifik.kmath.operations.Real
|
||||||
|
import scientifik.kmath.operations.RealField
|
||||||
|
import java.nio.DoubleBuffer
|
||||||
|
|
||||||
|
private class RealNDField(shape: List<Int>) : NDField<Real>(shape, RealField) {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Strides for memory access
|
||||||
|
*/
|
||||||
|
private val strides: List<Int> by lazy {
|
||||||
|
ArrayList<Int>(shape.size).apply {
|
||||||
|
var current = 1
|
||||||
|
shape.forEach{
|
||||||
|
current *=it
|
||||||
|
add(current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun offset(index: List<Int>): Int {
|
||||||
|
return index.mapIndexed { i, value ->
|
||||||
|
if (value < 0 || value >= shape[i]) {
|
||||||
|
throw RuntimeException("Index out of shape bounds: ($i,$value)")
|
||||||
|
}
|
||||||
|
value * strides[i]
|
||||||
|
}.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
val capacity: Int
|
||||||
|
get() = strides[shape.size - 1]
|
||||||
|
|
||||||
|
|
||||||
|
override fun produce(initializer: (List<Int>) -> Real): NDArray<Real> {
|
||||||
|
//TODO use sparse arrays for large capacities
|
||||||
|
val buffer = DoubleBuffer.allocate(capacity)
|
||||||
|
NDArray.iterateIndexes(shape).forEach {
|
||||||
|
buffer.put(offset(it), initializer(it).value)
|
||||||
|
}
|
||||||
|
return RealNDArray(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
inner class RealNDArray(val data: DoubleBuffer) : NDArray<Real> {
|
||||||
|
|
||||||
|
override val context: Field<NDArray<Real>>
|
||||||
|
get() = this@RealNDField
|
||||||
|
|
||||||
|
override fun get(vararg index: Int): Real {
|
||||||
|
return Real(data.get(offset(index.asList())))
|
||||||
|
}
|
||||||
|
|
||||||
|
override val self: NDArray<Real>
|
||||||
|
get() = this
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
actual fun RealNDArray(shape: List<Int>, initializer: (List<Int>) -> Double): NDArray<Real> {
|
||||||
|
//TODO cache fields?
|
||||||
|
return RealNDField(shape).produce { Real(initializer(it)) }
|
||||||
|
}
|
4
settings.gradle
Normal file
4
settings.gradle
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
rootProject.name = 'kmath'
|
||||||
|
include 'common'
|
||||||
|
include 'jvm'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user