Merge pull request #332 from mipt-npm/commandertvis/nd4j

Nd4j based TensorAlgebra implementation
This commit is contained in:
Alexander Nozik 2021-05-20 18:15:59 +03:00 committed by GitHub
commit 958788bc91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 280 additions and 166 deletions

View File

@ -227,7 +227,6 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
* Field of [StructureND].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F>, ScaleOperations<StructureND<T>> {

View File

@ -4,7 +4,7 @@ plugins {
}
dependencies {
api(project(":kmath-core"))
api(project(":kmath-tensors"))
api("org.nd4j:nd4j-api:1.0.0-beta7")
testImplementation("org.nd4j:nd4j-native:1.0.0-beta7")
testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
@ -15,19 +15,7 @@ readme {
description = "ND4J NDStructure implementation and according NDAlgebra classes"
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
feature(
id = "nd4jarraystructure",
description = "NDStructure wrapper for INDArray"
)
feature(
id = "nd4jarrayrings",
description = "Rings over Nd4jArrayStructure of Int and Long"
)
feature(
id = "nd4jarrayfields",
description = "Fields over Nd4jArrayStructure of Float and Double"
)
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }
feature(id = "nd4jarrayrings") { "Rings over Nd4jArrayStructure of Int and Long" }
feature(id = "nd4jarrayfields") { "Fields over Nd4jArrayStructure of Float and Double" }
}

View File

@ -6,15 +6,14 @@
package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.scalar.Pow
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.misc.UnstableKMathAPI
import space.kscience.kmath.nd.*
import space.kscience.kmath.operations.*
import space.kscience.kmath.structures.*
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
val arrayShape = array.shape().toIntArray()
@ -29,23 +28,16 @@ internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
* @param T the type of ND-structure element.
* @param C the type of the element context.
*/
public interface Nd4jArrayAlgebra<T, C : Algebra<T>> : AlgebraND<T, C> {
public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C> {
/**
* Wraps [INDArray] to [N].
* Wraps [INDArray] to [Nd4jArrayStructure].
*/
public fun INDArray.wrap(): Nd4jArrayStructure<T>
/**
* Unwraps to or acquires [INDArray] from [StructureND].
*/
public val StructureND<T>.ndArray: INDArray
get() = when {
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
this@Nd4jArrayAlgebra.shape,
shape
)
this is Nd4jArrayStructure -> ndArray //TODO check strides
else -> {
TODO()
}
}
public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
val struct = Nd4j.create(*shape)!!.wrap()
@ -85,7 +77,7 @@ public interface Nd4jArrayAlgebra<T, C : Algebra<T>> : AlgebraND<T, C> {
* @param T the type of the element contained in ND structure.
* @param S the type of space of structure elements.
*/
public interface Nd4JArrayGroup<T, S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
public override val zero: Nd4jArrayStructure<T>
get() = Nd4j.zeros(*shape).wrap()
@ -110,7 +102,7 @@ public interface Nd4JArrayGroup<T, S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebr
* @param R the type of ring of structure elements.
*/
@OptIn(UnstableKMathAPI::class)
public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T, R> {
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> {
public override val one: Nd4jArrayStructure<T>
get() = Nd4j.ones(*shape).wrap()
@ -135,10 +127,7 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
public companion object {
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
ThreadLocal.withInitial { hashMapOf() }
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
@ -146,20 +135,13 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
/**
* Creates an [RingND] for [Long] values or pull it from cache if it was created previously.
*/
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
/**
* Creates a most suitable implementation of [RingND] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when {
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>>
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
}
}
}
@ -168,11 +150,9 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
* Represents [FieldND] over [Nd4jArrayStructure].
*
* @param T the type of the element contained in ND structure.
* @param N the type of ND structure.
* @param F the type field of structure elements.
*/
public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
public override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
a.ndArray.div(b.ndArray).wrap()
@ -180,10 +160,10 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
public companion object {
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
ThreadLocal.withInitial(::HashMap)
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
ThreadLocal.withInitial { hashMapOf() }
ThreadLocal.withInitial(::HashMap)
/**
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
@ -198,26 +178,64 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
/**
* Creates a most suitable implementation of [RingND] using reified class.
* Creates a most suitable implementation of [FieldND] using reified class.
*/
@Suppress("UNCHECKED_CAST")
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>>
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>>
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
}
}
}
/**
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
*/
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>,
Nd4jArrayField<T, F> {
public override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
public override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
public override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
public override fun acos(arg: StructureND<T>): StructureND<T> = Transforms.acos(arg.ndArray).wrap()
public override fun atan(arg: StructureND<T>): StructureND<T> = Transforms.atan(arg.ndArray).wrap()
public override fun power(arg: StructureND<T>, pow: Number): StructureND<T> =
Transforms.pow(arg.ndArray, pow).wrap()
public override fun exp(arg: StructureND<T>): StructureND<T> = Transforms.exp(arg.ndArray).wrap()
public override fun ln(arg: StructureND<T>): StructureND<T> = Transforms.log(arg.ndArray).wrap()
public override fun sqrt(arg: StructureND<T>): StructureND<T> = Transforms.sqrt(arg.ndArray).wrap()
public override fun sinh(arg: StructureND<T>): StructureND<T> = Transforms.sinh(arg.ndArray).wrap()
public override fun cosh(arg: StructureND<T>): StructureND<T> = Transforms.cosh(arg.ndArray).wrap()
public override fun tanh(arg: StructureND<T>): StructureND<T> = Transforms.tanh(arg.ndArray).wrap()
public override fun asinh(arg: StructureND<T>): StructureND<T> =
Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap()
public override fun acosh(arg: StructureND<T>): StructureND<T> =
Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap()
public override fun atanh(arg: StructureND<T>): StructureND<T> = Transforms.atanh(arg.ndArray).wrap()
}
/**
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
*/
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
ExtendedField<StructureND<Double>> {
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> {
public override val elementContext: DoubleField get() = DoubleField
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
@OptIn(PerformancePitfall::class)
override val StructureND<Double>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Double> -> checkShape(ndArray)
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
}
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> {
return a.ndArray.mul(value).wrap()
}
@ -245,34 +263,25 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
return arg.ndArray.rsub(this).wrap()
}
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
Transforms.pow(arg.ndArray,pow).wrap()
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
}
/**
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
*/
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
ExtendedField<StructureND<Float>> {
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> {
public override val elementContext: FloatField get() = FloatField
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
@OptIn(PerformancePitfall::class)
public override val StructureND<Float>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Float> -> checkShape(ndArray)
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
}
override fun scale(a: StructureND<Float>, value: Double): StructureND<Float> =
a.ndArray.mul(value).wrap()
@ -293,23 +302,6 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
arg.ndArray.rsub(this).wrap()
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
Pow(arg.ndArray, pow.toDouble()).z().wrap()
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
}
/**
@ -321,6 +313,15 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
@OptIn(PerformancePitfall::class)
public override val StructureND<Int>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Int> -> checkShape(ndArray)
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
}
public override operator fun StructureND<Int>.plus(arg: Int): Nd4jArrayStructure<Int> =
ndArray.add(arg).wrap()
@ -333,25 +334,3 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
public override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
arg.ndArray.rsub(this).wrap()
}
/**
* Represents [RingND] over [Nd4jArrayStructure] of [Long].
*/
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Long, LongRing> {
public override val elementContext: LongRing
get() = LongRing
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
public override operator fun StructureND<Long>.plus(arg: Long): Nd4jArrayStructure<Long> =
ndArray.add(arg).wrap()
public override operator fun StructureND<Long>.minus(arg: Long): Nd4jArrayStructure<Long> =
ndArray.sub(arg).wrap()
public override operator fun StructureND<Long>.times(arg: Long): Nd4jArrayStructure<Long> =
ndArray.mul(arg).wrap()
public override operator fun Long.minus(arg: StructureND<Long>): Nd4jArrayStructure<Long> =
arg.ndArray.rsub(this).wrap()
}

View File

@ -48,12 +48,6 @@ private class Nd4jArrayDoubleIterator(iterateOver: INDArray) : Nd4jArrayIterator
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = Nd4jArrayDoubleIterator(this)
private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Long>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
}
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = Nd4jArrayLongIterator(this)
private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Int>(iterateOver) {
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
}

View File

@ -17,7 +17,8 @@ import space.kscience.kmath.nd.StructureND
*/
public sealed class Nd4jArrayStructure<T> : MutableStructureND<T> {
/**
* The wrapped [INDArray].
* The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming that the size of [INDArray] is less or equal to
* [Int.MAX_VALUE].
*/
public abstract val ndArray: INDArray
@ -25,6 +26,7 @@ public sealed class Nd4jArrayStructure<T> : MutableStructureND<T> {
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
@PerformancePitfall
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
}
@ -40,17 +42,6 @@ private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jA
*/
public fun INDArray.asIntStructure(): Nd4jArrayStructure<Int> = Nd4jArrayIntStructure(this)
private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Long>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
}
/**
* Wraps this [INDArray] to [Nd4jArrayStructure].
*/
public fun INDArray.asLongStructure(): Nd4jArrayStructure<Long> = Nd4jArrayLongStructure(this)
private data class Nd4jArrayDoubleStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Double>() {
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
override fun get(index: IntArray): Double = ndArray.getDouble(*index)

View File

@ -0,0 +1,175 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.api.ops.impl.summarystats.Variance
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.factory.ops.NDBase
import org.nd4j.linalg.ops.transforms.Transforms
import space.kscience.kmath.misc.PerformancePitfall
import space.kscience.kmath.nd.StructureND
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.api.TensorAlgebra
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
/**
* ND4J based [TensorAlgebra] implementation.
*/
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> {
/**
* Wraps [INDArray] to [Nd4jArrayStructure].
*/
public fun INDArray.wrap(): Nd4jArrayStructure<T>
/**
* Unwraps to or acquires [INDArray] from [StructureND].
*/
public val StructureND<T>.ndArray: INDArray
public override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap()
public override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap()
public override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap()
public override fun Tensor<T>.plusAssign(value: T) {
ndArray.addi(value)
}
public override fun Tensor<T>.plusAssign(other: Tensor<T>) {
ndArray.addi(other.ndArray)
}
public override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap()
public override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap()
public override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap()
public override fun Tensor<T>.minusAssign(value: T) {
ndArray.rsubi(value)
}
public override fun Tensor<T>.minusAssign(other: Tensor<T>) {
ndArray.subi(other.ndArray)
}
public override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap()
public override fun Tensor<T>.times(value: T): Tensor<T> =
ndArray.mul(value).wrap()
public override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap()
public override fun Tensor<T>.timesAssign(value: T) {
ndArray.muli(value)
}
public override fun Tensor<T>.timesAssign(other: Tensor<T>) {
ndArray.mmuli(other.ndArray)
}
public override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap()
public override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap()
public override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap()
public override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap()
public override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> =
ndArray.min(keepDim, dim).wrap()
public override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> =
ndArray.sum(keepDim, dim).wrap()
public override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> =
ndArray.max(keepDim, dim).wrap()
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
public override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
public override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap()
public override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap()
public override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap()
public override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap()
public override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap()
public override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap()
public override fun Tensor<T>.acosh(): Tensor<T> =
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
public override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap()
public override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap()
public override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
public override fun Tensor<T>.asinh(): Tensor<T> =
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
public override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap()
public override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap()
public override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap()
public override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap()
public override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap()
public override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap()
public override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap()
public override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap()
public override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
public override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap()
public override fun Tensor<T>.divAssign(value: T) {
ndArray.divi(value)
}
public override fun Tensor<T>.divAssign(other: Tensor<T>) {
ndArray.divi(other.ndArray)
}
public override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> =
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
private companion object {
private val ndBase: ThreadLocal<NDBase> = ThreadLocal.withInitial(::NDBase)
}
}
/**
* [Double] specialization of [Nd4jTensorAlgebra].
*/
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
@OptIn(PerformancePitfall::class)
public override val StructureND<Double>.ndArray: INDArray
get() = when (this) {
is Nd4jArrayStructure<Double> -> ndArray
else -> Nd4j.zeros(*shape).also {
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
}
}
public override fun Tensor<Double>.valueOrNull(): Double? =
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
// TODO rewrite
@PerformancePitfall
public override fun diagonalEmbedding(
diagonalEntries: Tensor<Double>,
offset: Int,
dim1: Int,
dim2: Int,
): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2)
public override fun Tensor<Double>.sum(): Double = ndArray.sumNumber().toDouble()
public override fun Tensor<Double>.min(): Double = ndArray.minNumber().toDouble()
public override fun Tensor<Double>.max(): Double = ndArray.maxNumber().toDouble()
public override fun Tensor<Double>.mean(): Double = ndArray.meanNumber().toDouble()
public override fun Tensor<Double>.std(): Double = ndArray.stdNumber().toDouble()
public override fun Tensor<Double>.variance(): Double = ndArray.varNumber().toDouble()
}

View File

@ -5,5 +5,4 @@
package space.kscience.kmath.nd4j
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }

