forked from kscience/kmath
Shapeless ND and Buffer algebras
This commit is contained in:
parent
8d2770c275
commit
d0354da80a
@ -42,6 +42,9 @@
|
|||||||
- Use `Symbol` factory function instead of `StringSymbol`
|
- Use `Symbol` factory function instead of `StringSymbol`
|
||||||
- New discoverability pattern: `<Type>.algebra.<nd/etc>`
|
- New discoverability pattern: `<Type>.algebra.<nd/etc>`
|
||||||
- Adjusted commons-math API for linear solvers to match conventions.
|
- Adjusted commons-math API for linear solvers to match conventions.
|
||||||
|
- Buffer algebra does not require size anymore
|
||||||
|
- Operations -> Ops
|
||||||
|
- Default Buffer and ND algebras are now Ops and lack neutral elements (0, 1) as well as algebra-level shapes.
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
- Specialized `DoubleBufferAlgebra`
|
- Specialized `DoubleBufferAlgebra`
|
||||||
|
@ -9,9 +9,10 @@ import kotlinx.benchmark.Benchmark
|
|||||||
import kotlinx.benchmark.Blackhole
|
import kotlinx.benchmark.Blackhole
|
||||||
import kotlinx.benchmark.Scope
|
import kotlinx.benchmark.Scope
|
||||||
import kotlinx.benchmark.State
|
import kotlinx.benchmark.State
|
||||||
|
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.nd.autoNdAlgebra
|
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
|
import space.kscience.kmath.nd.one
|
||||||
import space.kscience.kmath.nd4j.nd4j
|
import space.kscience.kmath.nd4j.nd4j
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -23,21 +24,21 @@ import space.kscience.kmath.tensors.core.tensorAlgebra
|
|||||||
internal class NDFieldBenchmark {
|
internal class NDFieldBenchmark {
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
|
fun autoFieldAdd(blackhole: Blackhole) = with(autoField) {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += one }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
|
fun specializedFieldAdd(blackhole: Blackhole) = with(specializedField) {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
|
fun boxingFieldAdd(blackhole: Blackhole) = with(genericField) {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -56,19 +57,20 @@ internal class NDFieldBenchmark {
|
|||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Benchmark
|
@Benchmark
|
||||||
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
|
fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
|
||||||
// var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(dim, dim)
|
||||||
// repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
// blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
// }
|
}
|
||||||
|
|
||||||
private companion object {
|
private companion object {
|
||||||
private const val dim = 1000
|
private const val dim = 1000
|
||||||
private const val n = 100
|
private const val n = 100
|
||||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
private val shape = intArrayOf(dim, dim)
|
||||||
private val specializedField = DoubleField.ndAlgebra(dim, dim)
|
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||||
private val genericField = DoubleField.ndAlgebra(Buffer.Companion::boxing, dim, dim)
|
private val specializedField = DoubleField.ndAlgebra
|
||||||
private val nd4jField = DoubleField.nd4j(dim, dim)
|
private val genericField = BufferedFieldOpsND(DoubleField, Buffer.Companion::boxing)
|
||||||
|
private val nd4jField = DoubleField.nd4j
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,18 +10,17 @@ import kotlinx.benchmark.Blackhole
|
|||||||
import kotlinx.benchmark.Scope
|
import kotlinx.benchmark.Scope
|
||||||
import kotlinx.benchmark.State
|
import kotlinx.benchmark.State
|
||||||
import org.jetbrains.bio.viktor.F64Array
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.autoNdAlgebra
|
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.viktor.ViktorNDField
|
import space.kscience.kmath.structures.Buffer
|
||||||
|
import space.kscience.kmath.viktor.ViktorFieldND
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
internal class ViktorBenchmark {
|
internal class ViktorBenchmark {
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun automaticFieldAddition(blackhole: Blackhole) {
|
fun automaticFieldAddition(blackhole: Blackhole) {
|
||||||
with(autoField) {
|
with(autoField) {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -30,7 +29,7 @@ internal class ViktorBenchmark {
|
|||||||
@Benchmark
|
@Benchmark
|
||||||
fun realFieldAddition(blackhole: Blackhole) {
|
fun realFieldAddition(blackhole: Blackhole) {
|
||||||
with(realField) {
|
with(realField) {
|
||||||
var res: StructureND<Double> = one
|
var res: StructureND<Double> = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -39,7 +38,7 @@ internal class ViktorBenchmark {
|
|||||||
@Benchmark
|
@Benchmark
|
||||||
fun viktorFieldAddition(blackhole: Blackhole) {
|
fun viktorFieldAddition(blackhole: Blackhole) {
|
||||||
with(viktorField) {
|
with(viktorField) {
|
||||||
var res = one
|
var res = one(shape)
|
||||||
repeat(n) { res += 1.0 }
|
repeat(n) { res += 1.0 }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -56,10 +55,11 @@ internal class ViktorBenchmark {
|
|||||||
private companion object {
|
private companion object {
|
||||||
private const val dim = 1000
|
private const val dim = 1000
|
||||||
private const val n = 100
|
private const val n = 100
|
||||||
|
private val shape = Shape(dim, dim)
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||||
private val realField = DoubleField.ndAlgebra(dim, dim)
|
private val realField = DoubleField.ndAlgebra
|
||||||
private val viktorField = ViktorNDField(dim, dim)
|
private val viktorField = ViktorFieldND(dim, dim)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,18 +10,21 @@ import kotlinx.benchmark.Blackhole
|
|||||||
import kotlinx.benchmark.Scope
|
import kotlinx.benchmark.Scope
|
||||||
import kotlinx.benchmark.State
|
import kotlinx.benchmark.State
|
||||||
import org.jetbrains.bio.viktor.F64Array
|
import org.jetbrains.bio.viktor.F64Array
|
||||||
import space.kscience.kmath.nd.autoNdAlgebra
|
import space.kscience.kmath.nd.BufferedFieldOpsND
|
||||||
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
|
import space.kscience.kmath.nd.one
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.viktor.ViktorFieldND
|
import space.kscience.kmath.viktor.ViktorFieldND
|
||||||
|
|
||||||
@State(Scope.Benchmark)
|
@State(Scope.Benchmark)
|
||||||
internal class ViktorLogBenchmark {
|
internal class ViktorLogBenchmark {
|
||||||
@Benchmark
|
@Benchmark
|
||||||
fun realFieldLog(blackhole: Blackhole) {
|
fun realFieldLog(blackhole: Blackhole) {
|
||||||
with(realNdField) {
|
with(realField) {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce(shape) { 42.0 }
|
||||||
var res = one
|
var res = one(shape)
|
||||||
repeat(n) { res = ln(fortyTwo) }
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
}
|
}
|
||||||
@ -30,7 +33,7 @@ internal class ViktorLogBenchmark {
|
|||||||
@Benchmark
|
@Benchmark
|
||||||
fun viktorFieldLog(blackhole: Blackhole) {
|
fun viktorFieldLog(blackhole: Blackhole) {
|
||||||
with(viktorField) {
|
with(viktorField) {
|
||||||
val fortyTwo = produce { 42.0 }
|
val fortyTwo = produce(shape) { 42.0 }
|
||||||
var res = one
|
var res = one
|
||||||
repeat(n) { res = ln(fortyTwo) }
|
repeat(n) { res = ln(fortyTwo) }
|
||||||
blackhole.consume(res)
|
blackhole.consume(res)
|
||||||
@ -48,10 +51,11 @@ internal class ViktorLogBenchmark {
|
|||||||
private companion object {
|
private companion object {
|
||||||
private const val dim = 1000
|
private const val dim = 1000
|
||||||
private const val n = 100
|
private const val n = 100
|
||||||
|
private val shape = Shape(dim, dim)
|
||||||
|
|
||||||
// automatically build context most suited for given type.
|
// automatically build context most suited for given type.
|
||||||
private val autoField = DoubleField.autoNdAlgebra(dim, dim)
|
private val autoField = BufferedFieldOpsND(DoubleField, Buffer.Companion::auto)
|
||||||
private val realNdField = DoubleField.ndAlgebra(dim, dim)
|
private val realField = DoubleField.ndAlgebra
|
||||||
private val viktorField = ViktorFieldND(intArrayOf(dim, dim))
|
private val viktorField = ViktorFieldND(dim, dim)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,27 +11,27 @@ import space.kscience.kmath.complex.bufferAlgebra
|
|||||||
import space.kscience.kmath.complex.ndAlgebra
|
import space.kscience.kmath.complex.ndAlgebra
|
||||||
import space.kscience.kmath.nd.BufferND
|
import space.kscience.kmath.nd.BufferND
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.nd.produce
|
||||||
|
|
||||||
fun main() = Complex.algebra {
|
fun main() = Complex.algebra {
|
||||||
val complex = 2 + 2 * i
|
val complex = 2 + 2 * i
|
||||||
println(complex * 8 - 5 * i)
|
println(complex * 8 - 5 * i)
|
||||||
|
|
||||||
//flat buffer
|
//flat buffer
|
||||||
val buffer = bufferAlgebra(8).run {
|
val buffer = with(bufferAlgebra){
|
||||||
buffer { Complex(it, -it) }.map { Complex(it.im, it.re) }
|
buffer(8) { Complex(it, -it) }.map { Complex(it.im, it.re) }
|
||||||
}
|
}
|
||||||
println(buffer)
|
println(buffer)
|
||||||
|
|
||||||
|
|
||||||
// 2d element
|
// 2d element
|
||||||
val element: BufferND<Complex> = ndAlgebra(2, 2).produce { (i, j) ->
|
val element: BufferND<Complex> = ndAlgebra.produce(2, 2) { (i, j) ->
|
||||||
Complex(i - j, i + j)
|
Complex(i - j, i + j)
|
||||||
}
|
}
|
||||||
println(element)
|
println(element)
|
||||||
|
|
||||||
// 1d element operation
|
// 1d element operation
|
||||||
val result: StructureND<Complex> = ndAlgebra(8).run {
|
val result: StructureND<Complex> = ndAlgebra{
|
||||||
val a = produce { (it) -> i * it - it.toDouble() }
|
val a = produce(8) { (it) -> i * it - it.toDouble() }
|
||||||
val b = 3
|
val b = 3
|
||||||
val c = Complex(1.0, 1.0)
|
val c = Complex(1.0, 1.0)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ import space.kscience.kmath.nd.ndAlgebra
|
|||||||
import space.kscience.kmath.nd4j.Nd4jArrayField
|
import space.kscience.kmath.nd4j.Nd4jArrayField
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.viktor.ViktorNDField
|
import space.kscience.kmath.viktor.ViktorFieldND
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
import kotlin.system.measureTimeMillis
|
import kotlin.system.measureTimeMillis
|
||||||
@ -41,7 +41,7 @@ fun main() {
|
|||||||
// Nd4j specialized field.
|
// Nd4j specialized field.
|
||||||
val nd4jField = Nd4jArrayField.real(dim, dim)
|
val nd4jField = Nd4jArrayField.real(dim, dim)
|
||||||
//viktor field
|
//viktor field
|
||||||
val viktorField = ViktorNDField(dim, dim)
|
val viktorField = ViktorFieldND(dim, dim)
|
||||||
//parallel processing based on Java Streams
|
//parallel processing based on Java Streams
|
||||||
val parallelField = DoubleField.ndStreaming(dim, dim)
|
val parallelField = DoubleField.ndStreaming(dim, dim)
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ package space.kscience.kmath.structures
|
|||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
import space.kscience.kmath.operations.NumbersAddOps
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import java.util.stream.IntStream
|
import java.util.stream.IntStream
|
||||||
|
|
||||||
@ -17,11 +17,11 @@ import java.util.stream.IntStream
|
|||||||
* execution.
|
* execution.
|
||||||
*/
|
*/
|
||||||
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
||||||
NumbersAddOperations<StructureND<Double>>,
|
NumbersAddOps<StructureND<Double>>,
|
||||||
ExtendedField<StructureND<Double>> {
|
ExtendedField<StructureND<Double>> {
|
||||||
|
|
||||||
private val strides = DefaultStrides(shape)
|
private val strides = DefaultStrides(shape)
|
||||||
override val elementContext: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
override val zero: BufferND<Double> by lazy { produce { zero } }
|
override val zero: BufferND<Double> by lazy { produce { zero } }
|
||||||
override val one: BufferND<Double> by lazy { produce { one } }
|
override val one: BufferND<Double> by lazy { produce { one } }
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
|||||||
this@StreamDoubleFieldND.shape,
|
this@StreamDoubleFieldND.shape,
|
||||||
shape
|
shape
|
||||||
)
|
)
|
||||||
this is BufferND && this.strides == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
|
this is BufferND && this.indexes == this@StreamDoubleFieldND.strides -> this.buffer as DoubleBuffer
|
||||||
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
|
|||||||
return BufferND(strides, array.asBuffer())
|
return BufferND(strides, array.asBuffer())
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun combine(
|
override fun zip(
|
||||||
a: StructureND<Double>,
|
a: StructureND<Double>,
|
||||||
b: StructureND<Double>,
|
b: StructureND<Double>,
|
||||||
transform: DoubleField.(Double, Double) -> Double,
|
transform: DoubleField.(Double, Double) -> Double,
|
||||||
|
@ -8,6 +8,7 @@ package space.kscience.kmath.structures
|
|||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.buffer
|
import space.kscience.kmath.operations.buffer
|
||||||
import space.kscience.kmath.operations.bufferAlgebra
|
import space.kscience.kmath.operations.bufferAlgebra
|
||||||
|
import space.kscience.kmath.operations.withSize
|
||||||
|
|
||||||
inline fun <reified R : Any> MutableBuffer.Companion.same(
|
inline fun <reified R : Any> MutableBuffer.Companion.same(
|
||||||
n: Int,
|
n: Int,
|
||||||
@ -16,7 +17,7 @@ inline fun <reified R : Any> MutableBuffer.Companion.same(
|
|||||||
|
|
||||||
|
|
||||||
fun main() {
|
fun main() {
|
||||||
with(DoubleField.bufferAlgebra(5)) {
|
with(DoubleField.bufferAlgebra.withSize(5)) {
|
||||||
println(number(2.0) + buffer(1, 2, 3, 4, 5))
|
println(number(2.0) + buffer(1, 2, 3, 4, 5))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
2
gradle/wrapper/gradle-wrapper.properties
vendored
2
gradle/wrapper/gradle-wrapper.properties
vendored
@ -1,5 +1,5 @@
|
|||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.1.1-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
@ -18,10 +18,10 @@ import com.github.h0tk3y.betterParse.parser.ParseResult
|
|||||||
import com.github.h0tk3y.betterParse.parser.Parser
|
import com.github.h0tk3y.betterParse.parser.Parser
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.expressions.MST
|
||||||
import space.kscience.kmath.expressions.Symbol
|
import space.kscience.kmath.expressions.Symbol
|
||||||
import space.kscience.kmath.operations.FieldOperations
|
import space.kscience.kmath.operations.FieldOps
|
||||||
import space.kscience.kmath.operations.GroupOperations
|
import space.kscience.kmath.operations.GroupOps
|
||||||
import space.kscience.kmath.operations.PowerOperations
|
import space.kscience.kmath.operations.PowerOperations
|
||||||
import space.kscience.kmath.operations.RingOperations
|
import space.kscience.kmath.operations.RingOps
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
|
* better-parse implementation of grammar defined in the ArithmeticsEvaluator.g4.
|
||||||
@ -60,7 +60,7 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
|||||||
.or(binaryFunction)
|
.or(binaryFunction)
|
||||||
.or(unaryFunction)
|
.or(unaryFunction)
|
||||||
.or(singular)
|
.or(singular)
|
||||||
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOperations.MINUS_OPERATION, it) })
|
.or(-minus and parser(ArithmeticsEvaluator::term) map { MST.Unary(GroupOps.MINUS_OPERATION, it) })
|
||||||
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
|
.or(-lpar and parser(ArithmeticsEvaluator::subSumChain) and -rpar)
|
||||||
|
|
||||||
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
private val powChain: Parser<MST> by leftAssociative(term = term, operator = pow) { a, _, b ->
|
||||||
@ -72,9 +72,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
|||||||
operator = div or mul use TokenMatch::type
|
operator = div or mul use TokenMatch::type
|
||||||
) { a, op, b ->
|
) { a, op, b ->
|
||||||
if (op == div)
|
if (op == div)
|
||||||
MST.Binary(FieldOperations.DIV_OPERATION, a, b)
|
MST.Binary(FieldOps.DIV_OPERATION, a, b)
|
||||||
else
|
else
|
||||||
MST.Binary(RingOperations.TIMES_OPERATION, a, b)
|
MST.Binary(RingOps.TIMES_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
private val subSumChain: Parser<MST> by leftAssociative(
|
private val subSumChain: Parser<MST> by leftAssociative(
|
||||||
@ -82,9 +82,9 @@ public object ArithmeticsEvaluator : Grammar<MST>() {
|
|||||||
operator = plus or minus use TokenMatch::type
|
operator = plus or minus use TokenMatch::type
|
||||||
) { a, op, b ->
|
) { a, op, b ->
|
||||||
if (op == plus)
|
if (op == plus)
|
||||||
MST.Binary(GroupOperations.PLUS_OPERATION, a, b)
|
MST.Binary(GroupOps.PLUS_OPERATION, a, b)
|
||||||
else
|
else
|
||||||
MST.Binary(GroupOperations.MINUS_OPERATION, a, b)
|
MST.Binary(GroupOps.MINUS_OPERATION, a, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
override val rootParser: Parser<MST> by subSumChain
|
override val rootParser: Parser<MST> by subSumChain
|
||||||
|
@ -39,7 +39,7 @@ public val PrintNumeric: RenderFeature = RenderFeature { _, node ->
|
|||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
|
private fun printSignedNumberString(s: String): MathSyntax = if (s.startsWith('-'))
|
||||||
UnaryMinusSyntax(
|
UnaryMinusSyntax(
|
||||||
operation = GroupOperations.MINUS_OPERATION,
|
operation = GroupOps.MINUS_OPERATION,
|
||||||
operand = OperandSyntax(
|
operand = OperandSyntax(
|
||||||
operand = NumberSyntax(string = s.removePrefix("-")),
|
operand = NumberSyntax(string = s.removePrefix("-")),
|
||||||
parentheses = true,
|
parentheses = true,
|
||||||
@ -72,7 +72,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
|
|||||||
val exponent = afterE.toDouble().toString().removeSuffix(".0")
|
val exponent = afterE.toDouble().toString().removeSuffix(".0")
|
||||||
|
|
||||||
return MultiplicationSyntax(
|
return MultiplicationSyntax(
|
||||||
operation = RingOperations.TIMES_OPERATION,
|
operation = RingOps.TIMES_OPERATION,
|
||||||
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
|
left = OperandSyntax(operand = NumberSyntax(significand), parentheses = true),
|
||||||
right = OperandSyntax(
|
right = OperandSyntax(
|
||||||
operand = SuperscriptSyntax(
|
operand = SuperscriptSyntax(
|
||||||
@ -91,7 +91,7 @@ public class PrettyPrintFloats(public val types: Set<KClass<out Number>>) : Rend
|
|||||||
|
|
||||||
if (toString.startsWith('-'))
|
if (toString.startsWith('-'))
|
||||||
return UnaryMinusSyntax(
|
return UnaryMinusSyntax(
|
||||||
operation = GroupOperations.MINUS_OPERATION,
|
operation = GroupOps.MINUS_OPERATION,
|
||||||
operand = OperandSyntax(operand = infty, parentheses = true),
|
operand = OperandSyntax(operand = infty, parentheses = true),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -211,9 +211,9 @@ public class BinaryPlus(operations: Collection<String>?) : Binary(operations) {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
* The default instance configured with [GroupOps.PLUS_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
public val Default: BinaryPlus = BinaryPlus(setOf(GroupOps.PLUS_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,9 +233,9 @@ public class BinaryMinus(operations: Collection<String>?) : Binary(operations) {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
* The default instance configured with [GroupOps.MINUS_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
public val Default: BinaryMinus = BinaryMinus(setOf(GroupOps.MINUS_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,9 +253,9 @@ public class UnaryPlus(operations: Collection<String>?) : Unary(operations) {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [GroupOperations.PLUS_OPERATION].
|
* The default instance configured with [GroupOps.PLUS_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOperations.PLUS_OPERATION))
|
public val Default: UnaryPlus = UnaryPlus(setOf(GroupOps.PLUS_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,9 +273,9 @@ public class UnaryMinus(operations: Collection<String>?) : Unary(operations) {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [GroupOperations.MINUS_OPERATION].
|
* The default instance configured with [GroupOps.MINUS_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOperations.MINUS_OPERATION))
|
public val Default: UnaryMinus = UnaryMinus(setOf(GroupOps.MINUS_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,9 +295,9 @@ public class Fraction(operations: Collection<String>?) : Binary(operations) {
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [FieldOperations.DIV_OPERATION].
|
* The default instance configured with [FieldOps.DIV_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: Fraction = Fraction(setOf(FieldOperations.DIV_OPERATION))
|
public val Default: Fraction = Fraction(setOf(FieldOps.DIV_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -422,9 +422,9 @@ public class Multiplication(operations: Collection<String>?) : Binary(operations
|
|||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
/**
|
/**
|
||||||
* The default instance configured with [RingOperations.TIMES_OPERATION].
|
* The default instance configured with [RingOps.TIMES_OPERATION].
|
||||||
*/
|
*/
|
||||||
public val Default: Multiplication = Multiplication(setOf(RingOperations.TIMES_OPERATION))
|
public val Default: Multiplication = Multiplication(setOf(RingOps.TIMES_OPERATION))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,10 +7,10 @@ package space.kscience.kmath.ast.rendering
|
|||||||
|
|
||||||
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
|
import space.kscience.kmath.ast.rendering.FeaturedMathRendererWithPostProcess.PostProcessPhase
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.FieldOperations
|
import space.kscience.kmath.operations.FieldOps
|
||||||
import space.kscience.kmath.operations.GroupOperations
|
import space.kscience.kmath.operations.GroupOps
|
||||||
import space.kscience.kmath.operations.PowerOperations
|
import space.kscience.kmath.operations.PowerOperations
|
||||||
import space.kscience.kmath.operations.RingOperations
|
import space.kscience.kmath.operations.RingOps
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Removes unnecessary times (×) symbols from [MultiplicationSyntax].
|
* Removes unnecessary times (×) symbols from [MultiplicationSyntax].
|
||||||
@ -306,10 +306,10 @@ public class SimplifyParentheses(public val precedenceFunction: (MathSyntax) ->
|
|||||||
|
|
||||||
is BinarySyntax -> when (it.operation) {
|
is BinarySyntax -> when (it.operation) {
|
||||||
PowerOperations.POW_OPERATION -> 1
|
PowerOperations.POW_OPERATION -> 1
|
||||||
RingOperations.TIMES_OPERATION -> 3
|
RingOps.TIMES_OPERATION -> 3
|
||||||
FieldOperations.DIV_OPERATION -> 3
|
FieldOps.DIV_OPERATION -> 3
|
||||||
GroupOperations.MINUS_OPERATION -> 4
|
GroupOps.MINUS_OPERATION -> 4
|
||||||
GroupOperations.PLUS_OPERATION -> 4
|
GroupOps.PLUS_OPERATION -> 4
|
||||||
else -> 0
|
else -> 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
|
|||||||
|
|
||||||
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
|
import space.kscience.kmath.ast.rendering.TestUtils.testLatex
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.expressions.MST
|
||||||
import space.kscience.kmath.operations.GroupOperations
|
import space.kscience.kmath.operations.GroupOps
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class TestLatex {
|
internal class TestLatex {
|
||||||
@ -36,7 +36,7 @@ internal class TestLatex {
|
|||||||
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
|
fun unaryOperator() = testLatex("sin(1)", "\\operatorname{sin}\\,\\left(1\\right)")
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun unaryPlus() = testLatex(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "+1")
|
fun unaryPlus() = testLatex(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "+1")
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun unaryMinus() = testLatex("-x", "-x")
|
fun unaryMinus() = testLatex("-x", "-x")
|
||||||
|
@ -7,7 +7,7 @@ package space.kscience.kmath.ast.rendering
|
|||||||
|
|
||||||
import space.kscience.kmath.ast.rendering.TestUtils.testMathML
|
import space.kscience.kmath.ast.rendering.TestUtils.testMathML
|
||||||
import space.kscience.kmath.expressions.MST
|
import space.kscience.kmath.expressions.MST
|
||||||
import space.kscience.kmath.operations.GroupOperations
|
import space.kscience.kmath.operations.GroupOps
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
|
|
||||||
internal class TestMathML {
|
internal class TestMathML {
|
||||||
@ -47,7 +47,7 @@ internal class TestMathML {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun unaryPlus() =
|
fun unaryPlus() =
|
||||||
testMathML(MST.Unary(GroupOperations.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
|
testMathML(MST.Unary(GroupOps.PLUS_OPERATION, MST.Numeric(1)), "<mo>+</mo><mn>1</mn>")
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")
|
fun unaryMinus() = testMathML("-x", "<mo>-</mo><mi>x</mi>")
|
||||||
|
@ -108,8 +108,8 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
|
|||||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.f64.const(mst.value)
|
||||||
|
|
||||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
||||||
GroupOperations.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
GroupOps.MINUS_OPERATION -> ctx.f64.neg(visit(mst.value))
|
||||||
GroupOperations.PLUS_OPERATION -> visit(mst.value)
|
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
||||||
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
PowerOperations.SQRT_OPERATION -> ctx.f64.sqrt(visit(mst.value))
|
||||||
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.SIN_OPERATION -> ctx.call("sin", arrayOf(visit(mst.value)), f64)
|
||||||
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
|
TrigonometricOperations.COS_OPERATION -> ctx.call("cos", arrayOf(visit(mst.value)), f64)
|
||||||
@ -129,10 +129,10 @@ internal class DoubleWasmBuilder(target: MST) : WasmBuilder<Double>(f64, DoubleF
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.f64.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOperations.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.f64.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOperations.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.f64.mul(visit(mst.left), visit(mst.right))
|
||||||
FieldOperations.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
FieldOps.DIV_OPERATION -> ctx.f64.div(visit(mst.left), visit(mst.right))
|
||||||
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
|
PowerOperations.POW_OPERATION -> ctx.call("pow", arrayOf(visit(mst.left), visit(mst.right)), f64)
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(mst)
|
||||||
}
|
}
|
||||||
@ -142,15 +142,15 @@ internal class IntWasmBuilder(target: MST) : WasmBuilder<Int>(i32, IntRing, targ
|
|||||||
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
override fun visitNumeric(mst: Numeric): ExpressionRef = ctx.i32.const(mst.value)
|
||||||
|
|
||||||
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
override fun visitUnary(mst: Unary): ExpressionRef = when (mst.operation) {
|
||||||
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(ctx.i32.const(0), visit(mst.value))
|
||||||
GroupOperations.PLUS_OPERATION -> visit(mst.value)
|
GroupOps.PLUS_OPERATION -> visit(mst.value)
|
||||||
else -> super.visitUnary(mst)
|
else -> super.visitUnary(mst)
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
override fun visitBinary(mst: Binary): ExpressionRef = when (mst.operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
GroupOps.PLUS_OPERATION -> ctx.i32.add(visit(mst.left), visit(mst.right))
|
||||||
GroupOperations.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
GroupOps.MINUS_OPERATION -> ctx.i32.sub(visit(mst.left), visit(mst.right))
|
||||||
RingOperations.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
RingOps.TIMES_OPERATION -> ctx.i32.mul(visit(mst.left), visit(mst.right))
|
||||||
else -> super.visitBinary(mst)
|
else -> super.visitBinary(mst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import org.apache.commons.math3.analysis.differentiation.DerivativeStructure
|
|||||||
import space.kscience.kmath.expressions.*
|
import space.kscience.kmath.expressions.*
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
import space.kscience.kmath.operations.NumbersAddOps
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A field over commons-math [DerivativeStructure].
|
* A field over commons-math [DerivativeStructure].
|
||||||
@ -22,7 +22,7 @@ public class DerivativeStructureField(
|
|||||||
public val order: Int,
|
public val order: Int,
|
||||||
bindings: Map<Symbol, Double>,
|
bindings: Map<Symbol, Double>,
|
||||||
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
) : ExtendedField<DerivativeStructure>, ExpressionAlgebra<Double, DerivativeStructure>,
|
||||||
NumbersAddOperations<DerivativeStructure> {
|
NumbersAddOps<DerivativeStructure> {
|
||||||
public val numberOfVariables: Int = bindings.size
|
public val numberOfVariables: Int = bindings.size
|
||||||
|
|
||||||
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
override val zero: DerivativeStructure by lazy { DerivativeStructure(numberOfVariables, order) }
|
||||||
|
@ -52,7 +52,7 @@ private val PI_DIV_2 = Complex(PI / 2, 0)
|
|||||||
public object ComplexField :
|
public object ComplexField :
|
||||||
ExtendedField<Complex>,
|
ExtendedField<Complex>,
|
||||||
Norm<Complex, Complex>,
|
Norm<Complex, Complex>,
|
||||||
NumbersAddOperations<Complex>,
|
NumbersAddOps<Complex>,
|
||||||
ScaleOperations<Complex> {
|
ScaleOperations<Complex> {
|
||||||
|
|
||||||
override val zero: Complex = 0.0.toComplex()
|
override val zero: Complex = 0.0.toComplex()
|
||||||
@ -216,7 +216,6 @@ public data class Complex(val re: Double, val im: Double) {
|
|||||||
|
|
||||||
public val Complex.Companion.algebra: ComplexField get() = ComplexField
|
public val Complex.Companion.algebra: ComplexField get() = ComplexField
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a complex number with real part equal to this real.
|
* Creates a complex number with real part equal to this real.
|
||||||
*
|
*
|
||||||
|
@ -6,13 +6,8 @@
|
|||||||
package space.kscience.kmath.complex
|
package space.kscience.kmath.complex
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.BufferND
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.BufferedFieldND
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.nd.StructureND
|
|
||||||
import space.kscience.kmath.operations.BufferField
|
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
|
||||||
import space.kscience.kmath.operations.bufferAlgebra
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
@ -22,100 +17,61 @@ import kotlin.contracts.contract
|
|||||||
* An optimized nd-field for complex numbers
|
* An optimized nd-field for complex numbers
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class ComplexFieldND(
|
public sealed class ComplexFieldOpsND : BufferedFieldOpsND<Complex, ComplexField>(ComplexField.bufferAlgebra),
|
||||||
shape: IntArray,
|
ScaleOperations<StructureND<Complex>>, ExtendedFieldOps<StructureND<Complex>> {
|
||||||
) : BufferedFieldND<Complex, ComplexField>(shape, ComplexField, Buffer.Companion::complex),
|
|
||||||
NumbersAddOperations<StructureND<Complex>>,
|
|
||||||
ExtendedField<StructureND<Complex>> {
|
|
||||||
|
|
||||||
override val zero: BufferND<Complex> by lazy { produce { zero } }
|
override fun StructureND<Complex>.toBufferND(): BufferND<Complex> = when (this) {
|
||||||
override val one: BufferND<Complex> by lazy { produce { one } }
|
is BufferND -> this
|
||||||
|
else -> {
|
||||||
override fun number(value: Number): BufferND<Complex> {
|
val indexer = indexerBuilder(shape)
|
||||||
val d = value.toComplex() // minimize conversions
|
BufferND(indexer, Buffer.complex(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||||
return produce { d }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//TODO do specialization
|
||||||
// @Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
// override inline fun map(
|
|
||||||
// arg: AbstractNDBuffer<Double>,
|
|
||||||
// transform: DoubleField.(Double) -> Double,
|
|
||||||
// ): RealNDElement {
|
|
||||||
// check(arg)
|
|
||||||
// val array = RealBuffer(arg.strides.linearSize) { offset -> DoubleField.transform(arg.buffer[offset]) }
|
|
||||||
// return BufferedNDFieldElement(this, array)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// @Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
// override inline fun produce(initializer: DoubleField.(IntArray) -> Double): RealNDElement {
|
|
||||||
// val array = RealBuffer(strides.linearSize) { offset -> elementContext.initializer(strides.index(offset)) }
|
|
||||||
// return BufferedNDFieldElement(this, array)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// @Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
// override inline fun mapIndexed(
|
|
||||||
// arg: AbstractNDBuffer<Double>,
|
|
||||||
// transform: DoubleField.(index: IntArray, Double) -> Double,
|
|
||||||
// ): RealNDElement {
|
|
||||||
// check(arg)
|
|
||||||
// return BufferedNDFieldElement(
|
|
||||||
// this,
|
|
||||||
// RealBuffer(arg.strides.linearSize) { offset ->
|
|
||||||
// elementContext.transform(
|
|
||||||
// arg.strides.index(offset),
|
|
||||||
// arg.buffer[offset]
|
|
||||||
// )
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// @Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
// override inline fun combine(
|
|
||||||
// a: AbstractNDBuffer<Double>,
|
|
||||||
// b: AbstractNDBuffer<Double>,
|
|
||||||
// transform: DoubleField.(Double, Double) -> Double,
|
|
||||||
// ): RealNDElement {
|
|
||||||
// check(a, b)
|
|
||||||
// val buffer = RealBuffer(strides.linearSize) { offset ->
|
|
||||||
// elementContext.transform(a.buffer[offset], b.buffer[offset])
|
|
||||||
// }
|
|
||||||
// return BufferedNDFieldElement(this, buffer)
|
|
||||||
// }
|
|
||||||
|
|
||||||
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> = arg.map { power(it, pow) }
|
override fun scale(a: StructureND<Complex>, value: Double): BufferND<Complex> =
|
||||||
|
mapInline(a.toBufferND()) { it * value }
|
||||||
|
|
||||||
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = arg.map { exp(it) }
|
override fun power(arg: StructureND<Complex>, pow: Number): BufferND<Complex> =
|
||||||
|
mapInline(arg.toBufferND()) { power(it, pow) }
|
||||||
|
|
||||||
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = arg.map { ln(it) }
|
override fun exp(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { exp(it) }
|
||||||
|
override fun ln(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { ln(it) }
|
||||||
|
|
||||||
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sin(it) }
|
override fun sin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sin(it) }
|
||||||
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cos(it) }
|
override fun cos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cos(it) }
|
||||||
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tan(it) }
|
override fun tan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tan(it) }
|
||||||
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asin(it) }
|
override fun asin(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asin(it) }
|
||||||
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acos(it) }
|
override fun acos(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acos(it) }
|
||||||
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atan(it) }
|
override fun atan(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atan(it) }
|
||||||
|
|
||||||
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { sinh(it) }
|
override fun sinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { sinh(it) }
|
||||||
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { cosh(it) }
|
override fun cosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { cosh(it) }
|
||||||
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { tanh(it) }
|
override fun tanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { tanh(it) }
|
||||||
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { asinh(it) }
|
override fun asinh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { asinh(it) }
|
||||||
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { acosh(it) }
|
override fun acosh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { acosh(it) }
|
||||||
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = arg.map { atanh(it) }
|
override fun atanh(arg: StructureND<Complex>): BufferND<Complex> = mapInline(arg.toBufferND()) { atanh(it) }
|
||||||
}
|
|
||||||
|
|
||||||
|
public companion object : ComplexFieldOpsND()
|
||||||
/**
|
|
||||||
* Fast element production using function inlining
|
|
||||||
*/
|
|
||||||
public inline fun BufferedFieldND<Complex, ComplexField>.produceInline(initializer: ComplexField.(Int) -> Complex): BufferND<Complex> {
|
|
||||||
contract { callsInPlace(initializer, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
val buffer = Buffer.complex(strides.linearSize) { offset -> ComplexField.initializer(offset) }
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public fun ComplexField.bufferAlgebra(size: Int): BufferField<Complex, ComplexField> =
|
public val ComplexField.bufferAlgebra: BufferFieldOps<Complex, ComplexField>
|
||||||
bufferAlgebra(Buffer.Companion::complex, size)
|
get() = bufferAlgebra(Buffer.Companion::complex)
|
||||||
|
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public class ComplexFieldND(override val shape: Shape) :
|
||||||
|
ComplexFieldOpsND(), FieldND<Complex, ComplexField>, NumbersAddOps<StructureND<Complex>> {
|
||||||
|
|
||||||
|
override fun number(value: Number): BufferND<Complex> {
|
||||||
|
val d = value.toDouble() // minimize conversions
|
||||||
|
return produce(shape) { d.toComplex() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public val ComplexField.ndAlgebra: ComplexFieldOpsND get() = ComplexFieldOpsND
|
||||||
|
|
||||||
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)
|
public fun ComplexField.ndAlgebra(vararg shape: Int): ComplexFieldND = ComplexFieldND(shape)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ public val Quaternion.r: Double
|
|||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
|
public object QuaternionField : Field<Quaternion>, Norm<Quaternion, Quaternion>, PowerOperations<Quaternion>,
|
||||||
ExponentialOperations<Quaternion>, NumbersAddOperations<Quaternion>, ScaleOperations<Quaternion> {
|
ExponentialOperations<Quaternion>, NumbersAddOps<Quaternion>, ScaleOperations<Quaternion> {
|
||||||
override val zero: Quaternion = 0.toQuaternion()
|
override val zero: Quaternion = 0.toQuaternion()
|
||||||
override val one: Quaternion = 1.toQuaternion()
|
override val one: Quaternion = 1.toQuaternion()
|
||||||
|
|
||||||
|
@ -52,13 +52,13 @@ public open class FunctionalExpressionGroup<T, out A : Group<T>>(
|
|||||||
override val zero: Expression<T> get() = const(algebra.zero)
|
override val zero: Expression<T> get() = const(algebra.zero)
|
||||||
|
|
||||||
override fun Expression<T>.unaryMinus(): Expression<T> =
|
override fun Expression<T>.unaryMinus(): Expression<T> =
|
||||||
unaryOperation(GroupOperations.MINUS_OPERATION, this)
|
unaryOperation(GroupOps.MINUS_OPERATION, this)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds an Expression of addition of two another expressions.
|
* Builds an Expression of addition of two another expressions.
|
||||||
*/
|
*/
|
||||||
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun add(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperation(GroupOperations.PLUS_OPERATION, a, b)
|
binaryOperation(GroupOps.PLUS_OPERATION, a, b)
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Builds an Expression of multiplication of expression by number.
|
// * Builds an Expression of multiplication of expression by number.
|
||||||
@ -89,7 +89,7 @@ public open class FunctionalExpressionRing<T, out A : Ring<T>>(
|
|||||||
* Builds an Expression of multiplication of two expressions.
|
* Builds an Expression of multiplication of two expressions.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun multiply(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
public operator fun Expression<T>.times(arg: T): Expression<T> = this * const(arg)
|
||||||
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
public operator fun T.times(arg: Expression<T>): Expression<T> = arg * this
|
||||||
@ -108,7 +108,7 @@ public open class FunctionalExpressionField<T, out A : Field<T>>(
|
|||||||
* Builds an Expression of division an expression by another one.
|
* Builds an Expression of division an expression by another one.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
override fun divide(a: Expression<T>, b: Expression<T>): Expression<T> =
|
||||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
|
||||||
|
|
||||||
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
public operator fun Expression<T>.div(arg: T): Expression<T> = this / const(arg)
|
||||||
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
public operator fun T.div(arg: Expression<T>): Expression<T> = arg / this
|
||||||
|
@ -31,18 +31,18 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
|||||||
|
|
||||||
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
override fun number(value: Number): MST.Numeric = MstNumericAlgebra.number(value)
|
||||||
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
override fun bindSymbolOrNull(value: String): Symbol = MstNumericAlgebra.bindSymbolOrNull(value)
|
||||||
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOperations.PLUS_OPERATION)(a, b)
|
override fun add(a: MST, b: MST): MST.Binary = binaryOperationFunction(GroupOps.PLUS_OPERATION)(a, b)
|
||||||
override operator fun MST.unaryPlus(): MST.Unary =
|
override operator fun MST.unaryPlus(): MST.Unary =
|
||||||
unaryOperationFunction(GroupOperations.PLUS_OPERATION)(this)
|
unaryOperationFunction(GroupOps.PLUS_OPERATION)(this)
|
||||||
|
|
||||||
override operator fun MST.unaryMinus(): MST.Unary =
|
override operator fun MST.unaryMinus(): MST.Unary =
|
||||||
unaryOperationFunction(GroupOperations.MINUS_OPERATION)(this)
|
unaryOperationFunction(GroupOps.MINUS_OPERATION)(this)
|
||||||
|
|
||||||
override operator fun MST.minus(b: MST): MST.Binary =
|
override operator fun MST.minus(b: MST): MST.Binary =
|
||||||
binaryOperationFunction(GroupOperations.MINUS_OPERATION)(this, b)
|
binaryOperationFunction(GroupOps.MINUS_OPERATION)(this, b)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, number(value))
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, number(value))
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
override fun binaryOperationFunction(operation: String): (left: MST, right: MST) -> MST.Binary =
|
||||||
MstNumericAlgebra.binaryOperationFunction(operation)
|
MstNumericAlgebra.binaryOperationFunction(operation)
|
||||||
@ -56,7 +56,7 @@ public object MstGroup : Group<MST>, NumericAlgebra<MST>, ScaleOperations<MST> {
|
|||||||
*/
|
*/
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
|
public object MstRing : Ring<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||||
override inline val zero: MST.Numeric get() = MstGroup.zero
|
override inline val zero: MST.Numeric get() = MstGroup.zero
|
||||||
override val one: MST.Numeric = number(1.0)
|
override val one: MST.Numeric = number(1.0)
|
||||||
|
|
||||||
@ -65,10 +65,10 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
|
|||||||
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
|
override fun add(a: MST, b: MST): MST.Binary = MstGroup.add(a, b)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
|
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary =
|
override fun multiply(a: MST, b: MST): MST.Binary =
|
||||||
binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, b)
|
binaryOperationFunction(RingOps.TIMES_OPERATION)(a, b)
|
||||||
|
|
||||||
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
|
override operator fun MST.unaryPlus(): MST.Unary = MstGroup { +this@unaryPlus }
|
||||||
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
|
override operator fun MST.unaryMinus(): MST.Unary = MstGroup { -this@unaryMinus }
|
||||||
@ -86,7 +86,7 @@ public object MstRing : Ring<MST>, NumbersAddOperations<MST>, ScaleOperations<MS
|
|||||||
*/
|
*/
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE")
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<MST> {
|
public object MstField : Field<MST>, NumbersAddOps<MST>, ScaleOperations<MST> {
|
||||||
override inline val zero: MST.Numeric get() = MstRing.zero
|
override inline val zero: MST.Numeric get() = MstRing.zero
|
||||||
override inline val one: MST.Numeric get() = MstRing.one
|
override inline val one: MST.Numeric get() = MstRing.one
|
||||||
|
|
||||||
@ -95,11 +95,11 @@ public object MstField : Field<MST>, NumbersAddOperations<MST>, ScaleOperations<
|
|||||||
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
override fun add(a: MST, b: MST): MST.Binary = MstRing.add(a, b)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST.Binary =
|
override fun scale(a: MST, value: Double): MST.Binary =
|
||||||
MstGroup.binaryOperationFunction(RingOperations.TIMES_OPERATION)(a, MstGroup.number(value))
|
MstGroup.binaryOperationFunction(RingOps.TIMES_OPERATION)(a, MstGroup.number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
override fun multiply(a: MST, b: MST): MST.Binary = MstRing.multiply(a, b)
|
||||||
override fun divide(a: MST, b: MST): MST.Binary =
|
override fun divide(a: MST, b: MST): MST.Binary =
|
||||||
binaryOperationFunction(FieldOperations.DIV_OPERATION)(a, b)
|
binaryOperationFunction(FieldOps.DIV_OPERATION)(a, b)
|
||||||
|
|
||||||
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
override operator fun MST.unaryPlus(): MST.Unary = MstRing { +this@unaryPlus }
|
||||||
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
override operator fun MST.unaryMinus(): MST.Unary = MstRing { -this@unaryMinus }
|
||||||
@ -138,7 +138,7 @@ public object MstExtendedField : ExtendedField<MST>, NumericAlgebra<MST> {
|
|||||||
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
override fun sqrt(arg: MST): MST = unaryOperationFunction(PowerOperations.SQRT_OPERATION)(arg)
|
||||||
|
|
||||||
override fun scale(a: MST, value: Double): MST =
|
override fun scale(a: MST, value: Double): MST =
|
||||||
binaryOperation(GroupOperations.PLUS_OPERATION, a, number(value))
|
binaryOperation(GroupOps.PLUS_OPERATION, a, number(value))
|
||||||
|
|
||||||
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
override fun multiply(a: MST, b: MST): MST.Binary = MstField.multiply(a, b)
|
||||||
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
override fun divide(a: MST, b: MST): MST.Binary = MstField.divide(a, b)
|
||||||
|
@ -59,7 +59,7 @@ public fun <T : Any> DerivationResult<T>.grad(vararg variables: Symbol): Point<T
|
|||||||
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
public open class SimpleAutoDiffField<T : Any, F : Field<T>>(
|
||||||
public val context: F,
|
public val context: F,
|
||||||
bindings: Map<Symbol, T>,
|
bindings: Map<Symbol, T>,
|
||||||
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOperations<AutoDiffValue<T>> {
|
) : Field<AutoDiffValue<T>>, ExpressionAlgebra<T, AutoDiffValue<T>>, NumbersAddOps<AutoDiffValue<T>> {
|
||||||
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
override val zero: AutoDiffValue<T> get() = const(context.zero)
|
||||||
override val one: AutoDiffValue<T> get() = const(context.one)
|
override val one: AutoDiffValue<T> get() = const(context.one)
|
||||||
|
|
||||||
|
@ -6,12 +6,10 @@
|
|||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.BufferedRingND
|
import space.kscience.kmath.nd.BufferedRingOpsND
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.nd.asND
|
import space.kscience.kmath.nd.asND
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.operations.Ring
|
|
||||||
import space.kscience.kmath.operations.invoke
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.VirtualBuffer
|
import space.kscience.kmath.structures.VirtualBuffer
|
||||||
@ -19,31 +17,28 @@ import space.kscience.kmath.structures.indices
|
|||||||
|
|
||||||
|
|
||||||
public class BufferedLinearSpace<T, out A : Ring<T>>(
|
public class BufferedLinearSpace<T, out A : Ring<T>>(
|
||||||
override val elementAlgebra: A,
|
private val bufferAlgebra: BufferAlgebra<T, A>
|
||||||
private val bufferFactory: BufferFactory<T>,
|
|
||||||
) : LinearSpace<T, A> {
|
) : LinearSpace<T, A> {
|
||||||
|
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||||
|
|
||||||
private fun ndRing(
|
private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
|
||||||
rows: Int,
|
|
||||||
cols: Int,
|
|
||||||
): BufferedRingND<T, A> = elementAlgebra.ndAlgebra(bufferFactory, rows, cols)
|
|
||||||
|
|
||||||
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
|
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
|
||||||
ndRing(rows, columns).produce { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
|
ndAlgebra.produce(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
|
||||||
|
|
||||||
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
|
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
|
||||||
bufferFactory(size) { elementAlgebra.initializer(it) }
|
bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }
|
||||||
|
|
||||||
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndRing(rowNum, colNum).run {
|
override fun Matrix<T>.unaryMinus(): Matrix<T> = ndAlgebra {
|
||||||
asND().map { -it }.as2D()
|
asND().map { -it }.as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
|
override fun Matrix<T>.plus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||||
asND().plus(other.asND()).as2D()
|
asND().plus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndRing(rowNum, colNum).run {
|
override fun Matrix<T>.minus(other: Matrix<T>): Matrix<T> = ndAlgebra {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||||
asND().minus(other.asND()).as2D()
|
asND().minus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
@ -88,11 +83,11 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<T>.times(value: T): Matrix<T> = ndRing(rowNum, colNum).run {
|
override fun Matrix<T>.times(value: T): Matrix<T> = ndAlgebra {
|
||||||
asND().map { it * value }.as2D()
|
asND().map { it * value }.as2D()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
|
public fun <T, A : Ring<T>> A.linearSpace(bufferFactory: BufferFactory<T>): BufferedLinearSpace<T, A> =
|
||||||
BufferedLinearSpace(this, bufferFactory)
|
BufferedLinearSpace(BufferRingOps(this, bufferFactory))
|
||||||
|
@ -6,11 +6,12 @@
|
|||||||
package space.kscience.kmath.linear
|
package space.kscience.kmath.linear
|
||||||
|
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.DoubleFieldND
|
import space.kscience.kmath.nd.DoubleFieldOpsND
|
||||||
import space.kscience.kmath.nd.as2D
|
import space.kscience.kmath.nd.as2D
|
||||||
import space.kscience.kmath.nd.asND
|
import space.kscience.kmath.nd.asND
|
||||||
import space.kscience.kmath.operations.DoubleBufferOperations
|
import space.kscience.kmath.operations.DoubleBufferOps
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
|
|
||||||
@ -18,30 +19,27 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
|
|||||||
|
|
||||||
override val elementAlgebra: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
private fun ndRing(
|
|
||||||
rows: Int,
|
|
||||||
cols: Int,
|
|
||||||
): DoubleFieldND = DoubleFieldND(intArrayOf(rows, cols))
|
|
||||||
|
|
||||||
override fun buildMatrix(
|
override fun buildMatrix(
|
||||||
rows: Int,
|
rows: Int,
|
||||||
columns: Int,
|
columns: Int,
|
||||||
initializer: DoubleField.(i: Int, j: Int) -> Double
|
initializer: DoubleField.(i: Int, j: Int) -> Double
|
||||||
): Matrix<Double> = ndRing(rows, columns).produce { (i, j) -> DoubleField.initializer(i, j) }.as2D()
|
): Matrix<Double> = DoubleFieldOpsND.produce(intArrayOf(rows, columns)) { (i, j) ->
|
||||||
|
DoubleField.initializer(i, j)
|
||||||
|
}.as2D()
|
||||||
|
|
||||||
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
|
override fun buildVector(size: Int, initializer: DoubleField.(Int) -> Double): DoubleBuffer =
|
||||||
DoubleBuffer(size) { DoubleField.initializer(it) }
|
DoubleBuffer(size) { DoubleField.initializer(it) }
|
||||||
|
|
||||||
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = ndRing(rowNum, colNum).run {
|
override fun Matrix<Double>.unaryMinus(): Matrix<Double> = DoubleFieldOpsND {
|
||||||
asND().map { -it }.as2D()
|
asND().map { -it }.as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
|
override fun Matrix<Double>.plus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::plus. Expected $shape but found ${other.shape}" }
|
||||||
asND().plus(other.asND()).as2D()
|
asND().plus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = ndRing(rowNum, colNum).run {
|
override fun Matrix<Double>.minus(other: Matrix<Double>): Matrix<Double> = DoubleFieldOpsND {
|
||||||
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
require(shape.contentEquals(other.shape)) { "Shape mismatch on Matrix::minus. Expected $shape but found ${other.shape}" }
|
||||||
asND().minus(other.asND()).as2D()
|
asND().minus(other.asND()).as2D()
|
||||||
}
|
}
|
||||||
@ -84,23 +82,23 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun Matrix<Double>.times(value: Double): Matrix<Double> = ndRing(rowNum, colNum).run {
|
override fun Matrix<Double>.times(value: Double): Matrix<Double> = DoubleFieldOpsND {
|
||||||
asND().map { it * value }.as2D()
|
asND().map { it * value }.as2D()
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
|
public override fun Point<Double>.plus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
|
||||||
this@plus + other
|
this@plus + other
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOperations.run {
|
public override fun Point<Double>.minus(other: Point<Double>): DoubleBuffer = DoubleBufferOps.run {
|
||||||
this@minus - other
|
this@minus - other
|
||||||
}
|
}
|
||||||
|
|
||||||
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOperations.run {
|
public override fun Point<Double>.times(value: Double): DoubleBuffer = DoubleBufferOps.run {
|
||||||
scale(this@times, value)
|
scale(this@times, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOperations.run {
|
public operator fun Point<Double>.div(value: Double): DoubleBuffer = DoubleBufferOps.run {
|
||||||
scale(this@div, 1.0 / value)
|
scale(this@div, 1.0 / value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import space.kscience.kmath.nd.MutableStructure2D
|
|||||||
import space.kscience.kmath.nd.Structure2D
|
import space.kscience.kmath.nd.Structure2D
|
||||||
import space.kscience.kmath.nd.StructureFeature
|
import space.kscience.kmath.nd.StructureFeature
|
||||||
import space.kscience.kmath.nd.as1D
|
import space.kscience.kmath.nd.as1D
|
||||||
|
import space.kscience.kmath.operations.BufferRingOps
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
@ -188,7 +189,7 @@ public interface LinearSpace<T, out A : Ring<T>> {
|
|||||||
public fun <T : Any, A : Ring<T>> buffered(
|
public fun <T : Any, A : Ring<T>> buffered(
|
||||||
algebra: A,
|
algebra: A,
|
||||||
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
||||||
): LinearSpace<T, A> = BufferedLinearSpace(algebra, bufferFactory)
|
): LinearSpace<T, A> = BufferedLinearSpace(BufferRingOps(algebra, bufferFactory))
|
||||||
|
|
||||||
@Deprecated("use DoubleField.linearSpace")
|
@Deprecated("use DoubleField.linearSpace")
|
||||||
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)
|
public val double: LinearSpace<Double, DoubleField> = buffered(DoubleField, ::DoubleBuffer)
|
||||||
|
@ -7,7 +7,6 @@ package space.kscience.kmath.nd
|
|||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.*
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -19,6 +18,14 @@ import kotlin.reflect.KClass
|
|||||||
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
public class ShapeMismatchException(public val expected: IntArray, public val actual: IntArray) :
|
||||||
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
RuntimeException("Shape ${actual.contentToString()} doesn't fit in expected shape ${expected.contentToString()}.")
|
||||||
|
|
||||||
|
public typealias Shape = IntArray
|
||||||
|
|
||||||
|
public fun Shape(shapeFirst: Int, vararg shapeRest: Int): Shape = intArrayOf(shapeFirst, *shapeRest)
|
||||||
|
|
||||||
|
public interface WithShape {
|
||||||
|
public val shape: Shape
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The base interface for all ND-algebra implementations.
|
* The base interface for all ND-algebra implementations.
|
||||||
*
|
*
|
||||||
@ -26,20 +33,15 @@ public class ShapeMismatchException(public val expected: IntArray, public val ac
|
|||||||
* @param C the type of the element context.
|
* @param C the type of the element context.
|
||||||
*/
|
*/
|
||||||
public interface AlgebraND<T, out C : Algebra<T>> {
|
public interface AlgebraND<T, out C : Algebra<T>> {
|
||||||
/**
|
|
||||||
* The shape of ND-structures this algebra operates on.
|
|
||||||
*/
|
|
||||||
public val shape: IntArray
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The algebra over elements of ND structure.
|
* The algebra over elements of ND structure.
|
||||||
*/
|
*/
|
||||||
public val elementContext: C
|
public val elementAlgebra: C
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produces a new NDStructure using given initializer function.
|
* Produces a new NDStructure using given initializer function.
|
||||||
*/
|
*/
|
||||||
public fun produce(initializer: C.(IntArray) -> T): StructureND<T>
|
public fun produce(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps elements from one structure to another one by applying [transform] to them.
|
* Maps elements from one structure to another one by applying [transform] to them.
|
||||||
@ -54,7 +56,7 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
|||||||
/**
|
/**
|
||||||
* Combines two structures into one.
|
* Combines two structures into one.
|
||||||
*/
|
*/
|
||||||
public fun combine(a: StructureND<T>, b: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
|
public fun zip(left: StructureND<T>, right: StructureND<T>, transform: C.(T, T) -> T): StructureND<T>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Element-wise invocation of function working on [T] on a [StructureND].
|
* Element-wise invocation of function working on [T] on a [StructureND].
|
||||||
@ -77,7 +79,6 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get a feature of the structure in this scope. Structure features take precedence other context features.
|
* Get a feature of the structure in this scope. Structure features take precedence other context features.
|
||||||
*
|
*
|
||||||
@ -89,37 +90,13 @@ public interface AlgebraND<T, out C : Algebra<T>> {
|
|||||||
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
|
public inline fun <T : Any, reified F : StructureFeature> AlgebraND<T, *>.getFeature(structure: StructureND<T>): F? =
|
||||||
getFeature(structure, F::class)
|
getFeature(structure, F::class)
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if given elements are consistent with this context.
|
|
||||||
*
|
|
||||||
* @param structures the structures to check.
|
|
||||||
* @return the array of valid structures.
|
|
||||||
*/
|
|
||||||
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(vararg structures: StructureND<T>): Array<out StructureND<T>> =
|
|
||||||
structures
|
|
||||||
.map(StructureND<T>::shape)
|
|
||||||
.singleOrNull { !shape.contentEquals(it) }
|
|
||||||
?.let<IntArray, Array<out StructureND<T>>> { throw ShapeMismatchException(shape, it) }
|
|
||||||
?: structures
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Checks if given element is consistent with this context.
|
|
||||||
*
|
|
||||||
* @param element the structure to check.
|
|
||||||
* @return the valid structure.
|
|
||||||
*/
|
|
||||||
internal fun <T, C : Algebra<T>> AlgebraND<T, C>.checkShape(element: StructureND<T>): StructureND<T> {
|
|
||||||
if (!element.shape.contentEquals(shape)) throw ShapeMismatchException(shape, element.shape)
|
|
||||||
return element
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Space of [StructureND].
|
* Space of [StructureND].
|
||||||
*
|
*
|
||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param S the type of group over structure elements.
|
* @param A the type of group over structure elements.
|
||||||
*/
|
*/
|
||||||
public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND<T, S> {
|
public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>, AlgebraND<T, A> {
|
||||||
/**
|
/**
|
||||||
* Element-wise addition.
|
* Element-wise addition.
|
||||||
*
|
*
|
||||||
@ -128,7 +105,7 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
|
|||||||
* @return the sum.
|
* @return the sum.
|
||||||
*/
|
*/
|
||||||
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun add(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||||
combine(a, b) { aValue, bValue -> add(aValue, bValue) }
|
zip(a, b) { aValue, bValue -> add(aValue, bValue) }
|
||||||
|
|
||||||
// TODO move to extensions after KEEP-176
|
// TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
@ -171,13 +148,17 @@ public interface GroupND<T, out S : Group<T>> : Group<StructureND<T>>, AlgebraND
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
|
||||||
|
override val zero: StructureND<T> get() = produce(shape) { elementAlgebra.zero }
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Ring of [StructureND].
|
* Ring of [StructureND].
|
||||||
*
|
*
|
||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param R the type of ring over structure elements.
|
* @param A the type of ring over structure elements.
|
||||||
*/
|
*/
|
||||||
public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R> {
|
public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, GroupOpsND<T, A> {
|
||||||
/**
|
/**
|
||||||
* Element-wise multiplication.
|
* Element-wise multiplication.
|
||||||
*
|
*
|
||||||
@ -186,7 +167,7 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
|
|||||||
* @return the product.
|
* @return the product.
|
||||||
*/
|
*/
|
||||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun multiply(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||||
combine(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
zip(a, b) { aValue, bValue -> multiply(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after KEEP-176
|
||||||
|
|
||||||
@ -211,13 +192,19 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
|
|||||||
public companion object
|
public companion object
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
|
||||||
|
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Field of [StructureND].
|
* Field of [StructureND].
|
||||||
*
|
*
|
||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param F the type field over structure elements.
|
* @param A the type field over structure elements.
|
||||||
*/
|
*/
|
||||||
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F> {
|
public interface FieldOpsND<T, out A : Field<T>> : FieldOps<StructureND<T>>, RingOpsND<T, A>,
|
||||||
|
ScaleOperations<StructureND<T>> {
|
||||||
/**
|
/**
|
||||||
* Element-wise division.
|
* Element-wise division.
|
||||||
*
|
*
|
||||||
@ -226,9 +213,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
|
|||||||
* @return the quotient.
|
* @return the quotient.
|
||||||
*/
|
*/
|
||||||
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
override fun divide(a: StructureND<T>, b: StructureND<T>): StructureND<T> =
|
||||||
combine(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
zip(a, b) { aValue, bValue -> divide(aValue, bValue) }
|
||||||
|
|
||||||
//TODO move to extensions after KEEP-176
|
//TODO move to extensions after https://github.com/Kotlin/KEEP/blob/master/proposals/context-receivers.md
|
||||||
/**
|
/**
|
||||||
* Divides an ND structure by an element of it.
|
* Divides an ND structure by an element of it.
|
||||||
*
|
*
|
||||||
@ -247,42 +234,9 @@ public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T,
|
|||||||
*/
|
*/
|
||||||
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
|
public operator fun T.div(arg: StructureND<T>): StructureND<T> = arg.map { divide(it, this@div) }
|
||||||
|
|
||||||
/**
|
|
||||||
* Element-wise scaling.
|
|
||||||
*
|
|
||||||
* @param a the multiplicand.
|
|
||||||
* @param value the multiplier.
|
|
||||||
* @return the product.
|
|
||||||
*/
|
|
||||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
|
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { scale(it, value) }
|
||||||
|
}
|
||||||
// @ThreadLocal
|
|
||||||
// public companion object {
|
public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
|
||||||
// private val realNDFieldCache: MutableMap<IntArray, RealNDField> = hashMapOf()
|
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one }
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Create a nd-field for [Double] values or pull it from cache if it was created previously.
|
|
||||||
// */
|
|
||||||
// public fun real(vararg shape: Int): RealNDField = realNDFieldCache.getOrPut(shape) { RealNDField(shape) }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Create an ND field with boxing generic buffer.
|
|
||||||
// */
|
|
||||||
// public fun <T : Any, F : Field<T>> boxing(
|
|
||||||
// field: F,
|
|
||||||
// vararg shape: Int,
|
|
||||||
// bufferFactory: BufferFactory<T> = Buffer.Companion::boxing,
|
|
||||||
// ): BufferedNDField<T, F> = BufferedNDField(shape, field, bufferFactory)
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Create a most suitable implementation for nd-field using reified class.
|
|
||||||
// */
|
|
||||||
// @Suppress("UNCHECKED_CAST")
|
|
||||||
// public inline fun <reified T : Any, F : Field<T>> auto(field: F, vararg shape: Int): NDField<T, F> =
|
|
||||||
// when {
|
|
||||||
// T::class == Double::class -> real(*shape) as NDField<T, F>
|
|
||||||
// T::class == Complex::class -> complex(*shape) as BufferedNDField<T, F>
|
|
||||||
// else -> BoxingNDField(shape, field, Buffer.Companion::auto)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
}
|
}
|
@ -3,145 +3,173 @@
|
|||||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@file:OptIn(UnstableKMathAPI::class)
|
||||||
|
|
||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.structures.Buffer
|
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import kotlin.contracts.InvocationKind
|
|
||||||
import kotlin.contracts.contract
|
|
||||||
import kotlin.jvm.JvmName
|
|
||||||
|
|
||||||
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
|
||||||
public val strides: Strides
|
public val indexerBuilder: (IntArray) -> ShapeIndex
|
||||||
public val bufferFactory: BufferFactory<T>
|
public val bufferAlgebra: BufferAlgebra<T, A>
|
||||||
|
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
|
||||||
|
|
||||||
override fun produce(initializer: A.(IntArray) -> T): BufferND<T> = BufferND(
|
override fun produce(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
|
||||||
strides,
|
val indexer = indexerBuilder(shape)
|
||||||
bufferFactory(strides.linearSize) { offset ->
|
return BufferND(
|
||||||
elementContext.initializer(strides.index(offset))
|
indexer,
|
||||||
|
bufferAlgebra.buffer(indexer.linearSize) { offset ->
|
||||||
|
elementAlgebra.initializer(indexer.index(offset))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun StructureND<T>.toBufferND(): BufferND<T> = when (this) {
|
||||||
|
is BufferND -> this
|
||||||
|
else -> {
|
||||||
|
val indexer = indexerBuilder(shape)
|
||||||
|
BufferND(indexer, bufferAlgebra.buffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> = mapInline(toBufferND(), transform)
|
||||||
|
|
||||||
|
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> =
|
||||||
|
mapIndexedInline(toBufferND(), transform)
|
||||||
|
|
||||||
|
override fun zip(left: StructureND<T>, right: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> =
|
||||||
|
zipInline(left.toBufferND(), right.toBufferND(), transform)
|
||||||
|
|
||||||
|
public companion object {
|
||||||
|
public val defaultIndexerBuilder: (IntArray) -> ShapeIndex = DefaultStrides.Companion::invoke
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapInline(
|
||||||
|
arg: BufferND<T>,
|
||||||
|
crossinline transform: A.(T) -> T
|
||||||
|
): BufferND<T> {
|
||||||
|
val indexes = arg.indexes
|
||||||
|
return BufferND(indexes, bufferAlgebra.mapInline(arg.buffer, transform))
|
||||||
|
}
|
||||||
|
|
||||||
|
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.mapIndexedInline(
|
||||||
|
arg: BufferND<T>,
|
||||||
|
crossinline transform: A.(index: IntArray, arg: T) -> T
|
||||||
|
): BufferND<T> {
|
||||||
|
val indexes = arg.indexes
|
||||||
|
return BufferND(
|
||||||
|
indexes,
|
||||||
|
bufferAlgebra.mapIndexedInline(arg.buffer) { offset, value ->
|
||||||
|
transform(indexes.index(offset), value)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
public val StructureND<T>.buffer: Buffer<T>
|
|
||||||
get() = when {
|
|
||||||
!shape.contentEquals(this@BufferAlgebraND.shape) -> throw ShapeMismatchException(
|
|
||||||
this@BufferAlgebraND.shape,
|
|
||||||
shape
|
|
||||||
)
|
|
||||||
this is BufferND && this.strides == this@BufferAlgebraND.strides -> this.buffer
|
|
||||||
else -> bufferFactory(strides.linearSize) { offset -> get(strides.index(offset)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun StructureND<T>.map(transform: A.(T) -> T): BufferND<T> {
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
|
||||||
elementContext.transform(buffer[offset])
|
|
||||||
}
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun StructureND<T>.mapIndexed(transform: A.(index: IntArray, T) -> T): BufferND<T> {
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
|
||||||
elementContext.transform(
|
|
||||||
strides.index(offset),
|
|
||||||
buffer[offset]
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun combine(a: StructureND<T>, b: StructureND<T>, transform: A.(T, T) -> T): BufferND<T> {
|
|
||||||
val buffer = bufferFactory(strides.linearSize) { offset ->
|
|
||||||
elementContext.transform(a.buffer[offset], b.buffer[offset])
|
|
||||||
}
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class BufferedGroupND<T, out A : Group<T>>(
|
internal inline fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.zipInline(
|
||||||
final override val shape: IntArray,
|
l: BufferND<T>,
|
||||||
final override val elementContext: A,
|
r: BufferND<T>,
|
||||||
final override val bufferFactory: BufferFactory<T>,
|
crossinline block: A.(l: T, r: T) -> T
|
||||||
) : GroupND<T, A>, BufferAlgebraND<T, A> {
|
): BufferND<T> {
|
||||||
override val strides: Strides = DefaultStrides(shape)
|
require(l.indexes == r.indexes)
|
||||||
override val zero: BufferND<T> by lazy { produce { zero } }
|
val indexes = l.indexes
|
||||||
override fun StructureND<T>.unaryMinus(): StructureND<T> = produce { -get(it) }
|
return BufferND(indexes, bufferAlgebra.zipInline(l.buffer, r.buffer, block))
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class BufferedRingND<T, out R : Ring<T>>(
|
public open class BufferedGroupNDOps<T, out A : Group<T>>(
|
||||||
shape: IntArray,
|
override val bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
elementContext: R,
|
override val indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||||
bufferFactory: BufferFactory<T>,
|
) : GroupOpsND<T, A>, BufferAlgebraND<T, A> {
|
||||||
) : BufferedGroupND<T, R>(shape, elementContext, bufferFactory), RingND<T, R> {
|
override fun StructureND<T>.unaryMinus(): StructureND<T> = map { -it }
|
||||||
override val one: BufferND<T> by lazy { produce { one } }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public open class BufferedFieldND<T, out R : Field<T>>(
|
public open class BufferedRingOpsND<T, out A : Ring<T>>(
|
||||||
shape: IntArray,
|
bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
elementContext: R,
|
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||||
bufferFactory: BufferFactory<T>,
|
) : BufferedGroupNDOps<T, A>(bufferAlgebra, indexerBuilder), RingOpsND<T, A>
|
||||||
) : BufferedRingND<T, R>(shape, elementContext, bufferFactory), FieldND<T, R> {
|
|
||||||
|
public open class BufferedFieldOpsND<T, out A : Field<T>>(
|
||||||
|
bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
|
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||||
|
) : BufferedRingOpsND<T, A>(bufferAlgebra, indexerBuilder), FieldOpsND<T, A> {
|
||||||
|
|
||||||
|
public constructor(
|
||||||
|
elementAlgebra: A,
|
||||||
|
bufferFactory: BufferFactory<T>,
|
||||||
|
indexerBuilder: (IntArray) -> ShapeIndex = BufferAlgebraND.defaultIndexerBuilder
|
||||||
|
) : this(BufferFieldOps(elementAlgebra, bufferFactory), indexerBuilder)
|
||||||
|
|
||||||
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
override fun scale(a: StructureND<T>, value: Double): StructureND<T> = a.map { it * value }
|
||||||
}
|
}
|
||||||
|
|
||||||
// group factories
|
public val <T, A : Group<T>> BufferAlgebra<T, A>.nd: BufferedGroupNDOps<T, A> get() = BufferedGroupNDOps(this)
|
||||||
public fun <T, A : Group<T>> A.ndAlgebra(
|
public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get() = BufferedRingOpsND(this)
|
||||||
bufferFactory: BufferFactory<T>,
|
public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
|
||||||
vararg shape: Int,
|
|
||||||
): BufferedGroupND<T, A> = BufferedGroupND(shape, this, bufferFactory)
|
|
||||||
|
|
||||||
@JvmName("withNdGroup")
|
|
||||||
public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
|
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
|
||||||
vararg shape: Int,
|
|
||||||
action: BufferedGroupND<T, A>.() -> R,
|
|
||||||
): R {
|
|
||||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
|
||||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
|
||||||
}
|
|
||||||
|
|
||||||
//ring factories
|
public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce(
|
||||||
public fun <T, A : Ring<T>> A.ndAlgebra(
|
|
||||||
bufferFactory: BufferFactory<T>,
|
|
||||||
vararg shape: Int,
|
vararg shape: Int,
|
||||||
): BufferedRingND<T, A> = BufferedRingND(shape, this, bufferFactory)
|
initializer: A.(IntArray) -> T
|
||||||
|
): BufferND<T> = produce(shape, initializer)
|
||||||
|
|
||||||
@JvmName("withNdRing")
|
//// group factories
|
||||||
public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
|
//public fun <T, A : Group<T>> A.ndAlgebra(
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
// bufferAlgebra: BufferAlgebra<T, A>,
|
||||||
vararg shape: Int,
|
// vararg shape: Int,
|
||||||
action: BufferedRingND<T, A>.() -> R,
|
//): BufferedGroupNDOps<T, A> = BufferedGroupNDOps(bufferAlgebra)
|
||||||
): R {
|
//
|
||||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
//@JvmName("withNdGroup")
|
||||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
//public inline fun <T, A : Group<T>, R> A.withNdAlgebra(
|
||||||
}
|
// noinline bufferFactory: BufferFactory<T>,
|
||||||
|
// vararg shape: Int,
|
||||||
|
// action: BufferedGroupNDOps<T, A>.() -> R,
|
||||||
|
//): R {
|
||||||
|
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||||
|
//}
|
||||||
|
|
||||||
//field factories
|
////ring factories
|
||||||
public fun <T, A : Field<T>> A.ndAlgebra(
|
//public fun <T, A : Ring<T>> A.ndAlgebra(
|
||||||
bufferFactory: BufferFactory<T>,
|
// bufferFactory: BufferFactory<T>,
|
||||||
vararg shape: Int,
|
// vararg shape: Int,
|
||||||
): BufferedFieldND<T, A> = BufferedFieldND(shape, this, bufferFactory)
|
//): BufferedRingNDOps<T, A> = BufferedRingNDOps(shape, this, bufferFactory)
|
||||||
|
//
|
||||||
/**
|
//@JvmName("withNdRing")
|
||||||
* Create a [FieldND] for this [Field] inferring proper buffer factory from the type
|
//public inline fun <T, A : Ring<T>, R> A.withNdAlgebra(
|
||||||
*/
|
// noinline bufferFactory: BufferFactory<T>,
|
||||||
@UnstableKMathAPI
|
// vararg shape: Int,
|
||||||
@Suppress("UNCHECKED_CAST")
|
// action: BufferedRingNDOps<T, A>.() -> R,
|
||||||
public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
|
//): R {
|
||||||
vararg shape: Int,
|
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
): FieldND<T, A> = when (this) {
|
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||||
DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
|
//}
|
||||||
else -> BufferedFieldND(shape, this, Buffer.Companion::auto)
|
//
|
||||||
}
|
////field factories
|
||||||
|
//public fun <T, A : Field<T>> A.ndAlgebra(
|
||||||
@JvmName("withNdField")
|
// bufferFactory: BufferFactory<T>,
|
||||||
public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
|
// vararg shape: Int,
|
||||||
noinline bufferFactory: BufferFactory<T>,
|
//): BufferedFieldNDOps<T, A> = BufferedFieldNDOps(shape, this, bufferFactory)
|
||||||
vararg shape: Int,
|
//
|
||||||
action: BufferedFieldND<T, A>.() -> R,
|
///**
|
||||||
): R {
|
// * Create a [FieldND] for this [Field] inferring proper buffer factory from the type
|
||||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
// */
|
||||||
return ndAlgebra(bufferFactory, *shape).run(action)
|
//@UnstableKMathAPI
|
||||||
}
|
//@Suppress("UNCHECKED_CAST")
|
||||||
|
//public inline fun <reified T : Any, A : Field<T>> A.autoNdAlgebra(
|
||||||
|
// vararg shape: Int,
|
||||||
|
//): FieldND<T, A> = when (this) {
|
||||||
|
// DoubleField -> DoubleFieldND(shape) as FieldND<T, A>
|
||||||
|
// else -> BufferedFieldNDOps(shape, this, Buffer.Companion::auto)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//@JvmName("withNdField")
|
||||||
|
//public inline fun <T, A : Field<T>, R> A.withNdAlgebra(
|
||||||
|
// noinline bufferFactory: BufferFactory<T>,
|
||||||
|
// vararg shape: Int,
|
||||||
|
// action: BufferedFieldNDOps<T, A>.() -> R,
|
||||||
|
//): R {
|
||||||
|
// contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
|
// return ndAlgebra(bufferFactory, *shape).run(action)
|
||||||
|
//}
|
@ -15,26 +15,20 @@ import space.kscience.kmath.structures.MutableBufferFactory
|
|||||||
* Represents [StructureND] over [Buffer].
|
* Represents [StructureND] over [Buffer].
|
||||||
*
|
*
|
||||||
* @param T the type of items.
|
* @param T the type of items.
|
||||||
* @param strides The strides to access elements of [Buffer] by linear indices.
|
* @param indexes The strides to access elements of [Buffer] by linear indices.
|
||||||
* @param buffer The underlying buffer.
|
* @param buffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public open class BufferND<out T>(
|
public open class BufferND<out T>(
|
||||||
public val strides: Strides,
|
public val indexes: ShapeIndex,
|
||||||
public val buffer: Buffer<T>,
|
public val buffer: Buffer<T>,
|
||||||
) : StructureND<T> {
|
) : StructureND<T> {
|
||||||
|
|
||||||
init {
|
override operator fun get(index: IntArray): T = buffer[indexes.offset(index)]
|
||||||
if (strides.linearSize != buffer.size) {
|
|
||||||
error("Expected buffer side of ${strides.linearSize}, but found ${buffer.size}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun get(index: IntArray): T = buffer[strides.offset(index)]
|
override val shape: IntArray get() = indexes.shape
|
||||||
|
|
||||||
override val shape: IntArray get() = strides.shape
|
|
||||||
|
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
override fun elements(): Sequence<Pair<IntArray, T>> = strides.indices().map {
|
override fun elements(): Sequence<Pair<IntArray, T>> = indexes.indices().map {
|
||||||
it to this[it]
|
it to this[it]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,7 +43,7 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
|||||||
crossinline transform: (T) -> R,
|
crossinline transform: (T) -> R,
|
||||||
): BufferND<R> {
|
): BufferND<R> {
|
||||||
return if (this is BufferND<T>)
|
return if (this is BufferND<T>)
|
||||||
BufferND(this.strides, factory.invoke(strides.linearSize) { transform(buffer[it]) })
|
BufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(buffer[it]) })
|
||||||
else {
|
else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
BufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
@ -64,11 +58,11 @@ public inline fun <T, reified R : Any> StructureND<T>.mapToBuffer(
|
|||||||
* @param mutableBuffer The underlying buffer.
|
* @param mutableBuffer The underlying buffer.
|
||||||
*/
|
*/
|
||||||
public class MutableBufferND<T>(
|
public class MutableBufferND<T>(
|
||||||
strides: Strides,
|
strides: ShapeIndex,
|
||||||
public val mutableBuffer: MutableBuffer<T>,
|
public val mutableBuffer: MutableBuffer<T>,
|
||||||
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
) : MutableStructureND<T>, BufferND<T>(strides, mutableBuffer) {
|
||||||
override fun set(index: IntArray, value: T) {
|
override fun set(index: IntArray, value: T) {
|
||||||
mutableBuffer[strides.offset(index)] = value
|
mutableBuffer[indexes.offset(index)] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +74,7 @@ public inline fun <T, reified R : Any> MutableStructureND<T>.mapToMutableBuffer(
|
|||||||
crossinline transform: (T) -> R,
|
crossinline transform: (T) -> R,
|
||||||
): MutableBufferND<R> {
|
): MutableBufferND<R> {
|
||||||
return if (this is MutableBufferND<T>)
|
return if (this is MutableBufferND<T>)
|
||||||
MutableBufferND(this.strides, factory.invoke(strides.linearSize) { transform(mutableBuffer[it]) })
|
MutableBufferND(this.indexes, factory.invoke(indexes.linearSize) { transform(mutableBuffer[it]) })
|
||||||
else {
|
else {
|
||||||
val strides = DefaultStrides(shape)
|
val strides = DefaultStrides(shape)
|
||||||
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
MutableBufferND(strides, factory.invoke(strides.linearSize) { transform(get(strides.index(it))) })
|
||||||
|
@ -6,108 +6,68 @@
|
|||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.*
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(DoubleField.bufferAlgebra),
|
||||||
public class DoubleFieldND(
|
ScaleOperations<StructureND<Double>>, ExtendedFieldOps<StructureND<Double>> {
|
||||||
shape: IntArray,
|
|
||||||
) : BufferedFieldND<Double, DoubleField>(shape, DoubleField, ::DoubleBuffer),
|
|
||||||
NumbersAddOperations<StructureND<Double>>,
|
|
||||||
ScaleOperations<StructureND<Double>>,
|
|
||||||
ExtendedField<StructureND<Double>> {
|
|
||||||
|
|
||||||
override val zero: BufferND<Double> by lazy { produce { zero } }
|
override fun StructureND<Double>.toBufferND(): BufferND<Double> = when (this) {
|
||||||
override val one: BufferND<Double> by lazy { produce { one } }
|
is BufferND -> this
|
||||||
|
else -> {
|
||||||
|
val indexer = indexerBuilder(shape)
|
||||||
|
BufferND(indexer, DoubleBuffer(indexer.linearSize) { offset -> get(indexer.index(offset)) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO do specialization
|
||||||
|
|
||||||
|
override fun scale(a: StructureND<Double>, value: Double): BufferND<Double> =
|
||||||
|
mapInline(a.toBufferND()) { it * value }
|
||||||
|
|
||||||
|
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> =
|
||||||
|
mapInline(arg.toBufferND()) { power(it, pow) }
|
||||||
|
|
||||||
|
override fun exp(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { exp(it) }
|
||||||
|
override fun ln(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { ln(it) }
|
||||||
|
|
||||||
|
override fun sin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sin(it) }
|
||||||
|
override fun cos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cos(it) }
|
||||||
|
override fun tan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tan(it) }
|
||||||
|
override fun asin(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asin(it) }
|
||||||
|
override fun acos(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acos(it) }
|
||||||
|
override fun atan(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atan(it) }
|
||||||
|
|
||||||
|
override fun sinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { sinh(it) }
|
||||||
|
override fun cosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { cosh(it) }
|
||||||
|
override fun tanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { tanh(it) }
|
||||||
|
override fun asinh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { asinh(it) }
|
||||||
|
override fun acosh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { acosh(it) }
|
||||||
|
override fun atanh(arg: StructureND<Double>): BufferND<Double> = mapInline(arg.toBufferND()) { atanh(it) }
|
||||||
|
|
||||||
|
public companion object : DoubleFieldOpsND()
|
||||||
|
}
|
||||||
|
|
||||||
|
@OptIn(UnstableKMathAPI::class)
|
||||||
|
public class DoubleFieldND(override val shape: Shape) :
|
||||||
|
DoubleFieldOpsND(), FieldND<Double, DoubleField>, NumbersAddOps<StructureND<Double>> {
|
||||||
|
|
||||||
override fun number(value: Number): BufferND<Double> {
|
override fun number(value: Number): BufferND<Double> {
|
||||||
val d = value.toDouble() // minimize conversions
|
val d = value.toDouble() // minimize conversions
|
||||||
return produce { d }
|
return produce(shape) { d }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val StructureND<Double>.buffer: DoubleBuffer
|
|
||||||
get() = when {
|
|
||||||
!shape.contentEquals(this@DoubleFieldND.shape) -> throw ShapeMismatchException(
|
|
||||||
this@DoubleFieldND.shape,
|
|
||||||
shape
|
|
||||||
)
|
|
||||||
this is BufferND && this.strides == this@DoubleFieldND.strides -> this.buffer as DoubleBuffer
|
|
||||||
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
override inline fun StructureND<Double>.map(
|
|
||||||
transform: DoubleField.(Double) -> Double,
|
|
||||||
): BufferND<Double> {
|
|
||||||
val buffer = DoubleBuffer(strides.linearSize) { offset -> DoubleField.transform(buffer.array[offset]) }
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
override inline fun produce(initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
|
|
||||||
val array = DoubleArray(strides.linearSize) { offset ->
|
|
||||||
val index = strides.index(offset)
|
|
||||||
DoubleField.initializer(index)
|
|
||||||
}
|
|
||||||
return BufferND(strides, DoubleBuffer(array))
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
override inline fun StructureND<Double>.mapIndexed(
|
|
||||||
transform: DoubleField.(index: IntArray, Double) -> Double,
|
|
||||||
): BufferND<Double> = BufferND(
|
|
||||||
strides,
|
|
||||||
buffer = DoubleBuffer(strides.linearSize) { offset ->
|
|
||||||
DoubleField.transform(
|
|
||||||
strides.index(offset),
|
|
||||||
buffer.array[offset]
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE")
|
|
||||||
override inline fun combine(
|
|
||||||
a: StructureND<Double>,
|
|
||||||
b: StructureND<Double>,
|
|
||||||
transform: DoubleField.(Double, Double) -> Double,
|
|
||||||
): BufferND<Double> {
|
|
||||||
val buffer = DoubleBuffer(strides.linearSize) { offset ->
|
|
||||||
DoubleField.transform(a.buffer.array[offset], b.buffer.array[offset])
|
|
||||||
}
|
|
||||||
return BufferND(strides, buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun scale(a: StructureND<Double>, value: Double): StructureND<Double> = a.map { it * value }
|
|
||||||
|
|
||||||
override fun power(arg: StructureND<Double>, pow: Number): BufferND<Double> = arg.map { power(it, pow) }
|
|
||||||
|
|
||||||
override fun exp(arg: StructureND<Double>): BufferND<Double> = arg.map { exp(it) }
|
|
||||||
override fun ln(arg: StructureND<Double>): BufferND<Double> = arg.map { ln(it) }
|
|
||||||
|
|
||||||
override fun sin(arg: StructureND<Double>): BufferND<Double> = arg.map { sin(it) }
|
|
||||||
override fun cos(arg: StructureND<Double>): BufferND<Double> = arg.map { cos(it) }
|
|
||||||
override fun tan(arg: StructureND<Double>): BufferND<Double> = arg.map { tan(it) }
|
|
||||||
override fun asin(arg: StructureND<Double>): BufferND<Double> = arg.map { asin(it) }
|
|
||||||
override fun acos(arg: StructureND<Double>): BufferND<Double> = arg.map { acos(it) }
|
|
||||||
override fun atan(arg: StructureND<Double>): BufferND<Double> = arg.map { atan(it) }
|
|
||||||
|
|
||||||
override fun sinh(arg: StructureND<Double>): BufferND<Double> = arg.map { sinh(it) }
|
|
||||||
override fun cosh(arg: StructureND<Double>): BufferND<Double> = arg.map { cosh(it) }
|
|
||||||
override fun tanh(arg: StructureND<Double>): BufferND<Double> = arg.map { tanh(it) }
|
|
||||||
override fun asinh(arg: StructureND<Double>): BufferND<Double> = arg.map { asinh(it) }
|
|
||||||
override fun acosh(arg: StructureND<Double>): BufferND<Double> = arg.map { acosh(it) }
|
|
||||||
override fun atanh(arg: StructureND<Double>): BufferND<Double> = arg.map { atanh(it) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public val DoubleField.ndAlgebra: DoubleFieldOpsND get() = DoubleFieldOpsND
|
||||||
|
|
||||||
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
|
public fun DoubleField.ndAlgebra(vararg shape: Int): DoubleFieldND = DoubleFieldND(shape)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Produce a context for n-dimensional operations inside this real field
|
* Produce a context for n-dimensional operations inside this real field
|
||||||
*/
|
*/
|
||||||
|
@UnstableKMathAPI
|
||||||
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
|
public inline fun <R> DoubleField.withNdAlgebra(vararg shape: Int, action: DoubleFieldND.() -> R): R {
|
||||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
return DoubleFieldND(shape).run(action)
|
return DoubleFieldND(shape).run(action)
|
||||||
|
@ -0,0 +1,120 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import kotlin.native.concurrent.ThreadLocal
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A converter from linear index to multivariate index
|
||||||
|
*/
|
||||||
|
public interface ShapeIndex{
|
||||||
|
public val shape: Shape
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get linear index from multidimensional index
|
||||||
|
*/
|
||||||
|
public fun offset(index: IntArray): Int
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get multidimensional from linear
|
||||||
|
*/
|
||||||
|
public fun index(offset: Int): IntArray
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
||||||
|
*/
|
||||||
|
public val linearSize: Int
|
||||||
|
|
||||||
|
// TODO introduce a fast way to calculate index of the next element?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterate over ND indices in a natural order
|
||||||
|
*/
|
||||||
|
public fun indices(): Sequence<IntArray>
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean
|
||||||
|
override fun hashCode(): Int
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Linear transformation of indexes
|
||||||
|
*/
|
||||||
|
public abstract class Strides: ShapeIndex {
|
||||||
|
/**
|
||||||
|
* Array strides
|
||||||
|
*/
|
||||||
|
public abstract val strides: IntArray
|
||||||
|
|
||||||
|
public override fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
||||||
|
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
||||||
|
value * strides[i]
|
||||||
|
}.sum()
|
||||||
|
|
||||||
|
// TODO introduce a fast way to calculate index of the next element?
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Iterate over ND indices in a natural order
|
||||||
|
*/
|
||||||
|
public override fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple implementation of [Strides].
|
||||||
|
*/
|
||||||
|
public class DefaultStrides private constructor(override val shape: IntArray) : Strides() {
|
||||||
|
override val linearSize: Int get() = strides[shape.size]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Strides for memory access
|
||||||
|
*/
|
||||||
|
override val strides: IntArray by lazy {
|
||||||
|
sequence {
|
||||||
|
var current = 1
|
||||||
|
yield(1)
|
||||||
|
|
||||||
|
shape.forEach {
|
||||||
|
current *= it
|
||||||
|
yield(current)
|
||||||
|
}
|
||||||
|
}.toList().toIntArray()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun index(offset: Int): IntArray {
|
||||||
|
val res = IntArray(shape.size)
|
||||||
|
var current = offset
|
||||||
|
var strideIndex = strides.size - 2
|
||||||
|
|
||||||
|
while (strideIndex >= 0) {
|
||||||
|
res[strideIndex] = (current / strides[strideIndex])
|
||||||
|
current %= strides[strideIndex]
|
||||||
|
strideIndex--
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other !is DefaultStrides) return false
|
||||||
|
if (!shape.contentEquals(other.shape)) return false
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int = shape.contentHashCode()
|
||||||
|
|
||||||
|
@ThreadLocal
|
||||||
|
public companion object {
|
||||||
|
//private val defaultStridesCache = HashMap<IntArray, Strides>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cached builder for default strides
|
||||||
|
*/
|
||||||
|
public operator fun invoke(shape: IntArray): Strides = DefaultStrides(shape)
|
||||||
|
//defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
||||||
|
|
||||||
|
//TODO fix cache
|
||||||
|
}
|
||||||
|
}
|
@ -6,34 +6,27 @@
|
|||||||
package space.kscience.kmath.nd
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
import space.kscience.kmath.operations.NumbersAddOps
|
||||||
import space.kscience.kmath.operations.ShortRing
|
import space.kscience.kmath.operations.ShortRing
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.operations.bufferAlgebra
|
||||||
import space.kscience.kmath.structures.ShortBuffer
|
|
||||||
import kotlin.contracts.InvocationKind
|
import kotlin.contracts.InvocationKind
|
||||||
import kotlin.contracts.contract
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
public sealed class ShortRingOpsND : BufferedRingOpsND<Short, ShortRing>(ShortRing.bufferAlgebra) {
|
||||||
|
public companion object : ShortRingOpsND()
|
||||||
|
}
|
||||||
|
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public class ShortRingND(
|
public class ShortRingND(
|
||||||
shape: IntArray,
|
override val shape: Shape
|
||||||
) : BufferedRingND<Short, ShortRing>(shape, ShortRing, Buffer.Companion::auto),
|
) : ShortRingOpsND(), RingND<Short, ShortRing>, NumbersAddOps<StructureND<Short>> {
|
||||||
NumbersAddOperations<StructureND<Short>> {
|
|
||||||
|
|
||||||
override val zero: BufferND<Short> by lazy { produce { zero } }
|
|
||||||
override val one: BufferND<Short> by lazy { produce { one } }
|
|
||||||
|
|
||||||
override fun number(value: Number): BufferND<Short> {
|
override fun number(value: Number): BufferND<Short> {
|
||||||
val d = value.toShort() // minimize conversions
|
val d = value.toShort() // minimize conversions
|
||||||
return produce { d }
|
return produce(shape) { d }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Fast element production using function inlining.
|
|
||||||
*/
|
|
||||||
public inline fun BufferedRingND<Short, ShortRing>.produceInline(crossinline initializer: ShortRing.(Int) -> Short): BufferND<Short> =
|
|
||||||
BufferND(strides, ShortBuffer(ShortArray(strides.linearSize) { offset -> ShortRing.initializer(offset) }))
|
|
||||||
|
|
||||||
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
|
public inline fun <R> ShortRing.withNdAlgebra(vararg shape: Int, action: ShortRingND.() -> R): R {
|
||||||
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
contract { callsInPlace(action, InvocationKind.EXACTLY_ONCE) }
|
||||||
return ShortRingND(shape).run(action)
|
return ShortRingND(shape).run(action)
|
||||||
|
@ -15,7 +15,6 @@ import space.kscience.kmath.structures.Buffer
|
|||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import kotlin.jvm.JvmName
|
import kotlin.jvm.JvmName
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
import kotlin.native.concurrent.ThreadLocal
|
|
||||||
import kotlin.reflect.KClass
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
public interface StructureFeature : Feature<StructureFeature>
|
public interface StructureFeature : Feature<StructureFeature>
|
||||||
@ -72,7 +71,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
|||||||
if (st1 === st2) return true
|
if (st1 === st2) return true
|
||||||
|
|
||||||
// fast comparison of buffers if possible
|
// fast comparison of buffers if possible
|
||||||
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
|
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
|
||||||
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
||||||
|
|
||||||
//element by element comparison if it could not be avoided
|
//element by element comparison if it could not be avoided
|
||||||
@ -88,7 +87,7 @@ public interface StructureND<out T> : Featured<StructureFeature> {
|
|||||||
if (st1 === st2) return true
|
if (st1 === st2) return true
|
||||||
|
|
||||||
// fast comparison of buffers if possible
|
// fast comparison of buffers if possible
|
||||||
if (st1 is BufferND && st2 is BufferND && st1.strides == st2.strides)
|
if (st1 is BufferND && st2 is BufferND && st1.indexes == st2.indexes)
|
||||||
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
return Buffer.contentEquals(st1.buffer, st2.buffer)
|
||||||
|
|
||||||
//element by element comparison if it could not be avoided
|
//element by element comparison if it could not be avoided
|
||||||
@ -187,11 +186,11 @@ public fun <T : Comparable<T>> LinearSpace<T, Ring<T>>.contentEquals(
|
|||||||
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
||||||
*/
|
*/
|
||||||
@PerformancePitfall
|
@PerformancePitfall
|
||||||
public fun <T : Comparable<T>> GroupND<T, Ring<T>>.contentEquals(
|
public fun <T : Comparable<T>> GroupOpsND<T, Ring<T>>.contentEquals(
|
||||||
st1: StructureND<T>,
|
st1: StructureND<T>,
|
||||||
st2: StructureND<T>,
|
st2: StructureND<T>,
|
||||||
absoluteTolerance: T,
|
absoluteTolerance: T,
|
||||||
): Boolean = st1.elements().all { (index, value) -> elementContext { (value - st2[index]) } < absoluteTolerance }
|
): Boolean = st1.elements().all { (index, value) -> elementAlgebra { (value - st2[index]) } < absoluteTolerance }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
* Indicates whether some [StructureND] is equal to another one with [absoluteTolerance].
|
||||||
@ -231,107 +230,10 @@ public interface MutableStructureND<T> : StructureND<T> {
|
|||||||
* Transform a structure element-by element in place.
|
* Transform a structure element-by element in place.
|
||||||
*/
|
*/
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (IntArray, T) -> T): Unit =
|
public inline fun <T> MutableStructureND<T>.mapInPlace(action: (index: IntArray, t: T) -> T): Unit =
|
||||||
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
elements().forEach { (index, oldValue) -> this[index] = action(index, oldValue) }
|
||||||
|
|
||||||
/**
|
public inline fun <reified T : Any> StructureND<T>.zip(
|
||||||
* A way to convert ND indices to linear one and back.
|
|
||||||
*/
|
|
||||||
public interface Strides {
|
|
||||||
/**
|
|
||||||
* Shape of NDStructure
|
|
||||||
*/
|
|
||||||
public val shape: IntArray
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Array strides
|
|
||||||
*/
|
|
||||||
public val strides: IntArray
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get linear index from multidimensional index
|
|
||||||
*/
|
|
||||||
public fun offset(index: IntArray): Int = index.mapIndexed { i, value ->
|
|
||||||
if (value < 0 || value >= shape[i]) throw IndexOutOfBoundsException("Index $value out of shape bounds: (0,${this.shape[i]})")
|
|
||||||
value * strides[i]
|
|
||||||
}.sum()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get multidimensional from linear
|
|
||||||
*/
|
|
||||||
public fun index(offset: Int): IntArray
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The size of linear buffer to accommodate all elements of ND-structure corresponding to strides
|
|
||||||
*/
|
|
||||||
public val linearSize: Int
|
|
||||||
|
|
||||||
// TODO introduce a fast way to calculate index of the next element?
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterate over ND indices in a natural order
|
|
||||||
*/
|
|
||||||
public fun indices(): Sequence<IntArray> = (0 until linearSize).asSequence().map(::index)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simple implementation of [Strides].
|
|
||||||
*/
|
|
||||||
public class DefaultStrides private constructor(override val shape: IntArray) : Strides {
|
|
||||||
override val linearSize: Int
|
|
||||||
get() = strides[shape.size]
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Strides for memory access
|
|
||||||
*/
|
|
||||||
override val strides: IntArray by lazy {
|
|
||||||
sequence {
|
|
||||||
var current = 1
|
|
||||||
yield(1)
|
|
||||||
|
|
||||||
shape.forEach {
|
|
||||||
current *= it
|
|
||||||
yield(current)
|
|
||||||
}
|
|
||||||
}.toList().toIntArray()
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun index(offset: Int): IntArray {
|
|
||||||
val res = IntArray(shape.size)
|
|
||||||
var current = offset
|
|
||||||
var strideIndex = strides.size - 2
|
|
||||||
|
|
||||||
while (strideIndex >= 0) {
|
|
||||||
res[strideIndex] = (current / strides[strideIndex])
|
|
||||||
current %= strides[strideIndex]
|
|
||||||
strideIndex--
|
|
||||||
}
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
if (this === other) return true
|
|
||||||
if (other !is DefaultStrides) return false
|
|
||||||
if (!shape.contentEquals(other.shape)) return false
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun hashCode(): Int = shape.contentHashCode()
|
|
||||||
|
|
||||||
@ThreadLocal
|
|
||||||
public companion object {
|
|
||||||
private val defaultStridesCache = HashMap<IntArray, Strides>()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cached builder for default strides
|
|
||||||
*/
|
|
||||||
public operator fun invoke(shape: IntArray): Strides =
|
|
||||||
defaultStridesCache.getOrPut(shape) { DefaultStrides(shape) }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public inline fun <reified T : Any> StructureND<T>.combine(
|
|
||||||
struct: StructureND<T>,
|
struct: StructureND<T>,
|
||||||
crossinline block: (T, T) -> T,
|
crossinline block: (T, T) -> T,
|
||||||
): StructureND<T> {
|
): StructureND<T> {
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2018-2021 KMath contributors.
|
||||||
|
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package space.kscience.kmath.nd
|
||||||
|
|
||||||
|
import space.kscience.kmath.operations.Algebra
|
||||||
|
import space.kscience.kmath.operations.Group
|
||||||
|
import space.kscience.kmath.operations.Ring
|
||||||
|
import kotlin.jvm.JvmName
|
||||||
|
|
||||||
|
|
||||||
|
public fun <T, A : Algebra<T>> AlgebraND<T, A>.produce(
|
||||||
|
shapeFirst: Int,
|
||||||
|
vararg shapeRest: Int,
|
||||||
|
initializer: A.(IntArray) -> T
|
||||||
|
): StructureND<T> = produce(Shape(shapeFirst, *shapeRest), initializer)
|
||||||
|
|
||||||
|
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = produce(shape) { zero }
|
||||||
|
|
||||||
|
@JvmName("zeroVarArg")
|
||||||
|
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
|
||||||
|
shapeFirst: Int,
|
||||||
|
vararg shapeRest: Int,
|
||||||
|
): StructureND<T> = produce(shapeFirst, *shapeRest) { zero }
|
||||||
|
|
||||||
|
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = produce(shape) { one }
|
||||||
|
|
||||||
|
@JvmName("oneVarArg")
|
||||||
|
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
|
||||||
|
shapeFirst: Int,
|
||||||
|
vararg shapeRest: Int,
|
||||||
|
): StructureND<T> = produce(shapeFirst, *shapeRest) { one }
|
@ -117,7 +117,7 @@ public inline operator fun <A : Algebra<*>, R> A.invoke(block: A.() -> R): R = r
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this semispace.
|
* @param T the type of element of this semispace.
|
||||||
*/
|
*/
|
||||||
public interface GroupOperations<T> : Algebra<T> {
|
public interface GroupOps<T> : Algebra<T> {
|
||||||
/**
|
/**
|
||||||
* Addition of two elements.
|
* Addition of two elements.
|
||||||
*
|
*
|
||||||
@ -162,7 +162,7 @@ public interface GroupOperations<T> : Algebra<T> {
|
|||||||
* @return the difference.
|
* @return the difference.
|
||||||
*/
|
*/
|
||||||
public operator fun T.minus(b: T): T = add(this, -b)
|
public operator fun T.minus(b: T): T = add(this, -b)
|
||||||
|
// Dynamic dispatch of operations
|
||||||
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
override fun unaryOperationFunction(operation: String): (arg: T) -> T = when (operation) {
|
||||||
PLUS_OPERATION -> { arg -> +arg }
|
PLUS_OPERATION -> { arg -> +arg }
|
||||||
MINUS_OPERATION -> { arg -> -arg }
|
MINUS_OPERATION -> { arg -> -arg }
|
||||||
@ -193,7 +193,7 @@ public interface GroupOperations<T> : Algebra<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this semispace.
|
* @param T the type of element of this semispace.
|
||||||
*/
|
*/
|
||||||
public interface Group<T> : GroupOperations<T> {
|
public interface Group<T> : GroupOps<T> {
|
||||||
/**
|
/**
|
||||||
* The neutral element of addition.
|
* The neutral element of addition.
|
||||||
*/
|
*/
|
||||||
@ -206,7 +206,7 @@ public interface Group<T> : GroupOperations<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this semiring.
|
* @param T the type of element of this semiring.
|
||||||
*/
|
*/
|
||||||
public interface RingOperations<T> : GroupOperations<T> {
|
public interface RingOps<T> : GroupOps<T> {
|
||||||
/**
|
/**
|
||||||
* Multiplies two elements.
|
* Multiplies two elements.
|
||||||
*
|
*
|
||||||
@ -242,7 +242,7 @@ public interface RingOperations<T> : GroupOperations<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this ring.
|
* @param T the type of element of this ring.
|
||||||
*/
|
*/
|
||||||
public interface Ring<T> : Group<T>, RingOperations<T> {
|
public interface Ring<T> : Group<T>, RingOps<T> {
|
||||||
/**
|
/**
|
||||||
* The neutral element of multiplication
|
* The neutral element of multiplication
|
||||||
*/
|
*/
|
||||||
@ -256,7 +256,7 @@ public interface Ring<T> : Group<T>, RingOperations<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this semifield.
|
* @param T the type of element of this semifield.
|
||||||
*/
|
*/
|
||||||
public interface FieldOperations<T> : RingOperations<T> {
|
public interface FieldOps<T> : RingOps<T> {
|
||||||
/**
|
/**
|
||||||
* Division of two elements.
|
* Division of two elements.
|
||||||
*
|
*
|
||||||
@ -295,6 +295,6 @@ public interface FieldOperations<T> : RingOperations<T> {
|
|||||||
*
|
*
|
||||||
* @param T the type of element of this field.
|
* @param T the type of element of this field.
|
||||||
*/
|
*/
|
||||||
public interface Field<T> : Ring<T>, FieldOperations<T>, ScaleOperations<T>, NumericAlgebra<T> {
|
public interface Field<T> : Ring<T>, FieldOps<T>, ScaleOperations<T>, NumericAlgebra<T> {
|
||||||
override fun number(value: Number): T = scale(one, value.toDouble())
|
override fun number(value: Number): T = scale(one, value.toDouble())
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
import space.kscience.kmath.nd.BufferedRingND
|
import space.kscience.kmath.nd.BufferedRingOpsND
|
||||||
import space.kscience.kmath.operations.BigInt.Companion.BASE
|
import space.kscience.kmath.operations.BigInt.Companion.BASE
|
||||||
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
import space.kscience.kmath.operations.BigInt.Companion.BASE_SIZE
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
@ -26,7 +26,7 @@ private typealias TBase = ULong
|
|||||||
* @author Peter Klimai
|
* @author Peter Klimai
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public object BigIntField : Field<BigInt>, NumbersAddOperations<BigInt>, ScaleOperations<BigInt> {
|
public object BigIntField : Field<BigInt>, NumbersAddOps<BigInt>, ScaleOperations<BigInt> {
|
||||||
override val zero: BigInt = BigInt.ZERO
|
override val zero: BigInt = BigInt.ZERO
|
||||||
override val one: BigInt = BigInt.ONE
|
override val one: BigInt = BigInt.ONE
|
||||||
|
|
||||||
@ -542,5 +542,5 @@ public inline fun MutableBuffer.Companion.bigInt(size: Int, initializer: (Int) -
|
|||||||
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
public inline fun BigInt.mutableBuffer(size: Int, initializer: (Int) -> BigInt): Buffer<BigInt> =
|
||||||
Buffer.boxing(size, initializer)
|
Buffer.boxing(size, initializer)
|
||||||
|
|
||||||
public fun BigIntField.nd(vararg shape: Int): BufferedRingND<BigInt, BigIntField> =
|
public val BigIntField.nd: BufferedRingOpsND<BigInt, BigIntField>
|
||||||
BufferedRingND(shape, BigIntField, BigInt::buffer)
|
get() = BufferedRingOpsND(BufferRingOps(BigIntField, BigInt::buffer))
|
||||||
|
@ -5,32 +5,34 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.operations
|
package space.kscience.kmath.operations
|
||||||
|
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
|
||||||
import space.kscience.kmath.structures.Buffer
|
import space.kscience.kmath.structures.Buffer
|
||||||
import space.kscience.kmath.structures.BufferFactory
|
import space.kscience.kmath.structures.BufferFactory
|
||||||
import space.kscience.kmath.structures.DoubleBuffer
|
import space.kscience.kmath.structures.DoubleBuffer
|
||||||
|
import space.kscience.kmath.structures.ShortBuffer
|
||||||
|
|
||||||
|
public interface WithSize {
|
||||||
|
public val size: Int
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An algebra over [Buffer]
|
* An algebra over [Buffer]
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
public interface BufferAlgebra<T, out A : Algebra<T>> : Algebra<Buffer<T>> {
|
||||||
public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
|
|
||||||
public val bufferFactory: BufferFactory<T>
|
|
||||||
public val elementAlgebra: A
|
public val elementAlgebra: A
|
||||||
public val size: Int
|
public val bufferFactory: BufferFactory<T>
|
||||||
|
|
||||||
public fun buffer(vararg elements: T): Buffer<T> {
|
public fun buffer(size: Int, vararg elements: T): Buffer<T> {
|
||||||
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
||||||
return bufferFactory(size) { elements[it] }
|
return bufferFactory(size) { elements[it] }
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO move to multi-receiver inline extension
|
//TODO move to multi-receiver inline extension
|
||||||
public fun Buffer<T>.map(block: (T) -> T): Buffer<T> = bufferFactory(size) { block(get(it)) }
|
public fun Buffer<T>.map(block: A.(T) -> T): Buffer<T> = mapInline(this, block)
|
||||||
|
|
||||||
public fun Buffer<T>.zip(other: Buffer<T>, block: (left: T, right: T) -> T): Buffer<T> {
|
public fun Buffer<T>.mapIndexed(block: A.(index: Int, arg: T) -> T): Buffer<T> = mapIndexedInline(this, block)
|
||||||
require(size == other.size) { "Incompatible buffer sizes. left: $size, right: ${other.size}" }
|
|
||||||
return bufferFactory(size) { block(this[it], other[it]) }
|
public fun Buffer<T>.zip(other: Buffer<T>, block: A.(left: T, right: T) -> T): Buffer<T> =
|
||||||
}
|
zipInline(this, other, block)
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
|
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
|
||||||
val operationFunction = elementAlgebra.unaryOperationFunction(operation)
|
val operationFunction = elementAlgebra.unaryOperationFunction(operation)
|
||||||
@ -45,112 +47,149 @@ public interface BufferAlgebra<T, A : Algebra<T>> : Algebra<Buffer<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
/**
|
||||||
public fun <T> BufferField<T, *>.buffer(initializer: (Int) -> T): Buffer<T> {
|
* Inline map
|
||||||
|
*/
|
||||||
|
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapInline(
|
||||||
|
buffer: Buffer<T>,
|
||||||
|
crossinline block: A.(T) -> T
|
||||||
|
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(buffer[it]) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline map
|
||||||
|
*/
|
||||||
|
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.mapIndexedInline(
|
||||||
|
buffer: Buffer<T>,
|
||||||
|
crossinline block: A.(index: Int, arg: T) -> T
|
||||||
|
): Buffer<T> = bufferFactory(buffer.size) { elementAlgebra.block(it, buffer[it]) }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inline zip
|
||||||
|
*/
|
||||||
|
public inline fun <T, A : Algebra<T>> BufferAlgebra<T, A>.zipInline(
|
||||||
|
l: Buffer<T>,
|
||||||
|
r: Buffer<T>,
|
||||||
|
crossinline block: A.(l: T, r: T) -> T
|
||||||
|
): Buffer<T> {
|
||||||
|
require(l.size == r.size) { "Incompatible buffer sizes. left: ${l.size}, right: ${r.size}" }
|
||||||
|
return bufferFactory(l.size) { elementAlgebra.block(l[it], r[it]) }
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T> BufferAlgebra<T, *>.buffer(size: Int, initializer: (Int) -> T): Buffer<T> {
|
||||||
|
return bufferFactory(size, initializer)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun <T, A> A.buffer(initializer: (Int) -> T): Buffer<T> where A : BufferAlgebra<T, *>, A : WithSize {
|
||||||
return bufferFactory(size, initializer)
|
return bufferFactory(size, initializer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.sin(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::sin)
|
mapInline(arg) { sin(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.cos(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::cos)
|
mapInline(arg) { cos(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.tan(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::tan)
|
mapInline(arg) { tan(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.asin(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::asin)
|
mapInline(arg) { asin(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.acos(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::acos)
|
mapInline(arg) { acos(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : TrigonometricOperations<T>> BufferAlgebra<T, A>.atan(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::atan)
|
mapInline(arg) { atan(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.exp(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::exp)
|
mapInline(arg) { exp(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.ln(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::ln)
|
mapInline(arg) { ln(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.sinh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::sinh)
|
mapInline(arg) { sinh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.cosh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::cosh)
|
mapInline(arg) { cosh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.tanh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::tanh)
|
mapInline(arg) { tanh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.asinh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::asinh)
|
mapInline(arg) { asinh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.acosh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::acosh)
|
mapInline(arg) { acosh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
|
public fun <T, A : ExponentialOperations<T>> BufferAlgebra<T, A>.atanh(arg: Buffer<T>): Buffer<T> =
|
||||||
arg.map(elementAlgebra::atanh)
|
mapInline(arg) { atanh(it) }
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
|
public fun <T, A : PowerOperations<T>> BufferAlgebra<T, A>.pow(arg: Buffer<T>, pow: Number): Buffer<T> =
|
||||||
with(elementAlgebra) { arg.map { power(it, pow) } }
|
mapInline(arg) { power(it, pow) }
|
||||||
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
public open class BufferRingOps<T, A: Ring<T>>(
|
||||||
public class BufferField<T, A : Field<T>>(
|
|
||||||
override val bufferFactory: BufferFactory<T>,
|
|
||||||
override val elementAlgebra: A,
|
override val elementAlgebra: A,
|
||||||
|
override val bufferFactory: BufferFactory<T>,
|
||||||
|
) : BufferAlgebra<T, A>, RingOps<Buffer<T>>{
|
||||||
|
|
||||||
|
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
|
||||||
|
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
|
||||||
|
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||||
|
|
||||||
|
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> =
|
||||||
|
super<BufferAlgebra>.unaryOperationFunction(operation)
|
||||||
|
|
||||||
|
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
|
||||||
|
super<BufferAlgebra>.binaryOperationFunction(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
public val ShortRing.bufferAlgebra: BufferRingOps<Short, ShortRing>
|
||||||
|
get() = BufferRingOps(ShortRing, ::ShortBuffer)
|
||||||
|
|
||||||
|
public open class BufferFieldOps<T, A : Field<T>>(
|
||||||
|
elementAlgebra: A,
|
||||||
|
bufferFactory: BufferFactory<T>,
|
||||||
|
) : BufferRingOps<T, A>(elementAlgebra, bufferFactory), BufferAlgebra<T, A>, FieldOps<Buffer<T>>, ScaleOperations<Buffer<T>> {
|
||||||
|
|
||||||
|
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l + r }
|
||||||
|
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l * r }
|
||||||
|
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = zipInline(a, b) { l, r -> l / r }
|
||||||
|
|
||||||
|
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = a.map { scale(it, value) }
|
||||||
|
override fun Buffer<T>.unaryMinus(): Buffer<T> = map { -it }
|
||||||
|
|
||||||
|
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> =
|
||||||
|
super<BufferRingOps>.binaryOperationFunction(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
public class BufferField<T, A : Field<T>>(
|
||||||
|
elementAlgebra: A,
|
||||||
|
bufferFactory: BufferFactory<T>,
|
||||||
override val size: Int
|
override val size: Int
|
||||||
) : BufferAlgebra<T, A>, Field<Buffer<T>> {
|
) : BufferFieldOps<T, A>(elementAlgebra, bufferFactory), Field<Buffer<T>>, WithSize {
|
||||||
|
|
||||||
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
|
override val zero: Buffer<T> = bufferFactory(size) { elementAlgebra.zero }
|
||||||
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
|
override val one: Buffer<T> = bufferFactory(size) { elementAlgebra.one }
|
||||||
|
|
||||||
|
|
||||||
override fun add(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::add)
|
|
||||||
override fun multiply(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::multiply)
|
|
||||||
override fun divide(a: Buffer<T>, b: Buffer<T>): Buffer<T> = a.zip(b, elementAlgebra::divide)
|
|
||||||
|
|
||||||
override fun scale(a: Buffer<T>, value: Double): Buffer<T> = with(elementAlgebra) { a.map { scale(it, value) } }
|
|
||||||
override fun Buffer<T>.unaryMinus(): Buffer<T> = with(elementAlgebra) { map { -it } }
|
|
||||||
|
|
||||||
override fun unaryOperationFunction(operation: String): (arg: Buffer<T>) -> Buffer<T> {
|
|
||||||
return super<BufferAlgebra>.unaryOperationFunction(operation)
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun binaryOperationFunction(operation: String): (left: Buffer<T>, right: Buffer<T>) -> Buffer<T> {
|
|
||||||
return super<BufferAlgebra>.binaryOperationFunction(operation)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate full buffer field from given buffer operations
|
||||||
|
*/
|
||||||
|
public fun <T, A : Field<T>> BufferFieldOps<T, A>.withSize(size: Int): BufferField<T, A> =
|
||||||
|
BufferField(elementAlgebra, bufferFactory, size)
|
||||||
|
|
||||||
//Double buffer specialization
|
//Double buffer specialization
|
||||||
|
|
||||||
@UnstableKMathAPI
|
|
||||||
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
|
public fun BufferField<Double, *>.buffer(vararg elements: Number): Buffer<Double> {
|
||||||
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
require(elements.size == size) { "Expected $size elements but found ${elements.size}" }
|
||||||
return bufferFactory(size) { elements[it].toDouble() }
|
return bufferFactory(size) { elements[it].toDouble() }
|
||||||
}
|
}
|
||||||
|
|
||||||
@UnstableKMathAPI
|
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>): BufferFieldOps<T, A> =
|
||||||
public fun <T, A : Field<T>> A.bufferAlgebra(bufferFactory: BufferFactory<T>, size: Int): BufferField<T, A> =
|
BufferFieldOps(this, bufferFactory)
|
||||||
BufferField(bufferFactory, this, size)
|
|
||||||
|
|
||||||
@UnstableKMathAPI
|
public val DoubleField.bufferAlgebra: BufferFieldOps<Double, DoubleField>
|
||||||
public fun DoubleField.bufferAlgebra(size: Int): BufferField<Double, DoubleField> =
|
get() = BufferFieldOps(DoubleField, ::DoubleBuffer)
|
||||||
BufferField(::DoubleBuffer, DoubleField, size)
|
|
||||||
|
|
||||||
|
@ -13,21 +13,21 @@ import space.kscience.kmath.structures.DoubleBuffer
|
|||||||
*
|
*
|
||||||
* @property size the size of buffers to operate on.
|
* @property size the size of buffers to operate on.
|
||||||
*/
|
*/
|
||||||
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOperations() {
|
public class DoubleBufferField(public val size: Int) : ExtendedField<Buffer<Double>>, DoubleBufferOps() {
|
||||||
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
override val zero: Buffer<Double> by lazy { DoubleBuffer(size) { 0.0 } }
|
||||||
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
override val one: Buffer<Double> by lazy { DoubleBuffer(size) { 1.0 } }
|
||||||
|
|
||||||
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.sinh(arg)
|
override fun sinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.sinh(arg)
|
||||||
|
|
||||||
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.cosh(arg)
|
override fun cosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.cosh(arg)
|
||||||
|
|
||||||
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.tanh(arg)
|
override fun tanh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.tanh(arg)
|
||||||
|
|
||||||
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.asinh(arg)
|
override fun asinh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.asinh(arg)
|
||||||
|
|
||||||
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOperations>.acosh(arg)
|
override fun acosh(arg: Buffer<Double>): DoubleBuffer = super<DoubleBufferOps>.acosh(arg)
|
||||||
|
|
||||||
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOperations>.atanh(arg)
|
override fun atanh(arg: Buffer<Double>): DoubleBuffer= super<DoubleBufferOps>.atanh(arg)
|
||||||
|
|
||||||
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
|
// override fun number(value: Number): Buffer<Double> = DoubleBuffer(size) { value.toDouble() }
|
||||||
//
|
//
|
||||||
|
@ -12,9 +12,9 @@ import space.kscience.kmath.structures.DoubleBuffer
|
|||||||
import kotlin.math.*
|
import kotlin.math.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [ExtendedFieldOperations] over [DoubleBuffer].
|
* [ExtendedFieldOps] over [DoubleBuffer].
|
||||||
*/
|
*/
|
||||||
public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
public abstract class DoubleBufferOps : ExtendedFieldOps<Buffer<Double>>, Norm<Buffer<Double>, Double> {
|
||||||
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
|
override fun Buffer<Double>.unaryMinus(): DoubleBuffer = if (this is DoubleBuffer) {
|
||||||
DoubleBuffer(size) { -array[it] }
|
DoubleBuffer(size) { -array[it] }
|
||||||
} else {
|
} else {
|
||||||
@ -185,7 +185,7 @@ public abstract class DoubleBufferOperations : ExtendedFieldOperations<Buffer<Do
|
|||||||
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
|
DoubleBuffer(DoubleArray(a.size) { aArray[it] * value })
|
||||||
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
|
} else DoubleBuffer(DoubleArray(a.size) { a[it] * value })
|
||||||
|
|
||||||
public companion object : DoubleBufferOperations()
|
public companion object : DoubleBufferOps()
|
||||||
}
|
}
|
||||||
|
|
||||||
public object DoubleL2Norm : Norm<Point<Double>, Double> {
|
public object DoubleL2Norm : Norm<Point<Double>, Double> {
|
@ -150,7 +150,7 @@ public interface ScaleOperations<T> : Algebra<T> {
|
|||||||
* TODO to be removed and replaced by extensions after multiple receivers are there
|
* TODO to be removed and replaced by extensions after multiple receivers are there
|
||||||
*/
|
*/
|
||||||
@UnstableKMathAPI
|
@UnstableKMathAPI
|
||||||
public interface NumbersAddOperations<T> : Ring<T>, NumericAlgebra<T> {
|
public interface NumbersAddOps<T> : Ring<T>, NumericAlgebra<T> {
|
||||||
/**
|
/**
|
||||||
* Addition of element and scalar.
|
* Addition of element and scalar.
|
||||||
*
|
*
|
||||||
|
@ -10,8 +10,8 @@ import kotlin.math.pow as kpow
|
|||||||
/**
|
/**
|
||||||
* Advanced Number-like semifield that implements basic operations.
|
* Advanced Number-like semifield that implements basic operations.
|
||||||
*/
|
*/
|
||||||
public interface ExtendedFieldOperations<T> :
|
public interface ExtendedFieldOps<T> :
|
||||||
FieldOperations<T>,
|
FieldOps<T>,
|
||||||
TrigonometricOperations<T>,
|
TrigonometricOperations<T>,
|
||||||
PowerOperations<T>,
|
PowerOperations<T>,
|
||||||
ExponentialOperations<T>,
|
ExponentialOperations<T>,
|
||||||
@ -35,14 +35,14 @@ public interface ExtendedFieldOperations<T> :
|
|||||||
ExponentialOperations.ACOSH_OPERATION -> ::acosh
|
ExponentialOperations.ACOSH_OPERATION -> ::acosh
|
||||||
ExponentialOperations.ASINH_OPERATION -> ::asinh
|
ExponentialOperations.ASINH_OPERATION -> ::asinh
|
||||||
ExponentialOperations.ATANH_OPERATION -> ::atanh
|
ExponentialOperations.ATANH_OPERATION -> ::atanh
|
||||||
else -> super<FieldOperations>.unaryOperationFunction(operation)
|
else -> super<FieldOps>.unaryOperationFunction(operation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Advanced Number-like field that implements basic operations.
|
* Advanced Number-like field that implements basic operations.
|
||||||
*/
|
*/
|
||||||
public interface ExtendedField<T> : ExtendedFieldOperations<T>, Field<T>, NumericAlgebra<T>{
|
public interface ExtendedField<T> : ExtendedFieldOps<T>, Field<T>, NumericAlgebra<T>{
|
||||||
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
|
override fun sinh(arg: T): T = (exp(arg) - exp(-arg)) / 2.0
|
||||||
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
|
override fun cosh(arg: T): T = (exp(arg) + exp(-arg)) / 2.0
|
||||||
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
override fun tanh(arg: T): T = (exp(arg) - exp(-arg)) / (exp(-arg) + exp(arg))
|
||||||
|
@ -7,6 +7,7 @@ package space.kscience.kmath.structures
|
|||||||
|
|
||||||
import space.kscience.kmath.nd.get
|
import space.kscience.kmath.nd.get
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
import space.kscience.kmath.nd.ndAlgebra
|
||||||
|
import space.kscience.kmath.nd.produce
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.testutils.FieldVerifier
|
import space.kscience.kmath.testutils.FieldVerifier
|
||||||
@ -21,7 +22,7 @@ internal class NDFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testStrides() {
|
fun testStrides() {
|
||||||
val ndArray = DoubleField.ndAlgebra(10, 10).produce { (it[0] + it[1]).toDouble() }
|
val ndArray = DoubleField.ndAlgebra.produce(10, 10) { (it[0] + it[1]).toDouble() }
|
||||||
assertEquals(ndArray[5, 5], 10.0)
|
assertEquals(ndArray[5, 5], 10.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,7 @@ package space.kscience.kmath.structures
|
|||||||
|
|
||||||
import space.kscience.kmath.linear.linearSpace
|
import space.kscience.kmath.linear.linearSpace
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.nd.combine
|
|
||||||
import space.kscience.kmath.nd.get
|
|
||||||
import space.kscience.kmath.nd.ndAlgebra
|
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.Norm
|
import space.kscience.kmath.operations.Norm
|
||||||
import space.kscience.kmath.operations.algebra
|
import space.kscience.kmath.operations.algebra
|
||||||
@ -22,9 +19,9 @@ import kotlin.test.assertEquals
|
|||||||
|
|
||||||
@Suppress("UNUSED_VARIABLE")
|
@Suppress("UNUSED_VARIABLE")
|
||||||
class NumberNDFieldTest {
|
class NumberNDFieldTest {
|
||||||
val algebra = DoubleField.ndAlgebra(3, 3)
|
val algebra = DoubleField.ndAlgebra
|
||||||
val array1 = algebra.produce { (i, j) -> (i + j).toDouble() }
|
val array1 = algebra.produce(3, 3) { (i, j) -> (i + j).toDouble() }
|
||||||
val array2 = algebra.produce { (i, j) -> (i - j).toDouble() }
|
val array2 = algebra.produce(3, 3) { (i, j) -> (i - j).toDouble() }
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSum() {
|
fun testSum() {
|
||||||
@ -77,7 +74,7 @@ class NumberNDFieldTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun combineTest() {
|
fun combineTest() {
|
||||||
val division = array1.combine(array2, Double::div)
|
val division = array1.zip(array2, Double::div)
|
||||||
}
|
}
|
||||||
|
|
||||||
object L2Norm : Norm<StructureND<Number>, Double> {
|
object L2Norm : Norm<StructureND<Number>, Double> {
|
||||||
|
@ -10,12 +10,12 @@ import kotlinx.coroutines.flow.Flow
|
|||||||
import kotlinx.coroutines.flow.map
|
import kotlinx.coroutines.flow.map
|
||||||
import kotlinx.coroutines.flow.runningReduce
|
import kotlinx.coroutines.flow.runningReduce
|
||||||
import kotlinx.coroutines.flow.scan
|
import kotlinx.coroutines.flow.scan
|
||||||
import space.kscience.kmath.operations.GroupOperations
|
import space.kscience.kmath.operations.GroupOps
|
||||||
import space.kscience.kmath.operations.Ring
|
import space.kscience.kmath.operations.Ring
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
|
|
||||||
public fun <T> Flow<T>.cumulativeSum(group: GroupOperations<T>): Flow<T> =
|
public fun <T> Flow<T>.cumulativeSum(group: GroupOps<T>): Flow<T> =
|
||||||
group { runningReduce { sum, element -> sum + element } }
|
group { runningReduce { sum, element -> sum + element } }
|
||||||
|
|
||||||
@ExperimentalCoroutinesApi
|
@ExperimentalCoroutinesApi
|
||||||
|
@ -13,8 +13,8 @@ import space.kscience.kmath.structures.DoubleBuffer
|
|||||||
* Map one [BufferND] using function without indices.
|
* Map one [BufferND] using function without indices.
|
||||||
*/
|
*/
|
||||||
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
|
public inline fun BufferND<Double>.mapInline(crossinline transform: DoubleField.(Double) -> Double): BufferND<Double> {
|
||||||
val array = DoubleArray(strides.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
|
val array = DoubleArray(indexes.linearSize) { offset -> DoubleField.transform(buffer[offset]) }
|
||||||
return BufferND(strides, DoubleBuffer(array))
|
return BufferND(indexes, DoubleBuffer(array))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,10 +28,9 @@ public class DoubleHistogramSpace(
|
|||||||
|
|
||||||
public val dimension: Int get() = lower.size
|
public val dimension: Int get() = lower.size
|
||||||
|
|
||||||
private val shape = IntArray(binNums.size) { binNums[it] + 2 }
|
override val shape: IntArray = IntArray(binNums.size) { binNums[it] + 2 }
|
||||||
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
|
override val histogramValueSpace: DoubleFieldND = DoubleField.ndAlgebra(*shape)
|
||||||
|
|
||||||
override val strides: Strides get() = histogramValueSpace.strides
|
|
||||||
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
private val binSize = DoubleBuffer(dimension) { (upper[it] - lower[it]) / binNums[it] }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -52,7 +51,7 @@ public class DoubleHistogramSpace(
|
|||||||
val lowerBoundary = index.mapIndexed { axis, i ->
|
val lowerBoundary = index.mapIndexed { axis, i ->
|
||||||
when (i) {
|
when (i) {
|
||||||
0 -> Double.NEGATIVE_INFINITY
|
0 -> Double.NEGATIVE_INFINITY
|
||||||
strides.shape[axis] - 1 -> upper[axis]
|
shape[axis] - 1 -> upper[axis]
|
||||||
else -> lower[axis] + (i.toDouble()) * binSize[axis]
|
else -> lower[axis] + (i.toDouble()) * binSize[axis]
|
||||||
}
|
}
|
||||||
}.asBuffer()
|
}.asBuffer()
|
||||||
@ -60,7 +59,7 @@ public class DoubleHistogramSpace(
|
|||||||
val upperBoundary = index.mapIndexed { axis, i ->
|
val upperBoundary = index.mapIndexed { axis, i ->
|
||||||
when (i) {
|
when (i) {
|
||||||
0 -> lower[axis]
|
0 -> lower[axis]
|
||||||
strides.shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
shape[axis] - 1 -> Double.POSITIVE_INFINITY
|
||||||
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
|
else -> lower[axis] + (i.toDouble() + 1) * binSize[axis]
|
||||||
}
|
}
|
||||||
}.asBuffer()
|
}.asBuffer()
|
||||||
@ -75,7 +74,7 @@ public class DoubleHistogramSpace(
|
|||||||
}
|
}
|
||||||
|
|
||||||
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
override fun produce(builder: HistogramBuilder<Double>.() -> Unit): IndexedHistogram<Double, Double> {
|
||||||
val ndCounter = StructureND.auto(strides) { Counter.real() }
|
val ndCounter = StructureND.auto(shape) { Counter.real() }
|
||||||
val hBuilder = HistogramBuilder<Double> { point, value ->
|
val hBuilder = HistogramBuilder<Double> { point, value ->
|
||||||
val index = getIndex(point)
|
val index = getIndex(point)
|
||||||
ndCounter[index].add(value.toDouble())
|
ndCounter[index].add(value.toDouble())
|
||||||
|
@ -8,8 +8,9 @@ package space.kscience.kmath.histogram
|
|||||||
import space.kscience.kmath.domains.Domain
|
import space.kscience.kmath.domains.Domain
|
||||||
import space.kscience.kmath.linear.Point
|
import space.kscience.kmath.linear.Point
|
||||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||||
|
import space.kscience.kmath.nd.DefaultStrides
|
||||||
import space.kscience.kmath.nd.FieldND
|
import space.kscience.kmath.nd.FieldND
|
||||||
import space.kscience.kmath.nd.Strides
|
import space.kscience.kmath.nd.Shape
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
import space.kscience.kmath.operations.Group
|
import space.kscience.kmath.operations.Group
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
@ -34,10 +35,10 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
|||||||
return context.produceBin(index, values[index])
|
return context.produceBin(index, values[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
override val dimension: Int get() = context.strides.shape.size
|
override val dimension: Int get() = context.shape.size
|
||||||
|
|
||||||
override val bins: Iterable<Bin<T>>
|
override val bins: Iterable<Bin<T>>
|
||||||
get() = context.strides.indices().map {
|
get() = DefaultStrides(context.shape).indices().map {
|
||||||
context.produceBin(it, values[it])
|
context.produceBin(it, values[it])
|
||||||
}.asIterable()
|
}.asIterable()
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ public class IndexedHistogram<T : Comparable<T>, V : Any>(
|
|||||||
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
public interface IndexedHistogramSpace<T : Comparable<T>, V : Any>
|
||||||
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
|
: Group<IndexedHistogram<T, V>>, ScaleOperations<IndexedHistogram<T, V>> {
|
||||||
//public val valueSpace: Space<V>
|
//public val valueSpace: Space<V>
|
||||||
public val strides: Strides
|
public val shape: Shape
|
||||||
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
|
public val histogramValueSpace: FieldND<V, *> //= NDAlgebra.space(valueSpace, Buffer.Companion::boxing, *shape),
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
package space.kscience.kmath.histogram
|
package space.kscience.kmath.histogram
|
||||||
|
|
||||||
|
import space.kscience.kmath.nd.DefaultStrides
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import space.kscience.kmath.real.DoubleVector
|
import space.kscience.kmath.real.DoubleVector
|
||||||
import kotlin.random.Random
|
import kotlin.random.Random
|
||||||
@ -69,7 +70,7 @@ internal class MultivariateHistogramTest {
|
|||||||
}
|
}
|
||||||
val res = histogram1 - histogram2
|
val res = histogram1 - histogram2
|
||||||
assertTrue {
|
assertTrue {
|
||||||
strides.indices().all { index ->
|
DefaultStrides(shape).indices().all { index ->
|
||||||
res.values[index] <= histogram1.values[index]
|
res.values[index] <= histogram1.values[index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,8 +106,8 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
|||||||
is Symbol -> toSVar()
|
is Symbol -> toSVar()
|
||||||
|
|
||||||
is MST.Unary -> when (operation) {
|
is MST.Unary -> when (operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> +value.toSFun<X>()
|
GroupOps.PLUS_OPERATION -> +value.toSFun<X>()
|
||||||
GroupOperations.MINUS_OPERATION -> -value.toSFun<X>()
|
GroupOps.MINUS_OPERATION -> -value.toSFun<X>()
|
||||||
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
TrigonometricOperations.SIN_OPERATION -> sin(value.toSFun())
|
||||||
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
TrigonometricOperations.COS_OPERATION -> cos(value.toSFun())
|
||||||
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
TrigonometricOperations.TAN_OPERATION -> tan(value.toSFun())
|
||||||
@ -124,10 +124,10 @@ public fun <X : SFun<X>> MST.toSFun(): SFun<X> = when (this) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> when (operation) {
|
is MST.Binary -> when (operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
GroupOps.PLUS_OPERATION -> left.toSFun<X>() + right.toSFun()
|
||||||
GroupOperations.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
GroupOps.MINUS_OPERATION -> left.toSFun<X>() - right.toSFun()
|
||||||
RingOperations.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
RingOps.TIMES_OPERATION -> left.toSFun<X>() * right.toSFun()
|
||||||
FieldOperations.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
FieldOps.DIV_OPERATION -> left.toSFun<X>() / right.toSFun()
|
||||||
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
PowerOperations.POW_OPERATION -> left.toSFun<X>() pow (right as MST.Numeric).toSConst()
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
@ -15,13 +15,6 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.*
|
import space.kscience.kmath.operations.*
|
||||||
|
|
||||||
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
|
|
||||||
val arrayShape = array.shape().toIntArray()
|
|
||||||
if (!shape.contentEquals(arrayShape)) throw ShapeMismatchException(shape, arrayShape)
|
|
||||||
return array
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [AlgebraND] over [Nd4jArrayAlgebra].
|
* Represents [AlgebraND] over [Nd4jArrayAlgebra].
|
||||||
*
|
*
|
||||||
@ -39,33 +32,34 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
|||||||
*/
|
*/
|
||||||
public val StructureND<T>.ndArray: INDArray
|
public val StructureND<T>.ndArray: INDArray
|
||||||
|
|
||||||
override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
override fun produce(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||||
val struct = Nd4j.create(*shape)!!.wrap()
|
val struct = Nd4j.create(*shape)!!.wrap()
|
||||||
struct.indicesIterator().forEach { struct[it] = elementContext.initializer(it) }
|
struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
|
||||||
return struct
|
return struct
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
override fun StructureND<T>.map(transform: C.(T) -> T): Nd4jArrayStructure<T> {
|
||||||
val newStruct = ndArray.dup().wrap()
|
val newStruct = ndArray.dup().wrap()
|
||||||
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementContext.transform(value) }
|
newStruct.elements().forEach { (idx, value) -> newStruct[idx] = elementAlgebra.transform(value) }
|
||||||
return newStruct
|
return newStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun StructureND<T>.mapIndexed(
|
override fun StructureND<T>.mapIndexed(
|
||||||
transform: C.(index: IntArray, T) -> T,
|
transform: C.(index: IntArray, T) -> T,
|
||||||
): Nd4jArrayStructure<T> {
|
): Nd4jArrayStructure<T> {
|
||||||
val new = Nd4j.create(*this@Nd4jArrayAlgebra.shape).wrap()
|
val new = Nd4j.create(*shape).wrap()
|
||||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(idx, this[idx]) }
|
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(idx, this[idx]) }
|
||||||
return new
|
return new
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun combine(
|
override fun zip(
|
||||||
a: StructureND<T>,
|
left: StructureND<T>,
|
||||||
b: StructureND<T>,
|
right: StructureND<T>,
|
||||||
transform: C.(T, T) -> T,
|
transform: C.(T, T) -> T,
|
||||||
): Nd4jArrayStructure<T> {
|
): Nd4jArrayStructure<T> {
|
||||||
val new = Nd4j.create(*shape).wrap()
|
require(left.shape.contentEquals(right.shape)) { "Can't zip tow structures of shape ${left.shape} and ${right.shape}" }
|
||||||
new.indicesIterator().forEach { idx -> new[idx] = elementContext.transform(a[idx], b[idx]) }
|
val new = Nd4j.create(*left.shape).wrap()
|
||||||
|
new.indicesIterator().forEach { idx -> new[idx] = elementAlgebra.transform(left[idx], right[idx]) }
|
||||||
return new
|
return new
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -76,10 +70,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
|
|||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param S the type of space of structure elements.
|
* @param S the type of space of structure elements.
|
||||||
*/
|
*/
|
||||||
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
|
public sealed interface Nd4jArrayGroupOps<T, out S : Ring<T>> : GroupOpsND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||||
|
|
||||||
override val zero: Nd4jArrayStructure<T>
|
|
||||||
get() = Nd4j.zeros(*shape).wrap()
|
|
||||||
|
|
||||||
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun add(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.add(b.ndArray).wrap()
|
a.ndArray.add(b.ndArray).wrap()
|
||||||
@ -101,10 +92,7 @@ public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4j
|
|||||||
* @param R the type of ring of structure elements.
|
* @param R the type of ring of structure elements.
|
||||||
*/
|
*/
|
||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> {
|
public sealed interface Nd4jArrayRingOps<T, out R : Ring<T>> : RingOpsND<T, R>, Nd4jArrayGroupOps<T, R> {
|
||||||
|
|
||||||
override val one: Nd4jArrayStructure<T>
|
|
||||||
get() = Nd4j.ones(*shape).wrap()
|
|
||||||
|
|
||||||
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun multiply(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.mul(b.ndArray).wrap()
|
a.ndArray.mul(b.ndArray).wrap()
|
||||||
@ -125,21 +113,12 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
|
||||||
ThreadLocal.withInitial(::HashMap)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
|
|
||||||
*/
|
|
||||||
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
|
|
||||||
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a most suitable implementation of [RingND] using reified class.
|
* Creates a most suitable implementation of [RingND] using reified class.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when {
|
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRingOps<T, Ring<T>> = when {
|
||||||
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>>
|
T::class == Int::class -> IntRing.nd4j as Nd4jArrayRingOps<T, Ring<T>>
|
||||||
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,38 +130,21 @@ public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jAr
|
|||||||
* @param T the type of the element contained in ND structure.
|
* @param T the type of the element contained in ND structure.
|
||||||
* @param F the type field of structure elements.
|
* @param F the type field of structure elements.
|
||||||
*/
|
*/
|
||||||
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
|
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldOpsND<T, F>, Nd4jArrayRingOps<T, F> {
|
||||||
|
|
||||||
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||||
a.ndArray.div(b.ndArray).wrap()
|
a.ndArray.div(b.ndArray).wrap()
|
||||||
|
|
||||||
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
|
public operator fun Number.div(b: StructureND<T>): Nd4jArrayStructure<T> = b.ndArray.rdiv(this).wrap()
|
||||||
|
|
||||||
public companion object {
|
public companion object {
|
||||||
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
|
|
||||||
ThreadLocal.withInitial(::HashMap)
|
|
||||||
|
|
||||||
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
|
|
||||||
ThreadLocal.withInitial(::HashMap)
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
|
|
||||||
*/
|
|
||||||
public fun float(vararg shape: Int): Nd4jArrayRing<Float, FloatField> =
|
|
||||||
floatNd4jArrayFieldCache.get().getOrPut(shape) { FloatNd4jArrayField(shape) }
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an [FieldND] for [Double] values or pull it from cache if it was created previously.
|
|
||||||
*/
|
|
||||||
public fun real(vararg shape: Int): Nd4jArrayRing<Double, DoubleField> =
|
|
||||||
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a most suitable implementation of [FieldND] using reified class.
|
* Creates a most suitable implementation of [FieldND] using reified class.
|
||||||
*/
|
*/
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
|
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
|
||||||
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>>
|
T::class == Float::class -> FloatField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||||
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>>
|
T::class == Double::class -> DoubleField.nd4j as Nd4jArrayField<T, Field<T>>
|
||||||
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -191,8 +153,9 @@ public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4
|
|||||||
/**
|
/**
|
||||||
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
|
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
|
||||||
*/
|
*/
|
||||||
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>,
|
public sealed interface Nd4jArrayExtendedFieldOps<T, out F : ExtendedField<T>> :
|
||||||
Nd4jArrayField<T, F> {
|
ExtendedFieldOps<StructureND<T>>, Nd4jArrayField<T, F> {
|
||||||
|
|
||||||
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
|
override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
|
||||||
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
|
override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
|
||||||
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
|
override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
|
||||||
@ -221,63 +184,59 @@ public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : Ex
|
|||||||
/**
|
/**
|
||||||
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
||||||
*/
|
*/
|
||||||
public class DoubleNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> {
|
public open class DoubleNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Double, DoubleField> {
|
||||||
override val elementContext: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override val StructureND<Double>.ndArray: INDArray
|
override val StructureND<Double>.ndArray: INDArray
|
||||||
get() = when (this) {
|
get() = when (this) {
|
||||||
is Nd4jArrayStructure<Double> -> checkShape(ndArray)
|
is Nd4jArrayStructure<Double> -> ndArray
|
||||||
else -> Nd4j.zeros(*shape).also {
|
else -> Nd4j.zeros(*shape).also {
|
||||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> {
|
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> = a.ndArray.mul(value).wrap()
|
||||||
return a.ndArray.mul(value).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> {
|
override operator fun StructureND<Double>.div(arg: Double): Nd4jArrayStructure<Double> = ndArray.div(arg).wrap()
|
||||||
return ndArray.div(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> {
|
override operator fun StructureND<Double>.plus(arg: Double): Nd4jArrayStructure<Double> = ndArray.add(arg).wrap()
|
||||||
return ndArray.add(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> {
|
override operator fun StructureND<Double>.minus(arg: Double): Nd4jArrayStructure<Double> = ndArray.sub(arg).wrap()
|
||||||
return ndArray.sub(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> {
|
override operator fun StructureND<Double>.times(arg: Double): Nd4jArrayStructure<Double> = ndArray.mul(arg).wrap()
|
||||||
return ndArray.mul(arg).wrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
override operator fun Double.div(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
|
||||||
return arg.ndArray.rdiv(this).wrap()
|
arg.ndArray.rdiv(this).wrap()
|
||||||
}
|
|
||||||
|
|
||||||
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> =
|
||||||
return arg.ndArray.rsub(this).wrap()
|
arg.ndArray.rsub(this).wrap()
|
||||||
}
|
|
||||||
|
public companion object : DoubleNd4jArrayFieldOps()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun DoubleField.nd4j(vararg shape: Int): DoubleNd4jArrayField = DoubleNd4jArrayField(intArrayOf(*shape))
|
public val DoubleField.nd4j: DoubleNd4jArrayFieldOps get() = DoubleNd4jArrayFieldOps
|
||||||
|
|
||||||
|
public class DoubleNd4jArrayField(override val shape: Shape) : DoubleNd4jArrayFieldOps(), FieldND<Double, DoubleField>
|
||||||
|
|
||||||
|
public fun DoubleField.nd4j(shapeFirst: Int, vararg shapeRest: Int): DoubleNd4jArrayField =
|
||||||
|
DoubleNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
||||||
*/
|
*/
|
||||||
public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> {
|
public open class FloatNd4jArrayFieldOps : Nd4jArrayExtendedFieldOps<Float, FloatField> {
|
||||||
override val elementContext: FloatField get() = FloatField
|
override val elementAlgebra: FloatField get() = FloatField
|
||||||
|
|
||||||
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
override fun INDArray.wrap(): Nd4jArrayStructure<Float> = asFloatStructure()
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override val StructureND<Float>.ndArray: INDArray
|
override val StructureND<Float>.ndArray: INDArray
|
||||||
get() = when (this) {
|
get() = when (this) {
|
||||||
is Nd4jArrayStructure<Float> -> checkShape(ndArray)
|
is Nd4jArrayStructure<Float> -> ndArray
|
||||||
else -> Nd4j.zeros(*shape).also {
|
else -> Nd4j.zeros(*shape).also {
|
||||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||||
}
|
}
|
||||||
@ -303,21 +262,29 @@ public class FloatNd4jArrayField(override val shape: IntArray) : Nd4jArrayExtend
|
|||||||
|
|
||||||
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
||||||
arg.ndArray.rsub(this).wrap()
|
arg.ndArray.rsub(this).wrap()
|
||||||
|
|
||||||
|
public companion object : FloatNd4jArrayFieldOps()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class FloatNd4jArrayField(override val shape: Shape) : FloatNd4jArrayFieldOps(), RingND<Float, FloatField>
|
||||||
|
|
||||||
|
public val FloatField.nd4j: FloatNd4jArrayFieldOps get() = FloatNd4jArrayFieldOps
|
||||||
|
|
||||||
|
public fun FloatField.nd4j(shapeFirst: Int, vararg shapeRest: Int): FloatNd4jArrayField =
|
||||||
|
FloatNd4jArrayField(intArrayOf(shapeFirst, * shapeRest))
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents [RingND] over [Nd4jArrayIntStructure].
|
* Represents [RingND] over [Nd4jArrayIntStructure].
|
||||||
*/
|
*/
|
||||||
public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int, IntRing> {
|
public open class IntNd4jArrayRingOps : Nd4jArrayRingOps<Int, IntRing> {
|
||||||
override val elementContext: IntRing
|
override val elementAlgebra: IntRing get() = IntRing
|
||||||
get() = IntRing
|
|
||||||
|
|
||||||
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
|
override fun INDArray.wrap(): Nd4jArrayStructure<Int> = asIntStructure()
|
||||||
|
|
||||||
@OptIn(PerformancePitfall::class)
|
@OptIn(PerformancePitfall::class)
|
||||||
override val StructureND<Int>.ndArray: INDArray
|
override val StructureND<Int>.ndArray: INDArray
|
||||||
get() = when (this) {
|
get() = when (this) {
|
||||||
is Nd4jArrayStructure<Int> -> checkShape(ndArray)
|
is Nd4jArrayStructure<Int> -> ndArray
|
||||||
else -> Nd4j.zeros(*shape).also {
|
else -> Nd4j.zeros(*shape).also {
|
||||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||||
}
|
}
|
||||||
@ -334,4 +301,13 @@ public class IntNd4jArrayRing(override val shape: IntArray) : Nd4jArrayRing<Int,
|
|||||||
|
|
||||||
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
|
override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
|
||||||
arg.ndArray.rsub(this).wrap()
|
arg.ndArray.rsub(this).wrap()
|
||||||
|
|
||||||
|
public companion object : IntNd4jArrayRingOps()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public val IntRing.nd4j: IntNd4jArrayRingOps get() = IntNd4jArrayRingOps
|
||||||
|
|
||||||
|
public class IntNd4jArrayRing(override val shape: Shape) : IntNd4jArrayRingOps(), RingND<Int, IntRing>
|
||||||
|
|
||||||
|
public fun IntRing.nd4j(shapeFirst: Int, vararg shapeRest: Int): IntNd4jArrayRing =
|
||||||
|
IntNd4jArrayRing(intArrayOf(shapeFirst, * shapeRest))
|
@ -8,6 +8,10 @@ package space.kscience.kmath.nd4j
|
|||||||
import org.nd4j.linalg.factory.Nd4j
|
import org.nd4j.linalg.factory.Nd4j
|
||||||
import space.kscience.kmath.misc.PerformancePitfall
|
import space.kscience.kmath.misc.PerformancePitfall
|
||||||
import space.kscience.kmath.nd.StructureND
|
import space.kscience.kmath.nd.StructureND
|
||||||
|
import space.kscience.kmath.nd.one
|
||||||
|
import space.kscience.kmath.nd.produce
|
||||||
|
import space.kscience.kmath.operations.DoubleField
|
||||||
|
import space.kscience.kmath.operations.IntRing
|
||||||
import space.kscience.kmath.operations.invoke
|
import space.kscience.kmath.operations.invoke
|
||||||
import kotlin.math.PI
|
import kotlin.math.PI
|
||||||
import kotlin.test.Test
|
import kotlin.test.Test
|
||||||
@ -19,7 +23,7 @@ import kotlin.test.fail
|
|||||||
internal class Nd4jArrayAlgebraTest {
|
internal class Nd4jArrayAlgebraTest {
|
||||||
@Test
|
@Test
|
||||||
fun testProduce() {
|
fun testProduce() {
|
||||||
val res = with(DoubleNd4jArrayField(intArrayOf(2, 2))) { produce { it.sum().toDouble() } }
|
val res = DoubleField.nd4j.produce(2, 2) { it.sum().toDouble() }
|
||||||
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
|
||||||
expected[intArrayOf(0, 0)] = 0.0
|
expected[intArrayOf(0, 0)] = 0.0
|
||||||
expected[intArrayOf(0, 1)] = 1.0
|
expected[intArrayOf(0, 1)] = 1.0
|
||||||
@ -30,7 +34,9 @@ internal class Nd4jArrayAlgebraTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMap() {
|
fun testMap() {
|
||||||
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one.map { it + it * 2 } }
|
val res = IntRing.nd4j {
|
||||||
|
one(2, 2).map { it + it * 2 }
|
||||||
|
}
|
||||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
expected[intArrayOf(0, 0)] = 3
|
expected[intArrayOf(0, 0)] = 3
|
||||||
expected[intArrayOf(0, 1)] = 3
|
expected[intArrayOf(0, 1)] = 3
|
||||||
@ -41,7 +47,7 @@ internal class Nd4jArrayAlgebraTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testAdd() {
|
fun testAdd() {
|
||||||
val res = with(IntNd4jArrayRing(intArrayOf(2, 2))) { one + 25 }
|
val res = IntRing.nd4j { one(2, 2) + 25 }
|
||||||
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
val expected = (Nd4j.create(2, 2) ?: fail()).asIntStructure()
|
||||||
expected[intArrayOf(0, 0)] = 26
|
expected[intArrayOf(0, 0)] = 26
|
||||||
expected[intArrayOf(0, 1)] = 26
|
expected[intArrayOf(0, 1)] = 26
|
||||||
@ -51,10 +57,10 @@ internal class Nd4jArrayAlgebraTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
|
fun testSin() = DoubleField.nd4j{
|
||||||
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 }
|
val initial = produce(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
|
||||||
val transformed = sin(initial)
|
val transformed = sin(initial)
|
||||||
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
|
val expected = produce(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||||
|
|
||||||
println(transformed)
|
println(transformed)
|
||||||
assertTrue { StructureND.contentEquals(transformed, expected) }
|
assertTrue { StructureND.contentEquals(transformed, expected) }
|
||||||
|
@ -64,8 +64,8 @@ public fun MST.toIExpr(): IExpr = when (this) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
is MST.Unary -> when (operation) {
|
is MST.Unary -> when (operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> value.toIExpr()
|
GroupOps.PLUS_OPERATION -> value.toIExpr()
|
||||||
GroupOperations.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
GroupOps.MINUS_OPERATION -> F.Negate(value.toIExpr())
|
||||||
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
|
TrigonometricOperations.SIN_OPERATION -> F.Sin(value.toIExpr())
|
||||||
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
|
TrigonometricOperations.COS_OPERATION -> F.Cos(value.toIExpr())
|
||||||
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
|
TrigonometricOperations.TAN_OPERATION -> F.Tan(value.toIExpr())
|
||||||
@ -85,10 +85,10 @@ public fun MST.toIExpr(): IExpr = when (this) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
is MST.Binary -> when (operation) {
|
is MST.Binary -> when (operation) {
|
||||||
GroupOperations.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
GroupOps.PLUS_OPERATION -> left.toIExpr() + right.toIExpr()
|
||||||
GroupOperations.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
GroupOps.MINUS_OPERATION -> left.toIExpr() - right.toIExpr()
|
||||||
RingOperations.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
RingOps.TIMES_OPERATION -> left.toIExpr() * right.toIExpr()
|
||||||
FieldOperations.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
FieldOps.DIV_OPERATION -> F.Divide(left.toIExpr(), right.toIExpr())
|
||||||
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
|
PowerOperations.POW_OPERATION -> F.Power(left.toIExpr(), F.symjify((right as MST.Numeric).value))
|
||||||
else -> error("Binary operation $operation not defined in $this")
|
else -> error("Binary operation $operation not defined in $this")
|
||||||
}
|
}
|
||||||
|
@ -373,8 +373,12 @@ public open class DoubleTensorAlgebra :
|
|||||||
return resTensor
|
return resTensor
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun diagonalEmbedding(diagonalEntries: Tensor<Double>, offset: Int, dim1: Int, dim2: Int):
|
override fun diagonalEmbedding(
|
||||||
DoubleTensor {
|
diagonalEntries: Tensor<Double>,
|
||||||
|
offset: Int,
|
||||||
|
dim1: Int,
|
||||||
|
dim2: Int
|
||||||
|
): DoubleTensor {
|
||||||
val n = diagonalEntries.shape.size
|
val n = diagonalEntries.shape.size
|
||||||
val d1 = minusIndexFrom(n + 1, dim1)
|
val d1 = minusIndexFrom(n + 1, dim1)
|
||||||
val d2 = minusIndexFrom(n + 1, dim2)
|
val d2 = minusIndexFrom(n + 1, dim2)
|
||||||
|
@ -44,7 +44,7 @@ internal fun indexFromOffset(offset: Int, strides: IntArray, nDim: Int): IntArra
|
|||||||
*
|
*
|
||||||
* @param shape the shape of the tensor.
|
* @param shape the shape of the tensor.
|
||||||
*/
|
*/
|
||||||
internal class TensorLinearStructure(override val shape: IntArray) : Strides {
|
internal class TensorLinearStructure(override val shape: IntArray) : Strides() {
|
||||||
override val strides: IntArray
|
override val strides: IntArray
|
||||||
get() = stridesFromShape(shape)
|
get() = stridesFromShape(shape)
|
||||||
|
|
||||||
@ -54,4 +54,18 @@ internal class TensorLinearStructure(override val shape: IntArray) : Strides {
|
|||||||
override val linearSize: Int
|
override val linearSize: Int
|
||||||
get() = shape.reduce(Int::times)
|
get() = shape.reduce(Int::times)
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean {
|
||||||
|
if (this === other) return true
|
||||||
|
if (other == null || this::class != other::class) return false
|
||||||
|
|
||||||
|
other as TensorLinearStructure
|
||||||
|
|
||||||
|
if (!shape.contentEquals(other.shape)) return false
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
return shape.contentHashCode()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,8 +26,11 @@ internal fun <T> Tensor<T>.copyToBufferedTensor(): BufferedTensor<T> =
|
|||||||
|
|
||||||
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
internal fun <T> Tensor<T>.toBufferedTensor(): BufferedTensor<T> = when (this) {
|
||||||
is BufferedTensor<T> -> this
|
is BufferedTensor<T> -> this
|
||||||
is MutableBufferND<T> -> if (this.strides.strides contentEquals TensorLinearStructure(this.shape).strides)
|
is MutableBufferND<T> -> if (this.indexes == TensorLinearStructure(this.shape)) {
|
||||||
BufferedTensor(this.shape, this.mutableBuffer, 0) else this.copyToBufferedTensor()
|
BufferedTensor(this.shape, this.mutableBuffer, 0)
|
||||||
|
} else {
|
||||||
|
this.copyToBufferedTensor()
|
||||||
|
}
|
||||||
else -> this.copyToBufferedTensor()
|
else -> this.copyToBufferedTensor()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import space.kscience.kmath.misc.UnstableKMathAPI
|
|||||||
import space.kscience.kmath.nd.*
|
import space.kscience.kmath.nd.*
|
||||||
import space.kscience.kmath.operations.DoubleField
|
import space.kscience.kmath.operations.DoubleField
|
||||||
import space.kscience.kmath.operations.ExtendedField
|
import space.kscience.kmath.operations.ExtendedField
|
||||||
import space.kscience.kmath.operations.NumbersAddOperations
|
import space.kscience.kmath.operations.NumbersAddOps
|
||||||
import space.kscience.kmath.operations.ScaleOperations
|
import space.kscience.kmath.operations.ScaleOperations
|
||||||
|
|
||||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
@ -34,7 +34,7 @@ public fun F64Array.asStructure(): ViktorStructureND = ViktorStructureND(this)
|
|||||||
@OptIn(UnstableKMathAPI::class)
|
@OptIn(UnstableKMathAPI::class)
|
||||||
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
@Suppress("OVERRIDE_BY_INLINE", "NOTHING_TO_INLINE")
|
||||||
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, DoubleField>,
|
||||||
NumbersAddOperations<StructureND<Double>>, ExtendedField<StructureND<Double>>,
|
NumbersAddOps<StructureND<Double>>, ExtendedField<StructureND<Double>>,
|
||||||
ScaleOperations<StructureND<Double>> {
|
ScaleOperations<StructureND<Double>> {
|
||||||
|
|
||||||
public val StructureND<Double>.f64Buffer: F64Array
|
public val StructureND<Double>.f64Buffer: F64Array
|
||||||
@ -44,7 +44,7 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
|
|||||||
shape
|
shape
|
||||||
)
|
)
|
||||||
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
|
this is ViktorStructureND && this.f64Buffer.shape.contentEquals(this@ViktorFieldND.shape) -> this.f64Buffer
|
||||||
else -> produce { this@f64Buffer[it] }.f64Buffer
|
else -> produce(shape) { this@f64Buffer[it] }.f64Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
|
override val zero: ViktorStructureND by lazy { F64Array.full(init = 0.0, shape = shape).asStructure() }
|
||||||
@ -52,9 +52,9 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
|
|||||||
|
|
||||||
private val strides: Strides = DefaultStrides(shape)
|
private val strides: Strides = DefaultStrides(shape)
|
||||||
|
|
||||||
override val elementContext: DoubleField get() = DoubleField
|
override val elementAlgebra: DoubleField get() = DoubleField
|
||||||
|
|
||||||
override fun produce(initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
override fun produce(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
|
||||||
F64Array(*shape).apply {
|
F64Array(*shape).apply {
|
||||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||||
set(value = DoubleField.initializer(index), indices = index)
|
set(value = DoubleField.initializer(index), indices = index)
|
||||||
@ -78,13 +78,13 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
|
|||||||
}
|
}
|
||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
override fun combine(
|
override fun zip(
|
||||||
a: StructureND<Double>,
|
left: StructureND<Double>,
|
||||||
b: StructureND<Double>,
|
right: StructureND<Double>,
|
||||||
transform: DoubleField.(Double, Double) -> Double,
|
transform: DoubleField.(Double, Double) -> Double,
|
||||||
): ViktorStructureND = F64Array(*shape).apply {
|
): ViktorStructureND = F64Array(*shape).apply {
|
||||||
this@ViktorFieldND.strides.indices().forEach { index ->
|
this@ViktorFieldND.strides.indices().forEach { index ->
|
||||||
set(value = DoubleField.transform(a[index], b[index]), indices = index)
|
set(value = DoubleField.transform(left[index], right[index]), indices = index)
|
||||||
}
|
}
|
||||||
}.asStructure()
|
}.asStructure()
|
||||||
|
|
||||||
@ -123,4 +123,4 @@ public class ViktorFieldND(override val shape: IntArray) : FieldND<Double, Doubl
|
|||||||
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
|
override fun ln(arg: StructureND<Double>): ViktorStructureND = arg.f64Buffer.log().asStructure()
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun ViktorNDField(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
|
public fun ViktorFieldND(vararg shape: Int): ViktorFieldND = ViktorFieldND(shape)
|
||||||
|
@ -1,16 +1,18 @@
|
|||||||
pluginManagement {
|
pluginManagement {
|
||||||
repositories {
|
repositories {
|
||||||
mavenLocal()
|
|
||||||
maven("https://repo.kotlin.link")
|
maven("https://repo.kotlin.link")
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
gradlePluginPortal()
|
gradlePluginPortal()
|
||||||
}
|
}
|
||||||
|
|
||||||
val kotlinVersion = "1.6.0-M1"
|
val kotlinVersion = "1.6.0-RC"
|
||||||
|
val toolsVersion = "0.10.5"
|
||||||
|
|
||||||
plugins {
|
plugins {
|
||||||
id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
|
id("org.jetbrains.kotlinx.benchmark") version "0.3.1"
|
||||||
id("ru.mipt.npm.gradle.project") version "0.10.5"
|
id("ru.mipt.npm.gradle.project") version toolsVersion
|
||||||
|
id("ru.mipt.npm.gradle.jvm") version toolsVersion
|
||||||
|
id("ru.mipt.npm.gradle.mpp") version toolsVersion
|
||||||
kotlin("multiplatform") version kotlinVersion
|
kotlin("multiplatform") version kotlinVersion
|
||||||
kotlin("plugin.allopen") version kotlinVersion
|
kotlin("plugin.allopen") version kotlinVersion
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user