forked from kscience/kmath
[WIP] Refactor NDStructures
This commit is contained in:
parent
061398b009
commit
332c04b573
@ -1,17 +1,17 @@
|
|||||||
package kscience.kmath.operations
|
package kscience.kmath.operations
|
||||||
|
|
||||||
import kscience.kmath.nd.NDField
|
import kscience.kmath.nd.NDAlgebra
|
||||||
import kscience.kmath.structures.NDElement
|
import kscience.kmath.nd.complex
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
// 2d element
|
// 2d element
|
||||||
val element = NDElement.complex(2, 2) { (i,j) ->
|
val element = NDAlgebra.complex(2, 2).produce { (i,j) ->
|
||||||
Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
|
Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
|
||||||
}
|
}
|
||||||
println(element)
|
println(element)
|
||||||
|
|
||||||
// 1d element operation
|
// 1d element operation
|
||||||
val result = with(NDField.complex(8)) {
|
val result = with(NDAlgebra.complex(8)) {
|
||||||
val a = produce { (it) -> i * it - it.toDouble() }
|
val a = produce { (it) -> i * it - it.toDouble() }
|
||||||
val b = 3
|
val b = 3
|
||||||
val c = Complex(1.0, 1.0)
|
val c = Complex(1.0, 1.0)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
import kotlinx.coroutines.GlobalScope
|
import kotlinx.coroutines.GlobalScope
|
||||||
import kscience.kmath.nd.*
|
import kscience.kmath.nd.NDAlgebra
|
||||||
|
import kscience.kmath.nd.NDStructure
|
||||||
|
import kscience.kmath.nd.field
|
||||||
|
import kscience.kmath.nd.real
|
||||||
import kscience.kmath.nd4j.Nd4jArrayField
|
import kscience.kmath.nd4j.Nd4jArrayField
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
@ -39,9 +42,11 @@ fun main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
measureAndPrint("Element addition") {
|
measureAndPrint("Element addition") {
|
||||||
var res: NDStructure<Double> = boxingField.one
|
boxingField {
|
||||||
|
var res: NDStructure<Double> = one
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
measureAndPrint("Specialized addition") {
|
measureAndPrint("Specialized addition") {
|
||||||
specializedField {
|
specializedField {
|
||||||
@ -52,7 +57,7 @@ fun main() {
|
|||||||
|
|
||||||
measureAndPrint("Nd4j specialized addition") {
|
measureAndPrint("Nd4j specialized addition") {
|
||||||
nd4jField {
|
nd4jField {
|
||||||
var res = one
|
var res:NDStructure<Double> = one
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
package kscience.kmath.commons.linear
|
package kscience.kmath.commons.linear
|
||||||
|
|
||||||
import kscience.kmath.linear.DiagonalFeature
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.linear.MatrixContext
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.misc.UnstableKMathAPI
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kscience.kmath.nd.Matrix
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
import kotlin.reflect.cast
|
import kotlin.reflect.cast
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package kscience.kmath.commons.linear
|
package kscience.kmath.commons.linear
|
||||||
|
|
||||||
|
import kscience.kmath.linear.Matrix
|
||||||
import kscience.kmath.linear.Point
|
import kscience.kmath.linear.Point
|
||||||
import kscience.kmath.nd.Matrix
|
|
||||||
import org.apache.commons.math3.linear.*
|
import org.apache.commons.math3.linear.*
|
||||||
|
|
||||||
public enum class CMDecomposition {
|
public enum class CMDecomposition {
|
||||||
|
@ -86,7 +86,7 @@ public fun <T, A : Space<T>> NDAlgebra.Companion.space(
|
|||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedNDSpace<T, A> = BufferedNDSpace(shape, space, bufferFactory)
|
): BufferedNDSpace<T, A> = BufferedNDSpace(shape, space, bufferFactory)
|
||||||
|
|
||||||
public inline fun <T, A : Space<T>, R> A.nd(
|
public inline fun <T, A : Space<T>, R> A.ndSpace(
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
noinline bufferFactory: BufferFactory<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
action: BufferedNDSpace<T, A>.() -> R,
|
action: BufferedNDSpace<T, A>.() -> R,
|
||||||
@ -102,7 +102,7 @@ public fun <T, A : Ring<T>> NDAlgebra.Companion.ring(
|
|||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedNDRing<T, A> = BufferedNDRing(shape, ring, bufferFactory)
|
): BufferedNDRing<T, A> = BufferedNDRing(shape, ring, bufferFactory)
|
||||||
|
|
||||||
public inline fun <T, A : Ring<T>, R> A.nd(
|
public inline fun <T, A : Ring<T>, R> A.ndRing(
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
noinline bufferFactory: BufferFactory<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
action: BufferedNDRing<T, A>.() -> R,
|
action: BufferedNDRing<T, A>.() -> R,
|
||||||
@ -118,7 +118,7 @@ public fun <T, A : Field<T>> NDAlgebra.Companion.field(
|
|||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
|
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
|
||||||
|
|
||||||
public inline fun <T, A : Field<T>, R> A.nd(
|
public inline fun <T, A : Field<T>, R> A.ndField(
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
noinline bufferFactory: BufferFactory<T>,
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
action: BufferedNDField<T, A>.() -> R,
|
action: BufferedNDField<T, A>.() -> R,
|
||||||
|
@ -94,9 +94,12 @@ public open class RealNDField(
|
|||||||
/**
|
/**
|
||||||
* Fast element production using function inlining
|
* Fast element production using function inlining
|
||||||
*/
|
*/
|
||||||
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): NDBuffer<Double> {
|
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
||||||
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
val array = DoubleArray(strides.linearSize) { offset ->
|
||||||
|
val index = strides.index(offset)
|
||||||
|
RealField.initializer(index)
|
||||||
|
}
|
||||||
return NDBuffer(strides, RealBuffer(array))
|
return NDBuffer(strides, RealBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +52,15 @@ public interface Structure2D<T> : NDStructure<T> {
|
|||||||
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||||
}
|
}
|
||||||
|
|
||||||
public companion object
|
public companion object {
|
||||||
|
public inline fun real(
|
||||||
|
rows: Int,
|
||||||
|
columns: Int,
|
||||||
|
crossinline init: (i: Int, j: Int) -> Double,
|
||||||
|
): NDBuffer<Double> = NDAlgebra.real(rows, columns).produceInline { (i, j) ->
|
||||||
|
init(i, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package kscience.kmath.operations
|
package kscience.kmath.operations
|
||||||
|
|
||||||
import kscience.kmath.misc.UnstableKMathAPI
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import kscience.kmath.nd.BufferedNDRing
|
||||||
import kscience.kmath.nd.NDAlgebra
|
import kscience.kmath.nd.NDAlgebra
|
||||||
import kscience.kmath.operations.BigInt.Companion.BASE
|
import kscience.kmath.operations.BigInt.Companion.BASE
|
||||||
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||||
import kscience.kmath.structures.*
|
import kscience.kmath.structures.Buffer
|
||||||
|
import kscience.kmath.structures.MutableBuffer
|
||||||
import kotlin.math.log2
|
import kotlin.math.log2
|
||||||
import kotlin.math.max
|
import kotlin.math.max
|
||||||
import kotlin.math.min
|
import kotlin.math.min
|
||||||
@ -463,10 +465,5 @@ public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigIn
|
|||||||
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
||||||
boxing(size, initializer)
|
boxing(size, initializer)
|
||||||
|
|
||||||
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BufferedNDRing<BigInt, BigIntField> =
|
||||||
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
BufferedNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||||
|
|
||||||
public fun NDElement.Companion.bigInt(
|
|
||||||
vararg shape: Int,
|
|
||||||
initializer: BigIntField.(IntArray) -> BigInt
|
|
||||||
): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
|
|
||||||
|
@ -43,7 +43,7 @@ public interface Buffer<T> {
|
|||||||
* Checks content equality with another buffer.
|
* Checks content equality with another buffer.
|
||||||
*/
|
*/
|
||||||
public fun contentEquals(other: Buffer<*>): Boolean =
|
public fun contentEquals(other: Buffer<*>): Boolean =
|
||||||
kscience.kmath.nd.mapIndexed { index, value -> value == other[index] }.all { it }
|
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
|
@ -60,7 +60,7 @@ public class FlaggedRealBuffer(public val values: DoubleArray, public val flags:
|
|||||||
|
|
||||||
override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||||
|
|
||||||
override operator fun iterator(): Iterator<Double?> = kscience.kmath.nd.map {
|
override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||||
if (isValid(it)) values[it] else null
|
if (isValid(it)) values[it] else null
|
||||||
}.iterator()
|
}.iterator()
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,7 @@ public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected
|
|||||||
private val reader: MemoryReader = memory.reader()
|
private val reader: MemoryReader = memory.reader()
|
||||||
|
|
||||||
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||||
override operator fun iterator(): Iterator<T> = kscience.kmath.nd.map { get(it) }.iterator()
|
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
import kscience.kmath.nd.NDField
|
import kscience.kmath.nd.NDAlgebra
|
||||||
|
import kscience.kmath.nd.get
|
||||||
|
import kscience.kmath.nd.real
|
||||||
import kscience.kmath.operations.internal.FieldVerifier
|
import kscience.kmath.operations.internal.FieldVerifier
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
@ -8,12 +10,12 @@ import kotlin.test.assertEquals
|
|||||||
internal class NDFieldTest {
|
internal class NDFieldTest {
|
||||||
@Test
|
@Test
|
||||||
fun verify() {
|
fun verify() {
|
||||||
NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
NDAlgebra.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testStrides() {
|
fun testStrides() {
|
||||||
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
val ndArray = NDAlgebra.real(10, 10).produce { (it[0] + it[1]).toDouble() }
|
||||||
assertEquals(ndArray[5, 5], 10.0)
|
assertEquals(ndArray[5, 5], 10.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
package kscience.kmath.structures
|
package kscience.kmath.structures
|
||||||
|
|
||||||
import kscience.kmath.nd.NDField
|
import kscience.kmath.nd.*
|
||||||
import kscience.kmath.nd.NDStructure
|
|
||||||
import kscience.kmath.operations.Norm
|
import kscience.kmath.operations.Norm
|
||||||
import kscience.kmath.structures.NDElement.Companion.real2D
|
import kscience.kmath.operations.invoke
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.math.pow
|
import kotlin.math.pow
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -11,25 +10,30 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
@Suppress("UNUSED_VARIABLE")
|
@Suppress("UNUSED_VARIABLE")
|
||||||
class NumberNDFieldTest {
|
class NumberNDFieldTest {
|
||||||
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
|
val algebra = NDAlgebra.real(3,3)
|
||||||
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
|
val array1 = algebra.produceInline { (i, j) -> (i + j).toDouble() }
|
||||||
|
val array2 = algebra.produceInline { (i, j) -> (i - j).toDouble() }
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSum() {
|
fun testSum() {
|
||||||
|
algebra {
|
||||||
val sum = array1 + array2
|
val sum = array1 + array2
|
||||||
assertEquals(4.0, sum[2, 2])
|
assertEquals(4.0, sum[2, 2])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testProduct() {
|
fun testProduct() {
|
||||||
|
algebra {
|
||||||
val product = array1 * array2
|
val product = array1 * array2
|
||||||
assertEquals(0.0, product[2, 2])
|
assertEquals(0.0, product[2, 2])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testGeneration() {
|
fun testGeneration() {
|
||||||
|
|
||||||
val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
val array = Structure2D.real(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||||
|
|
||||||
for (i in 0..2) {
|
for (i in 0..2) {
|
||||||
for (j in 0..2) {
|
for (j in 0..2) {
|
||||||
@ -41,17 +45,21 @@ class NumberNDFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testExternalFunction() {
|
fun testExternalFunction() {
|
||||||
|
algebra {
|
||||||
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||||
val result = function(array1) + 1.0
|
val result = function(array1) + 1.0
|
||||||
assertEquals(10.0, result[1, 1])
|
assertEquals(10.0, result[1, 1])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testLibraryFunction() {
|
fun testLibraryFunction() {
|
||||||
|
algebra {
|
||||||
val abs: (Double) -> Double = ::abs
|
val abs: (Double) -> Double = ::abs
|
||||||
val result = abs(array2)
|
val result = abs(array2)
|
||||||
assertEquals(2.0, result[0, 2])
|
assertEquals(2.0, result[0, 2])
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun combineTest() {
|
fun combineTest() {
|
||||||
@ -65,6 +73,8 @@ class NumberNDFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testInternalContext() {
|
fun testInternalContext() {
|
||||||
(NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
|
algebra {
|
||||||
|
(NDAlgebra.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ public class LazyNDStructure<T>(
|
|||||||
|
|
||||||
public override fun elements(): Sequence<Pair<IntArray, T>> {
|
public override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
val res = runBlocking { kscience.kmath.nd.map { index -> index to await(index) } }
|
val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } }
|
||||||
return res.asSequence()
|
return res.asSequence()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
package kscience.kmath.ejml
|
package kscience.kmath.ejml
|
||||||
|
|
||||||
import kscience.kmath.linear.InverseMatrixFeature
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.linear.MatrixContext
|
|
||||||
import kscience.kmath.linear.Point
|
|
||||||
import kscience.kmath.misc.UnstableKMathAPI
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kscience.kmath.nd.Matrix
|
|
||||||
import kscience.kmath.nd.getFeature
|
import kscience.kmath.nd.getFeature
|
||||||
import org.ejml.simple.SimpleMatrix
|
import org.ejml.simple.SimpleMatrix
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
package kscience.kmath.real
|
package kscience.kmath.real
|
||||||
|
|
||||||
import kscience.kmath.linear.MatrixContext
|
import kscience.kmath.linear.*
|
||||||
import kscience.kmath.linear.real
|
|
||||||
import kscience.kmath.misc.UnstableKMathAPI
|
import kscience.kmath.misc.UnstableKMathAPI
|
||||||
import kscience.kmath.nd.Matrix
|
|
||||||
import kscience.kmath.structures.Buffer
|
import kscience.kmath.structures.Buffer
|
||||||
import kscience.kmath.structures.RealBuffer
|
import kscience.kmath.structures.RealBuffer
|
||||||
|
import kscience.kmath.structures.asIterable
|
||||||
|
import kotlin.math.pow
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Functions for convenient "numpy-like" operations with Double matrices.
|
* Functions for convenient "numpy-like" operations with Double matrices.
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
package kaceince.kmath.real
|
package kaceince.kmath.real
|
||||||
|
|
||||||
import kscience.kmath.nd.Matrix
|
import kscience.kmath.linear.Matrix
|
||||||
|
import kscience.kmath.linear.build
|
||||||
import kscience.kmath.real.*
|
import kscience.kmath.real.*
|
||||||
|
import kscience.kmath.structures.contentEquals
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
import kotlin.test.assertTrue
|
import kotlin.test.assertTrue
|
||||||
|
@ -17,7 +17,7 @@ public data class BinDef<T : Comparable<T>>(
|
|||||||
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
|
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
|
||||||
val upper = space { center + sizes / 2.0 }
|
val upper = space { center + sizes / 2.0 }
|
||||||
val lower = space { center - sizes / 2.0 }
|
val lower = space { center - sizes / 2.0 }
|
||||||
return kscience.kmath.nd.mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
|
return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ public class RealHistogram(
|
|||||||
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||||
|
|
||||||
private fun getDef(index: IntArray): BinDef<Double> {
|
private fun getDef(index: IntArray): BinDef<Double> {
|
||||||
val center = kscience.kmath.nd.mapIndexed { axis, i ->
|
val center = index.mapIndexed { axis, i ->
|
||||||
when (i) {
|
when (i) {
|
||||||
0 -> Double.NEGATIVE_INFINITY
|
0 -> Double.NEGATIVE_INFINITY
|
||||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||||
@ -100,7 +100,7 @@ public class RealHistogram(
|
|||||||
}
|
}
|
||||||
|
|
||||||
public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
|
public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
|
||||||
kscience.kmath.nd.map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
|
weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
|
||||||
.iterator()
|
.iterator()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -7,6 +7,13 @@ import kscience.kmath.structures.*
|
|||||||
import org.nd4j.linalg.api.ndarray.INDArray
|
import org.nd4j.linalg.api.ndarray.INDArray
|
||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
|
|
||||||
|
internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray {
|
||||||
|
val arrayShape = array.shape().toIntArray()
|
||||||
|
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
|
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
|
||||||
*
|
*
|
||||||
@ -18,6 +25,7 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C> {
|
|||||||
* Wraps [INDArray] to [N].
|
* Wraps [INDArray] to [N].
|
||||||
*/
|
*/
|
||||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||||
|
|
||||||
public val NDStructure<T>.ndArray: INDArray
|
public val NDStructure<T>.ndArray: INDArray
|
||||||
get() = when {
|
get() = when {
|
||||||
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
|
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
|
||||||
@ -213,7 +221,7 @@ public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArray
|
|||||||
public override val elementContext: RealField
|
public override val elementContext: RealField
|
||||||
get() = RealField
|
get() = RealField
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(asRealStructure())
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asRealStructure()
|
||||||
|
|
||||||
public override operator fun NDStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
public override operator fun NDStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
||||||
return ndArray.div(arg).wrap()
|
return ndArray.div(arg).wrap()
|
||||||
@ -247,7 +255,7 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
|
|||||||
public override val elementContext: FloatField
|
public override val elementContext: FloatField
|
||||||
get() = FloatField
|
get() = FloatField
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(asFloatStructure())
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||||
|
|
||||||
public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
|
public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
|
||||||
return ndArray.div(arg).wrap()
|
return ndArray.div(arg).wrap()
|
||||||
@ -281,7 +289,7 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
|
|||||||
public override val elementContext: IntRing
|
public override val elementContext: IntRing
|
||||||
get() = IntRing
|
get() = IntRing
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = check(asIntStructure())
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
|
||||||
|
|
||||||
public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
|
public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
|
||||||
return ndArray.add(arg).wrap()
|
return ndArray.add(arg).wrap()
|
||||||
@ -307,7 +315,7 @@ public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayR
|
|||||||
public override val elementContext: LongRing
|
public override val elementContext: LongRing
|
||||||
get() = LongRing
|
get() = LongRing
|
||||||
|
|
||||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = check(asLongStructure())
|
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
|
||||||
|
|
||||||
public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
|
public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
|
||||||
return ndArray.add(arg).wrap()
|
return ndArray.add(arg).wrap()
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
package kscience.kmath.viktor
|
package kscience.kmath.viktor
|
||||||
|
|
||||||
import kscience.kmath.nd.DefaultStrides
|
import kscience.kmath.nd.*
|
||||||
import kscience.kmath.nd.MutableNDStructure
|
|
||||||
import kscience.kmath.nd.NDField
|
|
||||||
import kscience.kmath.nd.Strides
|
|
||||||
import kscience.kmath.operations.RealField
|
import kscience.kmath.operations.RealField
|
||||||
import org.jetbrains.bio.viktor.F64Array
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
|
|
||||||
@ -24,14 +21,25 @@ public inline class ViktorNDStructure(public val f64Buffer: F64Array) : MutableN
|
|||||||
public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this)
|
public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this)
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField, ViktorNDStructure> {
|
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField> {
|
||||||
|
|
||||||
|
public val NDStructure<Double>.f64Buffer: F64Array
|
||||||
|
get() = when {
|
||||||
|
!shape.contentEquals(this@ViktorNDField.shape) -> throw ShapeMismatchException(
|
||||||
|
this@ViktorNDField.shape,
|
||||||
|
shape
|
||||||
|
)
|
||||||
|
this is ViktorNDStructure && this.f64Buffer.shape.contentEquals(this@ViktorNDField.shape) -> this.f64Buffer
|
||||||
|
else -> produce { this@f64Buffer[it] }.f64Buffer
|
||||||
|
}
|
||||||
|
|
||||||
public override val zero: ViktorNDStructure
|
public override val zero: ViktorNDStructure
|
||||||
get() = F64Array.full(init = 0.0, shape = shape).asStructure()
|
get() = F64Array.full(init = 0.0, shape = shape).asStructure()
|
||||||
|
|
||||||
public override val one: ViktorNDStructure
|
public override val one: ViktorNDStructure
|
||||||
get() = F64Array.full(init = 1.0, shape = shape).asStructure()
|
get() = F64Array.full(init = 1.0, shape = shape).asStructure()
|
||||||
|
|
||||||
public val strides: Strides = DefaultStrides(shape)
|
private val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
public override val elementContext: RealField get() = RealField
|
public override val elementContext: RealField get() = RealField
|
||||||
|
|
||||||
@ -42,7 +50,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
|||||||
}
|
}
|
||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
public override fun map(arg: ViktorNDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure =
|
public override fun map(arg: NDStructure<Double>, transform: RealField.(Double) -> Double): ViktorNDStructure =
|
||||||
F64Array(*shape).apply {
|
F64Array(*shape).apply {
|
||||||
this@ViktorNDField.strides.indices().forEach { index ->
|
this@ViktorNDField.strides.indices().forEach { index ->
|
||||||
set(value = RealField.transform(arg[index]), indices = index)
|
set(value = RealField.transform(arg[index]), indices = index)
|
||||||
@ -50,7 +58,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
|||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
public override fun mapIndexed(
|
public override fun mapIndexed(
|
||||||
arg: ViktorNDStructure,
|
arg: NDStructure<Double>,
|
||||||
transform: RealField.(index: IntArray, Double) -> Double
|
transform: RealField.(index: IntArray, Double) -> Double
|
||||||
): ViktorNDStructure = F64Array(*shape).apply {
|
): ViktorNDStructure = F64Array(*shape).apply {
|
||||||
this@ViktorNDField.strides.indices().forEach { index ->
|
this@ViktorNDField.strides.indices().forEach { index ->
|
||||||
@ -59,8 +67,8 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
|||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
public override fun combine(
|
public override fun combine(
|
||||||
a: ViktorNDStructure,
|
a: NDStructure<Double>,
|
||||||
b: ViktorNDStructure,
|
b: NDStructure<Double>,
|
||||||
transform: RealField.(Double, Double) -> Double
|
transform: RealField.(Double, Double) -> Double
|
||||||
): ViktorNDStructure = F64Array(*shape).apply {
|
): ViktorNDStructure = F64Array(*shape).apply {
|
||||||
this@ViktorNDField.strides.indices().forEach { index ->
|
this@ViktorNDField.strides.indices().forEach { index ->
|
||||||
@ -68,21 +76,21 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
|||||||
}
|
}
|
||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
public override fun add(a: ViktorNDStructure, b: ViktorNDStructure): ViktorNDStructure =
|
public override fun add(a: NDStructure<Double>, b: NDStructure<Double>): ViktorNDStructure =
|
||||||
(a.f64Buffer + b.f64Buffer).asStructure()
|
(a.f64Buffer + b.f64Buffer).asStructure()
|
||||||
|
|
||||||
public override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure =
|
public override fun multiply(a: NDStructure<Double>, k: Number): ViktorNDStructure =
|
||||||
(a.f64Buffer * k.toDouble()).asStructure()
|
(a.f64Buffer * k.toDouble()).asStructure()
|
||||||
|
|
||||||
public override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure =
|
public override inline fun NDStructure<Double>.plus(b: NDStructure<Double>): ViktorNDStructure =
|
||||||
(f64Buffer + b.f64Buffer).asStructure()
|
(f64Buffer + b.f64Buffer).asStructure()
|
||||||
|
|
||||||
public override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure =
|
public override inline fun NDStructure<Double>.minus(b: NDStructure<Double>): ViktorNDStructure =
|
||||||
(f64Buffer - b.f64Buffer).asStructure()
|
(f64Buffer - b.f64Buffer).asStructure()
|
||||||
|
|
||||||
public override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure =
|
public override inline fun NDStructure<Double>.times(k: Number): ViktorNDStructure =
|
||||||
(f64Buffer * k.toDouble()).asStructure()
|
(f64Buffer * k.toDouble()).asStructure()
|
||||||
|
|
||||||
public override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure =
|
public override inline fun NDStructure<Double>.plus(arg: Double): ViktorNDStructure =
|
||||||
(f64Buffer.plus(arg)).asStructure()
|
(f64Buffer.plus(arg)).asStructure()
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user