forked from kscience/kmath
Explicit mutability for StructureND builders
This commit is contained in:
parent
875e32679b
commit
cdfddb7551
@ -3,6 +3,7 @@
|
||||
## Unreleased
|
||||
|
||||
### Added
|
||||
- Explicit `mutableStructureND` builders for mutable stucures
|
||||
|
||||
### Changed
|
||||
|
||||
|
@ -3,8 +3,6 @@ plugins {
|
||||
`version-catalog`
|
||||
}
|
||||
|
||||
java.targetCompatibility = JavaVersion.VERSION_11
|
||||
|
||||
repositories {
|
||||
mavenLocal()
|
||||
maven("https://repo.kotlin.link")
|
||||
@ -26,6 +24,11 @@ dependencies {
|
||||
implementation("com.fasterxml.jackson.module:jackson-module-kotlin:2.14.+")
|
||||
}
|
||||
|
||||
kotlin.sourceSets.all {
|
||||
languageSettings.optIn("kotlin.OptIn")
|
||||
kotlin{
|
||||
jvmToolchain{
|
||||
languageVersion.set(JavaLanguageVersion.of(11))
|
||||
}
|
||||
sourceSets.all {
|
||||
languageSettings.optIn("kotlin.OptIn")
|
||||
}
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ dependencies {
|
||||
|
||||
implementation("org.slf4j:slf4j-simple:1.7.32")
|
||||
// plotting
|
||||
implementation("space.kscience:plotlykt-server:0.5.0")
|
||||
implementation("space.kscience:plotlykt-server:0.5.3-dev-1")
|
||||
}
|
||||
|
||||
kotlin {
|
||||
|
@ -8,14 +8,15 @@ package space.kscience.kmath.operations
|
||||
import space.kscience.kmath.commons.linear.CMLinearSpace
|
||||
import space.kscience.kmath.linear.matrix
|
||||
import space.kscience.kmath.nd.DoubleBufferND
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.Structure2D
|
||||
import space.kscience.kmath.nd.mutableStructureND
|
||||
import space.kscience.kmath.nd.ndAlgebra
|
||||
import space.kscience.kmath.viktor.ViktorStructureND
|
||||
import space.kscience.kmath.viktor.viktorAlgebra
|
||||
import kotlin.collections.component1
|
||||
import kotlin.collections.component2
|
||||
|
||||
fun main() {
|
||||
val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.structureND(ShapeND(2, 2)) { (i, j) ->
|
||||
val viktorStructure = DoubleField.viktorAlgebra.mutableStructureND(2, 2) { (i, j) ->
|
||||
if (i == j) 2.0 else 0.0
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,10 @@ import space.kscience.kmath.structures.slice
|
||||
import space.kscience.plotly.*
|
||||
import kotlin.math.PI
|
||||
|
||||
fun main() = with(Double.algebra.bufferAlgebra.seriesAlgebra()) {
|
||||
fun Double.Companion.seriesAlgebra() = Double.algebra.bufferAlgebra.seriesAlgebra()
|
||||
|
||||
|
||||
fun main() = with(Double.seriesAlgebra()) {
|
||||
|
||||
|
||||
fun Plot.plotSeries(name: String, buffer: Buffer<Double>) {
|
||||
|
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* Copyright 2018-2023 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.series
|
||||
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.DoubleBuffer
|
||||
import space.kscience.kmath.structures.asBuffer
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.plotly.*
|
||||
import space.kscience.plotly.models.Scatter
|
||||
import space.kscience.plotly.models.ScatterMode
|
||||
import kotlin.random.Random
|
||||
|
||||
fun main(): Unit = with(Double.seriesAlgebra()) {
|
||||
|
||||
val random = Random(1234)
|
||||
|
||||
val arrayOfRandoms = DoubleArray(20) { random.nextDouble() }
|
||||
|
||||
val series1: DoubleBuffer = arrayOfRandoms.asBuffer()
|
||||
val series2: Series<Double> = series1.moveBy(3)
|
||||
|
||||
val res = series2 - series1
|
||||
|
||||
println(res.size)
|
||||
|
||||
println(res)
|
||||
|
||||
fun Plot.series(name: String, buffer: Buffer<Double>, block: Scatter.() -> Unit = {}) {
|
||||
scatter {
|
||||
this.name = name
|
||||
x.numbers = buffer.offsetIndices
|
||||
y.doubles = buffer.toDoubleArray()
|
||||
block()
|
||||
}
|
||||
}
|
||||
|
||||
Plotly.plot {
|
||||
series("series1", series1)
|
||||
series("series2", series2)
|
||||
series("dif", res){
|
||||
mode = ScatterMode.lines
|
||||
line.color("magenta")
|
||||
}
|
||||
}.makeFile(resourceLocation = ResourceLocation.REMOTE)
|
||||
}
|
@ -52,6 +52,15 @@ class StreamDoubleFieldND(override val shape: ShapeND) : FieldND<Double, DoubleF
|
||||
return BufferND(strides, array.asBuffer())
|
||||
}
|
||||
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): MutableBufferND<Double> {
|
||||
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
|
||||
val index = strides.index(offset)
|
||||
DoubleField.initializer(index)
|
||||
}.toArray()
|
||||
|
||||
return MutableBufferND(strides, array.asBuffer())
|
||||
}
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<Double>.map(
|
||||
transform: DoubleField.(Double) -> Double,
|
||||
|
@ -0,0 +1,26 @@
|
||||
/*
|
||||
* Copyright 2018-2023 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.structures
|
||||
|
||||
import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.algebra
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
fun main(): Unit = with(Double.algebra.ndAlgebra) {
|
||||
val structure: MutableStructure2D<Double> = mutableStructureND(ShapeND(2, 2)) { (i, j) ->
|
||||
i.toDouble() + j.toDouble()
|
||||
}.as2D()
|
||||
|
||||
structure[0, 1] = -2.0
|
||||
|
||||
val structure2 = mutableStructureND(2, 2) { (i, j) -> i.toDouble() + j.toDouble() }.as2D()
|
||||
|
||||
structure2[0, 1] = 2.0
|
||||
|
||||
|
||||
println(structure + structure2)
|
||||
}
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
|
@ -37,8 +37,8 @@ kotlin.sourceSets {
|
||||
filter { it.name.contains("test", true) }
|
||||
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
||||
.forEach {
|
||||
it.optIn("space.kscience.kmath.misc.PerformancePitfall")
|
||||
it.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||
it.optIn("space.kscience.kmath.PerformancePitfall")
|
||||
it.optIn("space.kscience.kmath.UnstableKMathAPI")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,16 +16,22 @@ import kotlin.reflect.KClass
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
*/
|
||||
public interface AlgebraND<T, out C : Algebra<T>>: Algebra<StructureND<T>> {
|
||||
public interface AlgebraND<T, out C : Algebra<T>> : Algebra<StructureND<T>> {
|
||||
/**
|
||||
* The algebra over elements of ND structure.
|
||||
*/
|
||||
public val elementAlgebra: C
|
||||
|
||||
/**
|
||||
* Produces a new [MutableStructureND] using given initializer function.
|
||||
*/
|
||||
public fun mutableStructureND(shape: ShapeND, initializer: C.(IntArray) -> T): MutableStructureND<T>
|
||||
|
||||
/**
|
||||
* Produces a new [StructureND] using given initializer function.
|
||||
*/
|
||||
public fun structureND(shape: ShapeND, initializer: C.(IntArray) -> T): StructureND<T>
|
||||
public fun structureND(shape: ShapeND, initializer: C.(IntArray) -> T): StructureND<T> =
|
||||
mutableStructureND(shape, initializer)
|
||||
|
||||
/**
|
||||
* Maps elements from one structure to another one by applying [transform] to them.
|
||||
|
@ -16,9 +16,10 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||
public val bufferAlgebra: BufferAlgebra<T, A>
|
||||
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||
|
||||
override fun structureND(shape: ShapeND, initializer: A.(IntArray) -> T): BufferND<T> {
|
||||
//TODO change AlgebraND contract to include this
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MutableBufferND<T> {
|
||||
val indexer = indexerBuilder(shape)
|
||||
return BufferND(
|
||||
return MutableBufferND(
|
||||
indexer,
|
||||
bufferAlgebra.buffer(indexer.linearSize) { offset ->
|
||||
elementAlgebra.initializer(indexer.index(offset))
|
||||
@ -26,6 +27,9 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||
)
|
||||
}
|
||||
|
||||
override fun structureND(shape: ShapeND, initializer: A.(IntArray) -> T): BufferND<T> =
|
||||
mutableStructureND(shape, initializer)
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
|
||||
is BufferND -> this
|
||||
@ -133,6 +137,11 @@ public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.structureND(
|
||||
initializer: A.(IntArray) -> T,
|
||||
): BufferND<T> = structureND(ShapeND(shape), initializer)
|
||||
|
||||
public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mutableStructureND(
|
||||
vararg shape: Int,
|
||||
initializer: A.(IntArray) -> T,
|
||||
): MutableBufferND<T> = mutableStructureND(ShapeND(shape), initializer)
|
||||
|
||||
public fun <T, EA : Algebra<T>, A> A.structureND(
|
||||
initializer: EA.(IntArray) -> T,
|
||||
): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = structureND(shape, initializer)
|
||||
|
@ -74,7 +74,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
|
||||
transform: DoubleField.(Double, Double) -> Double,
|
||||
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
|
||||
|
||||
override fun structureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
|
||||
val indexer = indexerBuilder(shape)
|
||||
return DoubleBufferND(
|
||||
indexer,
|
||||
@ -225,7 +225,7 @@ public class DoubleFieldND(override val shape: ShapeND) :
|
||||
|
||||
override fun number(value: Number): DoubleBufferND {
|
||||
val d = value.toDouble() // minimize conversions
|
||||
return structureND(shape) { d }
|
||||
return mutableStructureND(shape) { d }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -244,6 +244,7 @@ public interface MutableStructureND<T> : StructureND<T> {
|
||||
* Set value at specified indices
|
||||
*/
|
||||
@PerformancePitfall
|
||||
@Deprecated("")
|
||||
public operator fun <T> MutableStructureND<T>.set(vararg index: Int, value: T) {
|
||||
set(index, value)
|
||||
}
|
@ -17,6 +17,12 @@ public fun <T, A : Algebra<T>> AlgebraND<T, A>.structureND(
|
||||
initializer: A.(IntArray) -> T
|
||||
): StructureND<T> = structureND(ShapeND(shapeFirst, *shapeRest), initializer)
|
||||
|
||||
public fun <T, A : Algebra<T>> AlgebraND<T, A>.mutableStructureND(
|
||||
shapeFirst: Int,
|
||||
vararg shapeRest: Int,
|
||||
initializer: A.(IntArray) -> T
|
||||
): MutableStructureND<T> = mutableStructureND(ShapeND(shapeFirst, *shapeRest), initializer)
|
||||
|
||||
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: ShapeND): StructureND<T> = structureND(shape) { zero }
|
||||
|
||||
@JvmName("zeroVarArg")
|
||||
|
@ -6,7 +6,8 @@
|
||||
package space.kscience.kmath.operations
|
||||
|
||||
import space.kscience.kmath.structures.Buffer
|
||||
import space.kscience.kmath.structures.BufferFactory
|
||||
import space.kscience.kmath.structures.MutableBuffer
|
||||
import space.kscience.kmath.structures.MutableBufferFactory
|
||||
|
||||
public interface WithSize {
|
||||
public val size: Int
|
||||
@ -17,7 +18,7 @@ public interface WithSize {
|
||||
*/
|
||||
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||
public val elementAlgebra: A
|
||||
public val elementBufferFactory: BufferFactory<T> get() = elementAlgebra.bufferFactory
|
||||
public val elementBufferFactory: MutableBufferFactory<T> get() = elementAlgebra.bufferFactory
|
||||
|
||||
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
|
||||
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
||||
@ -73,11 +74,11 @@ private inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
|
||||
return elementBufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
|
||||
}
|
||||
|
||||
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
|
||||
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): MutableBuffer<T> {
|
||||
return elementBufferFactory(size, initializer)
|
||||
}
|
||||
|
||||
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
|
||||
public fun <T, A> A.buffer(initializer: (Int) -> T): MutableBuffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
|
||||
return elementBufferFactory(size, initializer)
|
||||
}
|
||||
|
||||
|
@ -33,7 +33,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
protected val multikStat: Statistics = multikEngine.getStatistics()
|
||||
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
override fun structureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): MultikTensor<T> {
|
||||
val strides = ColumnStrides(shape)
|
||||
val memoryView = initMemoryView<T>(strides.linearSize, type)
|
||||
strides.asSequence().forEachIndexed { linearIndex, tensorIndex ->
|
||||
@ -49,7 +49,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
for (el in array) data[count++] = elementAlgebra.transform(el)
|
||||
NDArray(data, shape = shape.asArray(), dim = array.dim).wrap()
|
||||
} else {
|
||||
structureND(shape) { index ->
|
||||
mutableStructureND(shape) { index ->
|
||||
transform(get(index))
|
||||
}
|
||||
}
|
||||
@ -70,7 +70,7 @@ public abstract class MultikTensorAlgebra<T, A : Ring<T>>(
|
||||
}
|
||||
NDArray(data, shape = array.shape, dim = array.dim).wrap()
|
||||
} else {
|
||||
structureND(shape) { index ->
|
||||
mutableStructureND(shape) { index ->
|
||||
transform(index, get(index))
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
||||
public val StructureND<T>.ndArray: INDArray
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun structureND(shape: ShapeND, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
val struct: Nd4jArrayStructure<T> = Nd4j.create(*shape.asArray())!!.wrap()
|
||||
struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
|
||||
|
@ -37,20 +37,20 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
||||
*/
|
||||
public val StructureND<T>.ndArray: INDArray
|
||||
|
||||
override fun structureND(shape: ShapeND, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T>
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: A.(IntArray) -> T): Nd4jArrayStructure<T>
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.map(transform: A.(T) -> T): Nd4jArrayStructure<T> =
|
||||
structureND(shape) { index -> elementAlgebra.transform(get(index)) }
|
||||
mutableStructureND(shape) { index -> elementAlgebra.transform(get(index)) }
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): Nd4jArrayStructure<T> =
|
||||
structureND(shape) { index -> elementAlgebra.transform(index, get(index)) }
|
||||
mutableStructureND(shape) { index -> elementAlgebra.transform(index, get(index)) }
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): Nd4jArrayStructure<T> {
|
||||
require(left.shape.contentEquals(right.shape))
|
||||
return structureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
||||
return mutableStructureND(left.shape) { index -> elementAlgebra.transform(left[index], right[index]) }
|
||||
}
|
||||
|
||||
override fun T.plus(arg: StructureND<T>): Nd4jArrayStructure<T> = arg.ndArray.add(this).wrap()
|
||||
@ -178,7 +178,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
override fun structureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure<Double> {
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): Nd4jArrayStructure<Double> {
|
||||
val array: INDArray = Nd4j.zeros(*shape.asArray())
|
||||
val indices = ColumnStrides(shape)
|
||||
indices.asSequence().forEach { index ->
|
||||
|
@ -14,6 +14,7 @@ import space.kscience.kmath.PerformancePitfall
|
||||
import space.kscience.kmath.UnstableKMathAPI
|
||||
import space.kscience.kmath.expressions.Symbol
|
||||
import space.kscience.kmath.nd.ColumnStrides
|
||||
import space.kscience.kmath.nd.MutableStructureND
|
||||
import space.kscience.kmath.nd.ShapeND
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
@ -36,10 +37,10 @@ public class DoubleTensorFlowAlgebra internal constructor(
|
||||
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
override fun structureND(
|
||||
override fun mutableStructureND(
|
||||
shape: ShapeND,
|
||||
initializer: DoubleField.(IntArray) -> Double,
|
||||
): StructureND<Double> {
|
||||
): MutableStructureND<Double> {
|
||||
val res = TFloat64.tensorOf(org.tensorflow.ndarray.Shape.of(*shape.toLongArray())) { array ->
|
||||
ColumnStrides(shape).forEach { index ->
|
||||
array.setDouble(elementAlgebra.initializer(index), *index.toLongArray())
|
||||
|
@ -15,12 +15,12 @@ kscience{
|
||||
|
||||
kotlin.sourceSets {
|
||||
all {
|
||||
languageSettings.optIn("space.kscience.kmath.misc.UnstableKMathAPI")
|
||||
languageSettings.optIn("space.kscience.kmath.UnstableKMathAPI")
|
||||
}
|
||||
|
||||
filter { it.name.contains("test", true) }
|
||||
.map(org.jetbrains.kotlin.gradle.plugin.KotlinSourceSet::languageSettings)
|
||||
.forEach { it.optIn("space.kscience.kmath.misc.PerformancePitfall") }
|
||||
.forEach { it.optIn("space.kscience.kmath.PerformancePitfall") }
|
||||
|
||||
commonMain {
|
||||
dependencies {
|
||||
|
@ -127,7 +127,7 @@ public open class DoubleTensorAlgebra :
|
||||
* @param initializer mapping tensor indices to values.
|
||||
* @return tensor with the [shape] shape and data generated by the [initializer].
|
||||
*/
|
||||
override fun structureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray(
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): DoubleTensor = fromArray(
|
||||
shape,
|
||||
RowStrides(shape).asSequence().map { DoubleField.initializer(it) }.toMutableList().toDoubleArray()
|
||||
)
|
||||
|
@ -118,7 +118,7 @@ public open class IntTensorAlgebra : TensorAlgebra<Int, IntRing> {
|
||||
* @param initializer mapping tensor indices to values.
|
||||
* @return tensor with the [shape] shape and data generated by the [initializer].
|
||||
*/
|
||||
override fun structureND(shape: ShapeND, initializer: IntRing.(IntArray) -> Int): IntTensor = fromArray(
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: IntRing.(IntArray) -> Int): IntTensor = fromArray(
|
||||
shape,
|
||||
RowStrides(shape).asSequence().map { IntRing.initializer(it) }.toMutableList().toIntArray()
|
||||
)
|
||||
|
@ -92,7 +92,7 @@ internal class TestDoubleTensor {
|
||||
|
||||
@Test
|
||||
fun test2D() = with(DoubleTensorAlgebra) {
|
||||
val tensor: DoubleTensor = structureND(ShapeND(3, 3)) { (i, j) -> (i - j).toDouble() }
|
||||
val tensor: DoubleTensor = mutableStructureND(ShapeND(3, 3)) { (i, j) -> (i - j).toDouble() }
|
||||
//println(tensor.toPrettyString())
|
||||
val tensor2d = tensor.asDoubleTensor2D()
|
||||
assertBufferEquals(DoubleBuffer(1.0, 0.0, -1.0), tensor2d.rows[1])
|
||||
@ -101,7 +101,7 @@ internal class TestDoubleTensor {
|
||||
|
||||
@Test
|
||||
fun testMatrixIteration() = with(DoubleTensorAlgebra) {
|
||||
val tensor = structureND(ShapeND(3, 3, 3, 3)) { index -> index.sum().toDouble() }
|
||||
val tensor = mutableStructureND(ShapeND(3, 3, 3, 3)) { index -> index.sum().toDouble() }
|
||||
tensor.forEachMatrix { index, matrix ->
|
||||
println(index.joinToString { it.toString() })
|
||||
println(matrix)
|
||||
|
@ -17,8 +17,7 @@ import space.kscience.kmath.operations.ExtendedFieldOps
|
||||
import space.kscience.kmath.operations.NumbersAddOps
|
||||
import space.kscience.kmath.operations.PowerOperations
|
||||
|
||||
@OptIn(UnstableKMathAPI::class, PerformancePitfall::class)
|
||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public open class ViktorFieldOpsND :
|
||||
FieldOpsND<Double, DoubleField>,
|
||||
ExtendedFieldOps<StructureND<Double>>,
|
||||
@ -27,13 +26,13 @@ public open class ViktorFieldOpsND :
|
||||
public val StructureND<Double>.f64Buffer: F64Array
|
||||
get() = when (this) {
|
||||
is ViktorStructureND -> this.f64Buffer
|
||||
else -> structureND(shape) { this@f64Buffer[it] }.f64Buffer
|
||||
else -> mutableStructureND(shape) { this@f64Buffer[it] }.f64Buffer
|
||||
}
|
||||
|
||||
override val elementAlgebra: DoubleField get() = DoubleField
|
||||
|
||||
@OptIn(UnsafeKMathAPI::class)
|
||||
override fun structureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
||||
override fun mutableStructureND(shape: ShapeND, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
||||
F64Array(*shape.asArray()).apply {
|
||||
ColumnStrides(shape).asSequence().forEach { index ->
|
||||
set(value = DoubleField.initializer(index), indices = index)
|
||||
|
Loading…
Reference in New Issue
Block a user