View File

@ -52,7 +52,7 @@ internal class Nd4jArrayAlgebraTest {
@Test
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 }
val transformed = sin(initial)
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }

View File

@ -72,7 +72,7 @@ internal class Nd4jArrayStructureTest {
@Test
fun testSet() {
val nd = Nd4j.rand(17, 12, 4, 8)!!
val struct = nd.asLongStructure()
val struct = nd.asIntStructure()
struct[intArrayOf(1, 2, 3, 4)] = 777
assertEquals(777, struct[1, 2, 3, 4])
}

View File

@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
*
* @param T the type of items in the tensors.
*/
public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
/**
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
@ -27,7 +27,8 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
*
* @return the value of a scalar tensor.
*/
public fun Tensor<T>.value(): T
public fun Tensor<T>.value(): T =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
/**
* Each element of the tensor [other] is added to this value.
@ -60,15 +61,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
*
* @param value the number to be added to each element of this tensor.
*/
public operator fun Tensor<T>.plusAssign(value: T): Unit
public operator fun Tensor<T>.plusAssign(value: T)
/**
* Each element of the tensor [other] is added to each element of this tensor.
*
* @param other tensor to be added.
*/
public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
public operator fun Tensor<T>.plusAssign(other: Tensor<T>)
/**
* Each element of the tensor [other] is subtracted from this value.
@ -101,14 +101,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
*
* @param value the number to be subtracted from each element of this tensor.
*/
public operator fun Tensor<T>.minusAssign(value: T): Unit
public operator fun Tensor<T>.minusAssign(value: T)
/**
* Each element of the tensor [other] is subtracted from each element of this tensor.
*
* @param other tensor to be subtracted.
*/
public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
public operator fun Tensor<T>.minusAssign(other: Tensor<T>)
/**
@ -142,14 +142,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
*
* @param value the number to be multiplied by each element of this tensor.
*/
public operator fun Tensor<T>.timesAssign(value: T): Unit
public operator fun Tensor<T>.timesAssign(value: T)
/**
* Each element of the tensor [other] is multiplied by each element of this tensor.
*
* @param other tensor to be multiplied.
*/
public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
public operator fun Tensor<T>.timesAssign(other: Tensor<T>)
/**
* Numerical negative, element-wise.
@ -217,7 +217,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
* multiple and removed after.
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
* The non-matrix (i.e. batch) dimensions are broadcast (and thus must be broadcastable).
* For example, if `input` is a (j &times; 1 &times; n &times; n) tensor and `other` is a
* (k &times; n &times; n) tensor, out will be a (j &times; k &times; n &times; n) tensor.
*
@ -255,7 +255,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
diagonalEntries: Tensor<T>,
offset: Int = 0,
dim1: Int = -2,
dim2: Int = -1
dim2: Int = -1,
): Tensor<T>
/**

View File

@ -9,20 +9,9 @@ import space.kscience.kmath.nd.as1D
import space.kscience.kmath.nd.as2D
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.api.Tensor
import space.kscience.kmath.tensors.core.internal.dotHelper
import space.kscience.kmath.tensors.core.internal.getRandomNormals
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
import space.kscience.kmath.tensors.core.internal.*
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency
import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer
import space.kscience.kmath.tensors.core.internal.checkEmptyShape
import space.kscience.kmath.tensors.core.internal.checkShapesCompatible
import space.kscience.kmath.tensors.core.internal.checkSquareMatrix
import space.kscience.kmath.tensors.core.internal.checkTranspose
import space.kscience.kmath.tensors.core.internal.checkView
import space.kscience.kmath.tensors.core.internal.minusIndexFrom
import kotlin.math.*
/**
@ -38,8 +27,8 @@ public open class DoubleTensorAlgebra :
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
tensor.mutableBuffer.array()[tensor.bufferStart] else null
override fun Tensor<Double>.value(): Double =
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
override fun Tensor<Double>.value(): Double = valueOrNull()
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
/**
* Constructs a tensor with the specified shape and data.
@ -466,7 +455,7 @@ public open class DoubleTensorAlgebra :
private fun Tensor<Double>.eq(
other: Tensor<Double>,
eqFunction: (Double, Double) -> Boolean
eqFunction: (Double, Double) -> Boolean,
): Boolean {
checkShapesCompatible(tensor, other)
val n = tensor.numElements
@ -540,7 +529,7 @@ public open class DoubleTensorAlgebra :
internal fun Tensor<Double>.foldDim(
foldFunction: (DoubleArray) -> Double,
dim: Int,
keepDim: Boolean
keepDim: Boolean,
): DoubleTensor {
check(dim < dimension) { "Dimension $dim out of range $dimension" }
val resShape = if (keepDim) {
@ -729,7 +718,7 @@ public open class DoubleTensorAlgebra :
*/
public fun luPivot(
luTensor: Tensor<Double>,
pivotsTensor: Tensor<Int>
pivotsTensor: Tensor<Int>,
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
checkSquareMatrix(luTensor.shape)
check(