Add multik tensor factories and benchmarks

This commit is contained in:
Alexander Nozik 2021-10-18 11:35:09 +03:00
parent 827f115a92
commit a81ab474f7
23 changed files with 84 additions and 57 deletions

View File

@ -9,7 +9,12 @@ import kotlinx.benchmark.Benchmark
import kotlinx.benchmark.Blackhole import kotlinx.benchmark.Blackhole
import kotlinx.benchmark.Scope import kotlinx.benchmark.Scope
import kotlinx.benchmark.State import kotlinx.benchmark.State
import org.jetbrains.kotlinx.multik.api.Multik
import org.jetbrains.kotlinx.multik.api.ones
import org.jetbrains.kotlinx.multik.ndarray.data.DN
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
import space.kscience.kmath.multik.multikND import space.kscience.kmath.multik.multikND
import space.kscience.kmath.multik.multikTensorAlgebra
import space.kscience.kmath.nd.BufferedFieldOpsND import space.kscience.kmath.nd.BufferedFieldOpsND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
@ -73,6 +78,13 @@ internal class NDFieldBenchmark {
blackhole.consume(res) blackhole.consume(res)
} }
@Benchmark
fun multikInPlaceAdd(blackhole: Blackhole) = with(DoubleField.multikTensorAlgebra) {
val res = Multik.ones<Double, DN>(shape, DataType.DoubleDataType).wrap()
repeat(n) { res += 1.0 }
blackhole.consume(res)
}
// @Benchmark // @Benchmark
// fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) { // fun nd4jAdd(blackhole: Blackhole) = with(nd4jField) {
// var res: StructureND<Double> = one(dim, dim) // var res: StructureND<Double> = one(dim, dim)

View File

