forked from kscience/kmath
[WIP] Refactor NDStructures
This commit is contained in:
parent
061398b009
commit
332c04b573
@ -1,17 +1,17 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.nd.NDField
|
||||
import kscience.kmath.structures.NDElement
|
||||
import kscience.kmath.nd.NDAlgebra
|
||||
import kscience.kmath.nd.complex
|
||||
|
||||
fun main() {
|
||||
// 2d element
|
||||
val element = NDElement.complex(2, 2) { (i,j) ->
|
||||
val element = NDAlgebra.complex(2, 2).produce { (i,j) ->
|
||||
Complex(i.toDouble() - j.toDouble(), i.toDouble() + j.toDouble())
|
||||
}
|
||||
println(element)
|
||||
|
||||
// 1d element operation
|
||||
val result = with(NDField.complex(8)) {
|
||||
val result = with(NDAlgebra.complex(8)) {
|
||||
val a = produce { (it) -> i * it - it.toDouble() }
|
||||
val b = 3
|
||||
val c = Complex(1.0, 1.0)
|
||||
|
@ -1,7 +1,10 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kotlinx.coroutines.GlobalScope
|
||||
import kscience.kmath.nd.*
|
||||
import kscience.kmath.nd.NDAlgebra
|
||||
import kscience.kmath.nd.NDStructure
|
||||
import kscience.kmath.nd.field
|
||||
import kscience.kmath.nd.real
|
||||
import kscience.kmath.nd4j.Nd4jArrayField
|
||||
import kscience.kmath.operations.RealField
|
||||
import kscience.kmath.operations.invoke
|
||||
@ -39,8 +42,10 @@ fun main() {
|
||||
}
|
||||
|
||||
measureAndPrint("Element addition") {
|
||||
var res: NDStructure<Double> = boxingField.one
|
||||
repeat(n) { res += 1.0 }
|
||||
boxingField {
|
||||
var res: NDStructure<Double> = one
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
||||
measureAndPrint("Specialized addition") {
|
||||
@ -52,7 +57,7 @@ fun main() {
|
||||
|
||||
measureAndPrint("Nd4j specialized addition") {
|
||||
nd4jField {
|
||||
var res = one
|
||||
var res:NDStructure<Double> = one
|
||||
repeat(n) { res += 1.0 }
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,7 @@
|
||||
package kscience.kmath.commons.linear
|
||||
|
||||
import kscience.kmath.linear.DiagonalFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.nd.Matrix
|
||||
import org.apache.commons.math3.linear.*
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.cast
|
||||
|
@ -1,7 +1,7 @@
|
||||
package kscience.kmath.commons.linear
|
||||
|
||||
import kscience.kmath.linear.Matrix
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.nd.Matrix
|
||||
import org.apache.commons.math3.linear.*
|
||||
|
||||
public enum class CMDecomposition {
|
||||
|
@ -86,7 +86,7 @@ public fun <T, A : Space<T>> NDAlgebra.Companion.space(
|
||||
vararg shape: Int,
|
||||
): BufferedNDSpace<T, A> = BufferedNDSpace(shape, space, bufferFactory)
|
||||
|
||||
public inline fun <T, A : Space<T>, R> A.nd(
|
||||
public inline fun <T, A : Space<T>, R> A.ndSpace(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedNDSpace<T, A>.() -> R,
|
||||
@ -102,7 +102,7 @@ public fun <T, A : Ring<T>> NDAlgebra.Companion.ring(
|
||||
vararg shape: Int,
|
||||
): BufferedNDRing<T, A> = BufferedNDRing(shape, ring, bufferFactory)
|
||||
|
||||
public inline fun <T, A : Ring<T>, R> A.nd(
|
||||
public inline fun <T, A : Ring<T>, R> A.ndRing(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedNDRing<T, A>.() -> R,
|
||||
@ -118,7 +118,7 @@ public fun <T, A : Field<T>> NDAlgebra.Companion.field(
|
||||
vararg shape: Int,
|
||||
): BufferedNDField<T, A> = BufferedNDField(shape, field, bufferFactory)
|
||||
|
||||
public inline fun <T, A : Field<T>, R> A.nd(
|
||||
public inline fun <T, A : Field<T>, R> A.ndField(
|
||||
noinline bufferFactory: BufferFactory<T>,
|
||||
vararg shape: Int,
|
||||
action: BufferedNDField<T, A>.() -> R,
|
||||
|
@ -94,9 +94,12 @@ public open class RealNDField(
|
||||
/**
|
||||
* Fast element production using function inlining
|
||||
*/
|
||||
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(Int) -> Double): NDBuffer<Double> {
|
||||
public inline fun BufferedNDField<Double, RealField>.produceInline(crossinline initializer: RealField.(IntArray) -> Double): NDBuffer<Double> {
|
||||
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
||||
val array = DoubleArray(strides.linearSize) { offset -> RealField.initializer(offset) }
|
||||
val array = DoubleArray(strides.linearSize) { offset ->
|
||||
val index = strides.index(offset)
|
||||
RealField.initializer(index)
|
||||
}
|
||||
return NDBuffer(strides, RealBuffer(array))
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,15 @@ public interface Structure2D<T> : NDStructure<T> {
|
||||
for (j in 0 until colNum) yield(intArrayOf(i, j) to get(i, j))
|
||||
}
|
||||
|
||||
public companion object
|
||||
public companion object {
|
||||
public inline fun real(
|
||||
rows: Int,
|
||||
columns: Int,
|
||||
crossinline init: (i: Int, j: Int) -> Double,
|
||||
): NDBuffer<Double> = NDAlgebra.real(rows, columns).produceInline { (i, j) ->
|
||||
init(i, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,10 +1,12 @@
|
||||
package kscience.kmath.operations
|
||||
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.nd.BufferedNDRing
|
||||
import kscience.kmath.nd.NDAlgebra
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE
|
||||
import kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||
import kscience.kmath.structures.*
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.MutableBuffer
|
||||
import kotlin.math.log2
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
@ -463,10 +465,5 @@ public inline fun Buffer.Companion.bigInt(size: Int, initializer: (Int) -> BigIn
|
||||
public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -> BigInt): MutableBuffer<BigInt> =
|
||||
boxing(size, initializer)
|
||||
|
||||
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BoxingNDRing<BigInt, BigIntField> =
|
||||
BoxingNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||
|
||||
public fun NDElement.Companion.bigInt(
|
||||
vararg shape: Int,
|
||||
initializer: BigIntField.(IntArray) -> BigInt
|
||||
): BufferedNDRingElement<BigInt, BigIntField> = NDAlgebra.bigInt(*shape).produce(initializer)
|
||||
public fun NDAlgebra.Companion.bigInt(vararg shape: Int): BufferedNDRing<BigInt, BigIntField> =
|
||||
BufferedNDRing(shape, BigIntField, Buffer.Companion::bigInt)
|
||||
|
@ -43,7 +43,7 @@ public interface Buffer<T> {
|
||||
* Checks content equality with another buffer.
|
||||
*/
|
||||
public fun contentEquals(other: Buffer<*>): Boolean =
|
||||
kscience.kmath.nd.mapIndexed { index, value -> value == other[index] }.all { it }
|
||||
asSequence().mapIndexed { index, value -> value == other[index] }.all { it }
|
||||
|
||||
public companion object {
|
||||
/**
|
||||
|
@ -60,7 +60,7 @@ public class FlaggedRealBuffer(public val values: DoubleArray, public val flags:
|
||||
|
||||
override operator fun get(index: Int): Double? = if (isValid(index)) values[index] else null
|
||||
|
||||
override operator fun iterator(): Iterator<Double?> = kscience.kmath.nd.map {
|
||||
override operator fun iterator(): Iterator<Double?> = values.indices.asSequence().map {
|
||||
if (isValid(it)) values[it] else null
|
||||
}.iterator()
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ public open class MemoryBuffer<T : Any>(protected val memory: Memory, protected
|
||||
private val reader: MemoryReader = memory.reader()
|
||||
|
||||
override operator fun get(index: Int): T = reader.read(spec, spec.objectSize * index)
|
||||
override operator fun iterator(): Iterator<T> = kscience.kmath.nd.map { get(it) }.iterator()
|
||||
override operator fun iterator(): Iterator<T> = (0 until size).asSequence().map { get(it) }.iterator()
|
||||
|
||||
public companion object {
|
||||
public fun <T : Any> create(spec: MemorySpec<T>, size: Int): MemoryBuffer<T> =
|
||||
|
@ -1,6 +1,8 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.nd.NDField
|
||||
import kscience.kmath.nd.NDAlgebra
|
||||
import kscience.kmath.nd.get
|
||||
import kscience.kmath.nd.real
|
||||
import kscience.kmath.operations.internal.FieldVerifier
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -8,12 +10,12 @@ import kotlin.test.assertEquals
|
||||
internal class NDFieldTest {
|
||||
@Test
|
||||
fun verify() {
|
||||
NDField.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||
NDAlgebra.real(12, 32).run { FieldVerifier(this, one + 3, one - 23, one * 12, 6.66) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testStrides() {
|
||||
val ndArray = NDElement.real(intArrayOf(10, 10)) { (it[0] + it[1]).toDouble() }
|
||||
val ndArray = NDAlgebra.real(10, 10).produce { (it[0] + it[1]).toDouble() }
|
||||
assertEquals(ndArray[5, 5], 10.0)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
package kscience.kmath.structures
|
||||
|
||||
import kscience.kmath.nd.NDField
|
||||
import kscience.kmath.nd.NDStructure
|
||||
import kscience.kmath.nd.*
|
||||
import kscience.kmath.operations.Norm
|
||||
import kscience.kmath.structures.NDElement.Companion.real2D
|
||||
import kscience.kmath.operations.invoke
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.pow
|
||||
import kotlin.test.Test
|
||||
@ -11,25 +10,30 @@ import kotlin.test.assertEquals
|
||||
|
||||
@Suppress("UNUSED_VARIABLE")
|
||||
class NumberNDFieldTest {
|
||||
val array1: RealNDElement = real2D(3, 3) { i, j -> (i + j).toDouble() }
|
||||
val array2: RealNDElement = real2D(3, 3) { i, j -> (i - j).toDouble() }
|
||||
val algebra = NDAlgebra.real(3,3)
|
||||
val array1 = algebra.produceInline { (i, j) -> (i + j).toDouble() }
|
||||
val array2 = algebra.produceInline { (i, j) -> (i - j).toDouble() }
|
||||
|
||||
@Test
|
||||
fun testSum() {
|
||||
val sum = array1 + array2
|
||||
assertEquals(4.0, sum[2, 2])
|
||||
algebra {
|
||||
val sum = array1 + array2
|
||||
assertEquals(4.0, sum[2, 2])
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testProduct() {
|
||||
val product = array1 * array2
|
||||
assertEquals(0.0, product[2, 2])
|
||||
algebra {
|
||||
val product = array1 * array2
|
||||
assertEquals(0.0, product[2, 2])
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGeneration() {
|
||||
|
||||
val array = real2D(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||
val array = Structure2D.real(3, 3) { i, j -> (i * 10 + j).toDouble() }
|
||||
|
||||
for (i in 0..2) {
|
||||
for (j in 0..2) {
|
||||
@ -41,16 +45,20 @@ class NumberNDFieldTest {
|
||||
|
||||
@Test
|
||||
fun testExternalFunction() {
|
||||
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||
val result = function(array1) + 1.0
|
||||
assertEquals(10.0, result[1, 1])
|
||||
algebra {
|
||||
val function: (Double) -> Double = { x -> x.pow(2) + 2 * x + 1 }
|
||||
val result = function(array1) + 1.0
|
||||
assertEquals(10.0, result[1, 1])
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLibraryFunction() {
|
||||
val abs: (Double) -> Double = ::abs
|
||||
val result = abs(array2)
|
||||
assertEquals(2.0, result[0, 2])
|
||||
algebra {
|
||||
val abs: (Double) -> Double = ::abs
|
||||
val result = abs(array2)
|
||||
assertEquals(2.0, result[0, 2])
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -65,6 +73,8 @@ class NumberNDFieldTest {
|
||||
|
||||
@Test
|
||||
fun testInternalContext() {
|
||||
(NDField.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
|
||||
algebra {
|
||||
(NDAlgebra.real(*array1.shape)) { with(L2Norm) { 1 + norm(array1) + exp(array2) } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ public class LazyNDStructure<T>(
|
||||
|
||||
public override fun elements(): Sequence<Pair<IntArray, T>> {
|
||||
val strides = DefaultStrides(shape)
|
||||
val res = runBlocking { kscience.kmath.nd.map { index -> index to await(index) } }
|
||||
val res = runBlocking { strides.indices().toList().map { index -> index to await(index) } }
|
||||
return res.asSequence()
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,7 @@
|
||||
package kscience.kmath.ejml
|
||||
|
||||
import kscience.kmath.linear.InverseMatrixFeature
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.Point
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.nd.Matrix
|
||||
import kscience.kmath.nd.getFeature
|
||||
import org.ejml.simple.SimpleMatrix
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
package kscience.kmath.real
|
||||
|
||||
import kscience.kmath.linear.MatrixContext
|
||||
import kscience.kmath.linear.real
|
||||
import kscience.kmath.linear.*
|
||||
import kscience.kmath.misc.UnstableKMathAPI
|
||||
import kscience.kmath.nd.Matrix
|
||||
import kscience.kmath.structures.Buffer
|
||||
import kscience.kmath.structures.RealBuffer
|
||||
import kscience.kmath.structures.asIterable
|
||||
import kotlin.math.pow
|
||||
|
||||
/*
|
||||
* Functions for convenient "numpy-like" operations with Double matrices.
|
||||
|
@ -1,7 +1,9 @@
|
||||
package kaceince.kmath.real
|
||||
|
||||
import kscience.kmath.nd.Matrix
|
||||
import kscience.kmath.linear.Matrix
|
||||
import kscience.kmath.linear.build
|
||||
import kscience.kmath.real.*
|
||||
import kscience.kmath.structures.contentEquals
|
||||
import kotlin.test.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
@ -17,7 +17,7 @@ public data class BinDef<T : Comparable<T>>(
|
||||
require(vector.size == center.size) { "Dimension mismatch for input vector. Expected ${center.size}, but found ${vector.size}" }
|
||||
val upper = space { center + sizes / 2.0 }
|
||||
val lower = space { center - sizes / 2.0 }
|
||||
return kscience.kmath.nd.mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
|
||||
return vector.asSequence().mapIndexed { i, value -> value in lower[i]..upper[i] }.all { it }
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ public class RealHistogram(
|
||||
public fun getValue(point: Buffer<out Double>): Long = getValue(getIndex(point))
|
||||
|
||||
private fun getDef(index: IntArray): BinDef<Double> {
|
||||
val center = kscience.kmath.nd.mapIndexed { axis, i ->
|
||||
val center = index.mapIndexed { axis, i ->
|
||||
when (i) {
|
||||
0 -> Double.NEGATIVE_INFINITY
|
||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||
@ -100,7 +100,7 @@ public class RealHistogram(
|
||||
}
|
||||
|
||||
public override operator fun iterator(): Iterator<MultivariateBin<Double>> =
|
||||
kscience.kmath.nd.map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
|
||||
weights.elements().map { (index, value) -> MultivariateBin(getDef(index), value.sum()) }
|
||||
.iterator()
|
||||
|
||||
/**
|
||||
|
@ -7,6 +7,13 @@ import kscience.kmath.structures.*
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
|
||||
internal fun NDAlgebra<*, *>.checkShape(array: INDArray): INDArray {
|
||||
val arrayShape = array.shape().toIntArray()
|
||||
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
|
||||
return array
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Represents [NDAlgebra] over [Nd4jArrayAlgebra].
|
||||
*
|
||||
@ -18,8 +25,9 @@ public interface Nd4jArrayAlgebra<T, C> : NDAlgebra<T, C> {
|
||||
* Wraps [INDArray] to [N].
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
public val NDStructure<T>.ndArray: INDArray
|
||||
get() = when {
|
||||
get() = when {
|
||||
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
|
||||
this@Nd4jArrayAlgebra.shape,
|
||||
shape
|
||||
@ -213,7 +221,7 @@ public class RealNd4jArrayField(public override val shape: IntArray) : Nd4jArray
|
||||
public override val elementContext: RealField
|
||||
get() = RealField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(asRealStructure())
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asRealStructure()
|
||||
|
||||
public override operator fun NDStructure<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
||||
return ndArray.div(arg).wrap()
|
||||
@ -247,7 +255,7 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
|
||||
public override val elementContext: FloatField
|
||||
get() = FloatField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(asFloatStructure())
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||
|
||||
public override operator fun NDStructure<Float>.div(arg: Float): Nd4jArrayStructure<Float> {
|
||||
return ndArray.div(arg).wrap()
|
||||
@ -281,7 +289,7 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
|
||||
public override val elementContext: IntRing
|
||||
get() = IntRing
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = check(asIntStructure())
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
|
||||
|
||||
public override operator fun NDStructure<Int>.plus(arg: Int): Nd4jArrayStructure<Int> {
|
||||
return ndArray.add(arg).wrap()
|
||||
@ -307,7 +315,7 @@ public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayR
|
||||
public override val elementContext: LongRing
|
||||
get() = LongRing
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = check(asLongStructure())
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
|
||||
|
||||
public override operator fun NDStructure<Long>.plus(arg: Long): Nd4jArrayStructure<Long> {
|
||||
return ndArray.add(arg).wrap()
|
||||
|
@ -1,9 +1,6 @@
|
||||
package kscience.kmath.viktor
|
||||
|
||||
import kscience.kmath.nd.DefaultStrides
|
||||
import kscience.kmath.nd.MutableNDStructure
|
||||
import kscience.kmath.nd.NDField
|
||||
import kscience.kmath.nd.Strides
|
||||
import kscience.kmath.nd.*
|
||||
import kscience.kmath.operations.RealField
|
||||
import org.jetbrains.bio.viktor.F64Array
|
||||
|
||||
@ -24,14 +21,25 @@ public inline class ViktorNDStructure(public val f64Buffer: F64Array) : MutableN
|
||||
public fun F64Array.asStructure(): ViktorNDStructure = ViktorNDStructure(this)
|
||||
|
||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField, ViktorNDStructure> {
|
||||
public class ViktorNDField(public override val shape: IntArray) : NDField<Double, RealField> {
|
||||
|
||||
public val NDStructure<Double>.f64Buffer: F64Array
|
||||
get() = when {
|
||||
!shape.contentEquals(this@ViktorNDField.shape) -> throw ShapeMismatchException(
|
||||
this@ViktorNDField.shape,
|
||||
shape
|
||||
)
|
||||
this is ViktorNDStructure && this.f64Buffer.shape.contentEquals(this@ViktorNDField.shape) -> this.f64Buffer
|
||||
else -> produce { this@f64Buffer[it] }.f64Buffer
|
||||
}
|
||||
|
||||
public override val zero: ViktorNDStructure
|
||||
get() = F64Array.full(init = 0.0, shape = shape).asStructure()
|
||||
|
||||
public override val one: ViktorNDStructure
|
||||
get() = F64Array.full(init = 1.0, shape = shape).asStructure()
|
||||
|
||||
public val strides: Strides = DefaultStrides(shape)
|
||||
private val strides: Strides = DefaultStrides(shape)
|
||||
|
||||
public override val elementContext: RealField get() = RealField
|
||||
|
||||
@ -42,7 +50,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
public override fun map(arg: ViktorNDStructure, transform: RealField.(Double) -> Double): ViktorNDStructure =
|
||||
public override fun map(arg: NDStructure<Double>, transform: RealField.(Double) -> Double): ViktorNDStructure =
|
||||
F64Array(*shape).apply {
|
||||
this@ViktorNDField.strides.indices().forEach { index ->
|
||||
set(value = RealField.transform(arg[index]), indices = index)
|
||||
@ -50,7 +58,7 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
||||
}.asStructure()
|
||||
|
||||
public override fun mapIndexed(
|
||||
arg: ViktorNDStructure,
|
||||
arg: NDStructure<Double>,
|
||||
transform: RealField.(index: IntArray, Double) -> Double
|
||||
): ViktorNDStructure = F64Array(*shape).apply {
|
||||
this@ViktorNDField.strides.indices().forEach { index ->
|
||||
@ -59,8 +67,8 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
||||
}.asStructure()
|
||||
|
||||
public override fun combine(
|
||||
a: ViktorNDStructure,
|
||||
b: ViktorNDStructure,
|
||||
a: NDStructure<Double>,
|
||||
b: NDStructure<Double>,
|
||||
transform: RealField.(Double, Double) -> Double
|
||||
): ViktorNDStructure = F64Array(*shape).apply {
|
||||
this@ViktorNDField.strides.indices().forEach { index ->
|
||||
@ -68,21 +76,21 @@ public class ViktorNDField(public override val shape: IntArray) : NDField<Double
|
||||
}
|
||||
}.asStructure()
|
||||
|
||||
public override fun add(a: ViktorNDStructure, b: ViktorNDStructure): ViktorNDStructure =
|
||||
public override fun add(a: NDStructure<Double>, b: NDStructure<Double>): ViktorNDStructure =
|
||||
(a.f64Buffer + b.f64Buffer).asStructure()
|
||||
|
||||
public override fun multiply(a: ViktorNDStructure, k: Number): ViktorNDStructure =
|
||||
public override fun multiply(a: NDStructure<Double>, k: Number): ViktorNDStructure =
|
||||
(a.f64Buffer * k.toDouble()).asStructure()
|
||||
|
||||
public override inline fun ViktorNDStructure.plus(b: ViktorNDStructure): ViktorNDStructure =
|
||||
public override inline fun NDStructure<Double>.plus(b: NDStructure<Double>): ViktorNDStructure =
|
||||
(f64Buffer + b.f64Buffer).asStructure()
|
||||
|
||||
public override inline fun ViktorNDStructure.minus(b: ViktorNDStructure): ViktorNDStructure =
|
||||
public override inline fun NDStructure<Double>.minus(b: NDStructure<Double>): ViktorNDStructure =
|
||||
(f64Buffer - b.f64Buffer).asStructure()
|
||||
|
||||
public override inline fun ViktorNDStructure.times(k: Number): ViktorNDStructure =
|
||||
public override inline fun NDStructure<Double>.times(k: Number): ViktorNDStructure =
|
||||
(f64Buffer * k.toDouble()).asStructure()
|
||||
|
||||
public override inline fun ViktorNDStructure.plus(arg: Double): ViktorNDStructure =
|
||||
public override inline fun NDStructure<Double>.plus(arg: Double): ViktorNDStructure =
|
||||
(f64Buffer.plus(arg)).asStructure()
|
||||
}
|
Loading…
Reference in New Issue
Block a user