@ -23,7 +23,7 @@ internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun realFieldLog(blackhole: Blackhole) { fun realFieldLog(blackhole: Blackhole) {
with(realField) { with(realField) {
val fortyTwo = produce(shape) { 42.0 } val fortyTwo = structureND(shape) { 42.0 }
var res = one(shape) var res = one(shape)
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)
@ -33,7 +33,7 @@ internal class ViktorLogBenchmark {
@Benchmark @Benchmark
fun viktorFieldLog(blackhole: Blackhole) { fun viktorFieldLog(blackhole: Blackhole) {
with(viktorField) { with(viktorField) {
val fortyTwo = produce(shape) { 42.0 } val fortyTwo = structureND(shape) { 42.0 }
var res = one var res = one
repeat(n) { res = ln(fortyTwo) } repeat(n) { res = ln(fortyTwo) }
blackhole.consume(res) blackhole.consume(res)

View File

@ -9,7 +9,7 @@ import space.kscience.kmath.integration.gaussIntegrator
import space.kscience.kmath.integration.integrate import space.kscience.kmath.integration.integrate
import space.kscience.kmath.integration.value import space.kscience.kmath.integration.value
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.structureND
import space.kscience.kmath.nd.withNdAlgebra import space.kscience.kmath.nd.withNdAlgebra
import space.kscience.kmath.operations.algebra import space.kscience.kmath.operations.algebra
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -18,7 +18,7 @@ fun main(): Unit = Double.algebra {
withNdAlgebra(2, 2) { withNdAlgebra(2, 2) {
//Produce a diagonal StructureND //Produce a diagonal StructureND
fun diagonal(v: Double) = produce { (i, j) -> fun diagonal(v: Double) = structureND { (i, j) ->
if (i == j) v else 0.0 if (i == j) v else 0.0
} }

View File

@ -11,7 +11,7 @@ import space.kscience.kmath.complex.bufferAlgebra
import space.kscience.kmath.complex.ndAlgebra import space.kscience.kmath.complex.ndAlgebra
import space.kscience.kmath.nd.BufferND import space.kscience.kmath.nd.BufferND
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.structureND
fun main() = Complex.algebra { fun main() = Complex.algebra {
val complex = 2 + 2 * i val complex = 2 + 2 * i
@ -24,14 +24,14 @@ fun main() = Complex.algebra {
println(buffer) println(buffer)
// 2d element // 2d element
val element: BufferND<Complex> = ndAlgebra.produce(2, 2) { (i, j) -> val element: BufferND<Complex> = ndAlgebra.structureND(2, 2) { (i, j) ->
Complex(i - j, i + j) Complex(i - j, i + j)
} }
println(element) println(element)
// 1d element operation // 1d element operation
val result: StructureND<Complex> = ndAlgebra{ val result: StructureND<Complex> = ndAlgebra{
val a = produce(8) { (it) -> i * it - it.toDouble() } val a = structureND(8) { (it) -> i * it - it.toDouble() }
val b = 3 val b = 3
val c = Complex(1.0, 1.0) val c = Complex(1.0, 1.0)

View File

@ -10,7 +10,7 @@ import space.kscience.kmath.viktor.ViktorStructureND
import space.kscience.kmath.viktor.viktorAlgebra import space.kscience.kmath.viktor.viktorAlgebra
fun main() { fun main() {
val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.produce(Shape(2, 2)) { (i, j) -> val viktorStructure: ViktorStructureND = DoubleField.viktorAlgebra.structureND(Shape(2, 2)) { (i, j) ->
if (i == j) 2.0 else 0.0 if (i == j) 2.0 else 0.0
} }

View File

@ -12,7 +12,7 @@ import space.kscience.kmath.linear.transpose
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.as2D import space.kscience.kmath.nd.as2D
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import kotlin.system.measureTimeMillis import kotlin.system.measureTimeMillis
@ -55,7 +55,7 @@ fun complexExample() {
val x = one * 2.5 val x = one * 2.5
operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im) operator fun Number.plus(other: Complex) = Complex(this.toDouble() + other.re, other.im)
//a structure generator specific to this context //a structure generator specific to this context
val matrix = produce { (k, l) -> k + l * i } val matrix = structureND { (k, l) -> k + l * i }
//Perform sum //Perform sum
val sum = matrix + x + 1.0 val sum = matrix + x + 1.0

View File

@ -22,12 +22,12 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
private val strides = DefaultStrides(shape) private val strides = DefaultStrides(shape)
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override val zero: BufferND<Double> by lazy { produce(shape) { zero } } override val zero: BufferND<Double> by lazy { structureND(shape) { zero } }
override val one: BufferND<Double> by lazy { produce(shape) { one } } override val one: BufferND<Double> by lazy { structureND(shape) { one } }
override fun number(value: Number): BufferND<Double> { override fun number(value: Number): BufferND<Double> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce(shape) { d } return structureND(shape) { d }
} }
private val StructureND<Double>.buffer: DoubleBuffer private val StructureND<Double>.buffer: DoubleBuffer
@ -40,7 +40,7 @@ class StreamDoubleFieldND(override val shape: IntArray) : FieldND<Double, Double
else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) } else -> DoubleBuffer(strides.linearSize) { offset -> get(strides.index(offset)) }
} }
override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> { override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): BufferND<Double> {
val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset -> val array = IntStream.range(0, strides.linearSize).parallel().mapToDouble { offset ->
val index = strides.index(offset) val index = strides.index(offset)
DoubleField.initializer(index) DoubleField.initializer(index)

View File

@ -67,7 +67,7 @@ public class ComplexFieldND(override val shape: Shape) :
override fun number(value: Number): BufferND<Complex> { override fun number(value: Number): BufferND<Complex> {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce(shape) { d.toComplex() } return structureND(shape) { d.toComplex() }
} }
} }

View File

@ -24,7 +24,7 @@ public class BufferedLinearSpace<T, out A : Ring<T>>(
private val ndAlgebra = BufferedRingOpsND(bufferAlgebra) private val ndAlgebra = BufferedRingOpsND(bufferAlgebra)
override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> = override fun buildMatrix(rows: Int, columns: Int, initializer: A.(i: Int, j: Int) -> T): Matrix<T> =
ndAlgebra.produce(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D() ndAlgebra.structureND(intArrayOf(rows, columns)) { (i, j) -> elementAlgebra.initializer(i, j) }.as2D()
override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> = override fun buildVector(size: Int, initializer: A.(Int) -> T): Point<T> =
bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) } bufferAlgebra.buffer(size) { elementAlgebra.initializer(it) }

View File

@ -23,7 +23,7 @@ public object DoubleLinearSpace : LinearSpace<Double, DoubleField> {
rows: Int, rows: Int,
columns: Int, columns: Int,
initializer: DoubleField.(i: Int, j: Int) -> Double initializer: DoubleField.(i: Int, j: Int) -> Double
): Matrix<Double> = DoubleFieldOpsND.produce(intArrayOf(rows, columns)) { (i, j) -> ): Matrix<Double> = DoubleFieldOpsND.structureND(intArrayOf(rows, columns)) { (i, j) ->
DoubleField.initializer(i, j) DoubleField.initializer(i, j)
}.as2D() }.as2D()

View File

@ -39,9 +39,9 @@ public interface AlgebraND<T, out C : Algebra<T>> {
public val elementAlgebra: C public val elementAlgebra: C
/** /**
* Produces a new NDStructure using given initializer function. * Produces a new [StructureND] using given initializer function.
*/ */
public fun produce(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T> public fun structureND(shape: Shape, initializer: C.(IntArray) -> T): StructureND<T>
/** /**
* Maps elements from one structure to another one by applying [transform] to them. * Maps elements from one structure to another one by applying [transform] to them.
@ -149,7 +149,7 @@ public interface GroupOpsND<T, out A : GroupOps<T>> : GroupOps<StructureND<T>>,
} }
public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape { public interface GroupND<T, out A : Group<T>> : Group<StructureND<T>>, GroupOpsND<T, A>, WithShape {
override val zero: StructureND<T> get() = produce(shape) { elementAlgebra.zero } override val zero: StructureND<T> get() = structureND(shape) { elementAlgebra.zero }
} }
/** /**
@ -193,7 +193,7 @@ public interface RingOpsND<T, out A : RingOps<T>> : RingOps<StructureND<T>>, Gro
} }
public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape { public interface RingND<T, out A : Ring<T>> : Ring<StructureND<T>>, RingOpsND<T, A>, GroupND<T, A>, WithShape {
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one } override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
} }
@ -240,5 +240,5 @@ public interface FieldOpsND<T, out A : Field<T>> :
} }
public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape { public interface FieldND<T, out A : Field<T>> : Field<StructureND<T>>, FieldOpsND<T, A>, RingND<T, A>, WithShape {
override val one: StructureND<T> get() = produce(shape) { elementAlgebra.one } override val one: StructureND<T> get() = structureND(shape) { elementAlgebra.one }
} }

View File

@ -16,7 +16,7 @@ public interface BufferAlgebraND<T, out A : Algebra<T>> : AlgebraND<T, A> {
public val bufferAlgebra: BufferAlgebra<T, A> public val bufferAlgebra: BufferAlgebra<T, A>
override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra override val elementAlgebra: A get() = bufferAlgebra.elementAlgebra
override fun produce(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> { override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): BufferND<T> {
val indexer = indexerBuilder(shape) val indexer = indexerBuilder(shape)
return BufferND( return BufferND(
indexer, indexer,
@ -109,14 +109,14 @@ public val <T, A : Ring<T>> BufferAlgebra<T, A>.nd: BufferedRingOpsND<T, A> get(
public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this) public val <T, A : Field<T>> BufferAlgebra<T, A>.nd: BufferedFieldOpsND<T, A> get() = BufferedFieldOpsND(this)
public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.produce( public fun <T, A : Algebra<T>> BufferAlgebraND<T, A>.structureND(
vararg shape: Int, vararg shape: Int,
initializer: A.(IntArray) -> T initializer: A.(IntArray) -> T
): BufferND<T> = produce(shape, initializer) ): BufferND<T> = structureND(shape, initializer)
public fun <T, EA : Algebra<T>, A> A.produce( public fun <T, EA : Algebra<T>, A> A.structureND(
initializer: EA.(IntArray) -> T initializer: EA.(IntArray) -> T
): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = produce(shape, initializer) ): BufferND<T> where A : BufferAlgebraND<T, EA>, A : WithShape = structureND(shape, initializer)
//// group factories //// group factories
//public fun <T, A : Group<T>> A.ndAlgebra( //public fun <T, A : Group<T>> A.ndAlgebra(

View File

@ -60,7 +60,7 @@ public sealed class DoubleFieldOpsND : BufferedFieldOpsND<Double, DoubleField>(D
transform: DoubleField.(Double, Double) -> Double transform: DoubleField.(Double, Double) -> Double
): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) } ): BufferND<Double> = zipInline(left.toBufferND(), right.toBufferND()) { l, r -> DoubleField.transform(l, r) }
override fun produce(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND { override fun structureND(shape: Shape, initializer: DoubleField.(IntArray) -> Double): DoubleBufferND {
val indexer = indexerBuilder(shape) val indexer = indexerBuilder(shape)
return DoubleBufferND( return DoubleBufferND(
indexer, indexer,
@ -174,7 +174,7 @@ public class DoubleFieldND(override val shape: Shape) :
override fun number(value: Number): DoubleBufferND { override fun number(value: Number): DoubleBufferND {
val d = value.toDouble() // minimize conversions val d = value.toDouble() // minimize conversions
return produce(shape) { d } return structureND(shape) { d }
} }
} }

View File

@ -23,7 +23,7 @@ public class ShortRingND(
override fun number(value: Number): BufferND<Short> { override fun number(value: Number): BufferND<Short> {
val d = value.toShort() // minimize conversions val d = value.toShort() // minimize conversions
return produce(shape) { d } return structureND(shape) { d }
} }
} }

View File

@ -11,24 +11,24 @@ import space.kscience.kmath.operations.Ring
import kotlin.jvm.JvmName import kotlin.jvm.JvmName
public fun <T, A : Algebra<T>> AlgebraND<T, A>.produce( public fun <T, A : Algebra<T>> AlgebraND<T, A>.structureND(
shapeFirst: Int, shapeFirst: Int,
vararg shapeRest: Int, vararg shapeRest: Int,
initializer: A.(IntArray) -> T initializer: A.(IntArray) -> T
): StructureND<T> = produce(Shape(shapeFirst, *shapeRest), initializer) ): StructureND<T> = structureND(Shape(shapeFirst, *shapeRest), initializer)
public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = produce(shape) { zero } public fun <T, A : Group<T>> AlgebraND<T, A>.zero(shape: Shape): StructureND<T> = structureND(shape) { zero }
@JvmName("zeroVarArg") @JvmName("zeroVarArg")
public fun <T, A : Group<T>> AlgebraND<T, A>.zero( public fun <T, A : Group<T>> AlgebraND<T, A>.zero(
shapeFirst: Int, shapeFirst: Int,
vararg shapeRest: Int, vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { zero } ): StructureND<T> = structureND(shapeFirst, *shapeRest) { zero }
public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = produce(shape) { one } public fun <T, A : Ring<T>> AlgebraND<T, A>.one(shape: Shape): StructureND<T> = structureND(shape) { one }
@JvmName("oneVarArg") @JvmName("oneVarArg")
public fun <T, A : Ring<T>> AlgebraND<T, A>.one( public fun <T, A : Ring<T>> AlgebraND<T, A>.one(
shapeFirst: Int, shapeFirst: Int,
vararg shapeRest: Int, vararg shapeRest: Int,
): StructureND<T> = produce(shapeFirst, *shapeRest) { one } ): StructureND<T> = structureND(shapeFirst, *shapeRest) { one }

View File

@ -7,7 +7,7 @@ package space.kscience.kmath.structures
import space.kscience.kmath.nd.get import space.kscience.kmath.nd.get
import space.kscience.kmath.nd.ndAlgebra import space.kscience.kmath.nd.ndAlgebra
import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
import space.kscience.kmath.testutils.FieldVerifier import space.kscience.kmath.testutils.FieldVerifier
@ -22,7 +22,7 @@ internal class NDFieldTest {
@Test @Test
fun testStrides() { fun testStrides() {
val ndArray = DoubleField.ndAlgebra.produce(10, 10) { (it[0] + it[1]).toDouble() } val ndArray = DoubleField.ndAlgebra.structureND(10, 10) { (it[0] + it[1]).toDouble() }
assertEquals(ndArray[5, 5], 10.0) assertEquals(ndArray[5, 5], 10.0)
} }
} }

View File

@ -20,8 +20,8 @@ import kotlin.test.assertEquals
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
class NumberNDFieldTest { class NumberNDFieldTest {
val algebra = DoubleField.ndAlgebra val algebra = DoubleField.ndAlgebra
val array1 = algebra.produce(3, 3) { (i, j) -> (i + j).toDouble() } val array1 = algebra.structureND(3, 3) { (i, j) -> (i + j).toDouble() }
val array2 = algebra.produce(3, 3) { (i, j) -> (i - j).toDouble() } val array2 = algebra.structureND(3, 3) { (i, j) -> (i - j).toDouble() }
@Test @Test
fun testSum() { fun testSum() {

View File

@ -18,9 +18,9 @@ public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
override val elementAlgebra: A override val elementAlgebra: A
) : RingOpsND<T, A> { ) : RingOpsND<T, A> {
protected fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this) public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
override fun produce(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> { override fun structureND(shape: Shape, initializer: A.(IntArray) -> T): MultikTensor<T> {
val res = mk.zeros<T, DN>(shape, type).asDNArray() val res = mk.zeros<T, DN>(shape, type).asDNArray()
for (index in res.multiIndices) { for (index in res.multiIndices) {
res[index] = elementAlgebra.initializer(index) res[index] = elementAlgebra.initializer(index)
@ -28,10 +28,10 @@ public open class MultikRingOpsND<T, A : Ring<T>> internal constructor(
return res.wrap() return res.wrap()
} }
protected fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) { public fun StructureND<T>.asMultik(): MultikTensor<T> = if (this is MultikTensor) {
this this
} else { } else {
produce(shape) { get(it) } structureND(shape) { get(it) }
} }
override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> { override fun StructureND<T>.map(transform: A.(T) -> T): MultikTensor<T> {

View File

@ -11,7 +11,7 @@ import org.jetbrains.kotlinx.multik.ndarray.data.*
import org.jetbrains.kotlinx.multik.ndarray.operations.* import org.jetbrains.kotlinx.multik.ndarray.operations.*
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.mapInPlace import space.kscience.kmath.nd.mapInPlace
import space.kscience.kmath.operations.Ring import space.kscience.kmath.operations.*
import space.kscience.kmath.tensors.api.Tensor import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra import space.kscience.kmath.tensors.api.TensorAlgebra
@ -31,7 +31,7 @@ public value class MultikTensor<T>(public val array: MutableMultiArray<T, DN>) :
} }
public abstract class MultikTensorAlgebra<T>( public class MultikTensorAlgebra<T> internal constructor(
public val type: DataType, public val type: DataType,
public val elementAlgebra: Ring<T>, public val elementAlgebra: Ring<T>,
public val comparator: Comparator<T> public val comparator: Comparator<T>
@ -41,7 +41,7 @@ public abstract class MultikTensorAlgebra<T>(
* Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor * Convert a tensor to [MultikTensor] if necessary. If tensor is converted, changes on the resulting tensor
* are not reflected back onto the source * are not reflected back onto the source
*/ */
private fun Tensor<T>.asMultik(): MultikTensor<T> { public fun Tensor<T>.asMultik(): MultikTensor<T> {
return if (this is MultikTensor) { return if (this is MultikTensor) {
this this
} else { } else {
@ -53,7 +53,7 @@ public abstract class MultikTensorAlgebra<T>(
} }
} }
private fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this) public fun MutableMultiArray<T, DN>.wrap(): MultikTensor<T> = MultikTensor(this)
override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) { override fun Tensor<T>.valueOrNull(): T? = if (shape contentEquals intArrayOf(1)) {
get(intArrayOf(0)) get(intArrayOf(0))
@ -196,4 +196,19 @@ public abstract class MultikTensorAlgebra<T>(
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> { override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): MultikTensor<T> {
TODO("Not yet implemented") TODO("Not yet implemented")
} }
} }
public val DoubleField.multikTensorAlgebra: MultikTensorAlgebra<Double>
get() = MultikTensorAlgebra(DataType.DoubleDataType, DoubleField) { o1, o2 -> o1.compareTo(o2) }
public val FloatField.multikTensorAlgebra: MultikTensorAlgebra<Float>
get() = MultikTensorAlgebra(DataType.FloatDataType, FloatField) { o1, o2 -> o1.compareTo(o2) }
public val ShortRing.multikTensorAlgebra: MultikTensorAlgebra<Short>
get() = MultikTensorAlgebra(DataType.ShortDataType, ShortRing) { o1, o2 -> o1.compareTo(o2) }
public val IntRing.multikTensorAlgebra: MultikTensorAlgebra<Int>
get() = MultikTensorAlgebra(DataType.IntDataType, IntRing) { o1, o2 -> o1.compareTo(o2) }
public val LongRing.multikTensorAlgebra: MultikTensorAlgebra<Long>
get() = MultikTensorAlgebra(DataType.LongDataType, LongRing) { o1, o2 -> o1.compareTo(o2) }

View File

@ -32,7 +32,7 @@ public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C
*/ */
public val StructureND<T>.ndArray: INDArray public val StructureND<T>.ndArray: INDArray
override fun produce(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> { override fun structureND(shape: Shape, initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap() val struct = Nd4j.create(*shape)!!.wrap()
struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) } struct.indicesIterator().forEach { struct[it] = elementAlgebra.initializer(it) }
return struct return struct

View File

@ -9,7 +9,7 @@ import org.nd4j.linalg.factory.Nd4j
import space.kscience.kmath.misc.PerformancePitfall import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.nd.one import space.kscience.kmath.nd.one
import space.kscience.kmath.nd.produce import space.kscience.kmath.nd.structureND
import space.kscience.kmath.operations.DoubleField import space.kscience.kmath.operations.DoubleField
import space.kscience.kmath.operations.IntRing import space.kscience.kmath.operations.IntRing
import space.kscience.kmath.operations.invoke import space.kscience.kmath.operations.invoke
@ -23,7 +23,7 @@ import kotlin.test.fail
internal class Nd4jArrayAlgebraTest { internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testProduce() { fun testProduce() {
val res = DoubleField.nd4j.produce(2, 2) { it.sum().toDouble() } val res = DoubleField.nd4j.structureND(2, 2) { it.sum().toDouble() }
val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure() val expected = (Nd4j.create(2, 2) ?: fail()).asDoubleStructure()
expected[intArrayOf(0, 0)] = 0.0 expected[intArrayOf(0, 0)] = 0.0
expected[intArrayOf(0, 1)] = 1.0 expected[intArrayOf(0, 1)] = 1.0
@ -58,9 +58,9 @@ internal class Nd4jArrayAlgebraTest {
@Test @Test
fun testSin() = DoubleField.nd4j{ fun testSin() = DoubleField.nd4j{
val initial = produce(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 } val initial = structureND(2, 2) { (i, j) -> if (i == j) PI / 2 else 0.0 }
val transformed = sin(initial) val transformed = sin(initial)
val expected = produce(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 } val expected = structureND(2, 2) { (i, j) -> if (i == j) 1.0 else 0.0 }
println(transformed) println(transformed)
assertTrue { StructureND.contentEquals(transformed, expected) } assertTrue { StructureND.contentEquals(transformed, expected) }

View File

@ -22,7 +22,7 @@ import kotlin.math.*
public open class DoubleTensorAlgebra : public open class DoubleTensorAlgebra :
TensorPartialDivisionAlgebra<Double>, TensorPartialDivisionAlgebra<Double>,
AnalyticTensorAlgebra<Double>, AnalyticTensorAlgebra<Double>,
LinearOpsTensorAlgebra<Double> { LinearOpsTensorAlgebra<Double>{
public companion object : DoubleTensorAlgebra() public companion object : DoubleTensorAlgebra()

View File

@ -21,12 +21,12 @@ public open class ViktorFieldOpsND :
public val StructureND<Double>.f64Buffer: F64Array public val StructureND<Double>.f64Buffer: F64Array
get() = when (this) { get() = when (this) {
is ViktorStructureND -> this.f64Buffer is ViktorStructureND -> this.f64Buffer
else -> produce(shape) { this@f64Buffer[it] }.f64Buffer else -> structureND(shape) { this@f64Buffer[it] }.f64Buffer
} }
override val elementAlgebra: DoubleField get() = DoubleField override val elementAlgebra: DoubleField get() = DoubleField
override fun produce(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND = override fun structureND(shape: IntArray, initializer: DoubleField.(IntArray) -> Double): ViktorStructureND =
F64Array(*shape).apply { F64Array(*shape).apply {
DefaultStrides(shape).indices().forEach { index -> DefaultStrides(shape).indices().forEach { index ->
set(value = DoubleField.initializer(index), indices = index) set(value = DoubleField.initializer(index), indices = index)