Merge pull request #332 from mipt-npm/commandertvis/nd4j
Nd4j based TensorAlgebra implementation
This commit is contained in:
commit
958788bc91
kmath-core/src/commonMain/kotlin/space/kscience/kmath/nd
kmath-nd4j
kmath-tensors/src/commonMain/kotlin/space/kscience/kmath/tensors
@ -227,7 +227,6 @@ public interface RingND<T, out R : Ring<T>> : Ring<StructureND<T>>, GroupND<T, R
|
||||
* Field of [StructureND].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public interface FieldND<T, out F : Field<T>> : Field<StructureND<T>>, RingND<T, F>, ScaleOperations<StructureND<T>> {
|
||||
|
@ -4,7 +4,7 @@ plugins {
|
||||
}
|
||||
|
||||
dependencies {
|
||||
api(project(":kmath-core"))
|
||||
api(project(":kmath-tensors"))
|
||||
api("org.nd4j:nd4j-api:1.0.0-beta7")
|
||||
testImplementation("org.nd4j:nd4j-native:1.0.0-beta7")
|
||||
testImplementation("org.nd4j:nd4j-native-platform:1.0.0-beta7")
|
||||
@ -15,19 +15,7 @@ readme {
|
||||
description = "ND4J NDStructure implementation and according NDAlgebra classes"
|
||||
maturity = ru.mipt.npm.gradle.Maturity.EXPERIMENTAL
|
||||
propertyByTemplate("artifact", rootProject.file("docs/templates/ARTIFACT-TEMPLATE.md"))
|
||||
|
||||
feature(
|
||||
id = "nd4jarraystructure",
|
||||
description = "NDStructure wrapper for INDArray"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "nd4jarrayrings",
|
||||
description = "Rings over Nd4jArrayStructure of Int and Long"
|
||||
)
|
||||
|
||||
feature(
|
||||
id = "nd4jarrayfields",
|
||||
description = "Fields over Nd4jArrayStructure of Float and Double"
|
||||
)
|
||||
feature(id = "nd4jarraystructure") { "NDStructure wrapper for INDArray" }
|
||||
feature(id = "nd4jarrayrings") { "Rings over Nd4jArrayStructure of Int and Long" }
|
||||
feature(id = "nd4jarrayfields") { "Fields over Nd4jArrayStructure of Float and Double" }
|
||||
}
|
||||
|
@ -6,15 +6,14 @@
|
||||
package space.kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Pow
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.ops.transforms.Transforms
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.misc.UnstableKMathAPI
|
||||
import space.kscience.kmath.nd.*
|
||||
import space.kscience.kmath.operations.*
|
||||
import space.kscience.kmath.structures.*
|
||||
|
||||
internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
|
||||
val arrayShape = array.shape().toIntArray()
|
||||
@ -29,23 +28,16 @@ internal fun AlgebraND<*, *>.checkShape(array: INDArray): INDArray {
|
||||
* @param T the type of ND-structure element.
|
||||
* @param C the type of the element context.
|
||||
*/
|
||||
public interface Nd4jArrayAlgebra<T, C : Algebra<T>> : AlgebraND<T, C> {
|
||||
public sealed interface Nd4jArrayAlgebra<T, out C : Algebra<T>> : AlgebraND<T, C> {
|
||||
/**
|
||||
* Wraps [INDArray] to [N].
|
||||
* Wraps [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
/**
|
||||
* Unwraps to or acquires [INDArray] from [StructureND].
|
||||
*/
|
||||
public val StructureND<T>.ndArray: INDArray
|
||||
get() = when {
|
||||
!shape.contentEquals(this@Nd4jArrayAlgebra.shape) -> throw ShapeMismatchException(
|
||||
this@Nd4jArrayAlgebra.shape,
|
||||
shape
|
||||
)
|
||||
this is Nd4jArrayStructure -> ndArray //TODO check strides
|
||||
else -> {
|
||||
TODO()
|
||||
}
|
||||
}
|
||||
|
||||
public override fun produce(initializer: C.(IntArray) -> T): Nd4jArrayStructure<T> {
|
||||
val struct = Nd4j.create(*shape)!!.wrap()
|
||||
@ -85,7 +77,7 @@ public interface Nd4jArrayAlgebra<T, C : Algebra<T>> : AlgebraND<T, C> {
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param S the type of space of structure elements.
|
||||
*/
|
||||
public interface Nd4JArrayGroup<T, S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||
public sealed interface Nd4jArrayGroup<T, out S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebra<T, S> {
|
||||
|
||||
public override val zero: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.zeros(*shape).wrap()
|
||||
@ -110,7 +102,7 @@ public interface Nd4JArrayGroup<T, S : Ring<T>> : GroupND<T, S>, Nd4jArrayAlgebr
|
||||
* @param R the type of ring of structure elements.
|
||||
*/
|
||||
@OptIn(UnstableKMathAPI::class)
|
||||
public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T, R> {
|
||||
public sealed interface Nd4jArrayRing<T, out R : Ring<T>> : RingND<T, R>, Nd4jArrayGroup<T, R> {
|
||||
|
||||
public override val one: Nd4jArrayStructure<T>
|
||||
get() = Nd4j.ones(*shape).wrap()
|
||||
@ -135,10 +127,7 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
|
||||
|
||||
public companion object {
|
||||
private val intNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, IntNd4jArrayRing>> =
|
||||
ThreadLocal.withInitial { hashMapOf() }
|
||||
|
||||
private val longNd4jArrayRingCache: ThreadLocal<MutableMap<IntArray, LongNd4jArrayRing>> =
|
||||
ThreadLocal.withInitial { hashMapOf() }
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
/**
|
||||
* Creates an [RingND] for [Int] values or pull it from cache if it was created previously.
|
||||
@ -146,20 +135,13 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
|
||||
public fun int(vararg shape: Int): Nd4jArrayRing<Int, IntRing> =
|
||||
intNd4jArrayRingCache.get().getOrPut(shape) { IntNd4jArrayRing(shape) }
|
||||
|
||||
/**
|
||||
* Creates an [RingND] for [Long] values or pull it from cache if it was created previously.
|
||||
*/
|
||||
public fun long(vararg shape: Int): Nd4jArrayRing<Long, LongRing> =
|
||||
longNd4jArrayRingCache.get().getOrPut(shape) { LongNd4jArrayRing(shape) }
|
||||
|
||||
/**
|
||||
* Creates a most suitable implementation of [RingND] using reified class.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayRing<T, out Ring<T>> = when {
|
||||
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||
T::class == Long::class -> long(*shape) as Nd4jArrayRing<T, out Ring<T>>
|
||||
else -> throw UnsupportedOperationException("This factory method only supports Int and Long types.")
|
||||
public inline fun <reified T : Number> auto(vararg shape: Int): Nd4jArrayRing<T, Ring<T>> = when {
|
||||
T::class == Int::class -> int(*shape) as Nd4jArrayRing<T, Ring<T>>
|
||||
else -> throw UnsupportedOperationException("This factory method only supports Long type.")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -168,11 +150,9 @@ public interface Nd4jArrayRing<T, R : Ring<T>> : RingND<T, R>, Nd4JArrayGroup<T,
|
||||
* Represents [FieldND] over [Nd4jArrayStructure].
|
||||
*
|
||||
* @param T the type of the element contained in ND structure.
|
||||
* @param N the type of ND structure.
|
||||
* @param F the type field of structure elements.
|
||||
*/
|
||||
public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
|
||||
|
||||
public sealed interface Nd4jArrayField<T, out F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<T, F> {
|
||||
public override fun divide(a: StructureND<T>, b: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
a.ndArray.div(b.ndArray).wrap()
|
||||
|
||||
@ -180,10 +160,10 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
|
||||
|
||||
public companion object {
|
||||
private val floatNd4jArrayFieldCache: ThreadLocal<MutableMap<IntArray, FloatNd4jArrayField>> =
|
||||
ThreadLocal.withInitial { hashMapOf() }
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
private val doubleNd4JArrayFieldCache: ThreadLocal<MutableMap<IntArray, DoubleNd4jArrayField>> =
|
||||
ThreadLocal.withInitial { hashMapOf() }
|
||||
ThreadLocal.withInitial(::HashMap)
|
||||
|
||||
/**
|
||||
* Creates an [FieldND] for [Float] values or pull it from cache if it was created previously.
|
||||
@ -198,26 +178,64 @@ public interface Nd4jArrayField<T, F : Field<T>> : FieldND<T, F>, Nd4jArrayRing<
|
||||
doubleNd4JArrayFieldCache.get().getOrPut(shape) { DoubleNd4jArrayField(shape) }
|
||||
|
||||
/**
|
||||
* Creates a most suitable implementation of [RingND] using reified class.
|
||||
* Creates a most suitable implementation of [FieldND] using reified class.
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, out Field<T>> = when {
|
||||
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, out Field<T>>
|
||||
public inline fun <reified T : Any> auto(vararg shape: Int): Nd4jArrayField<T, Field<T>> = when {
|
||||
T::class == Float::class -> float(*shape) as Nd4jArrayField<T, Field<T>>
|
||||
T::class == Double::class -> real(*shape) as Nd4jArrayField<T, Field<T>>
|
||||
else -> throw UnsupportedOperationException("This factory method only supports Float and Double types.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents intersection of [ExtendedField] and [Field] over [Nd4jArrayStructure].
|
||||
*/
|
||||
public sealed interface Nd4jArrayExtendedField<T, out F : ExtendedField<T>> : ExtendedField<StructureND<T>>,
|
||||
Nd4jArrayField<T, F> {
|
||||
public override fun sin(arg: StructureND<T>): StructureND<T> = Transforms.sin(arg.ndArray).wrap()
|
||||
public override fun cos(arg: StructureND<T>): StructureND<T> = Transforms.cos(arg.ndArray).wrap()
|
||||
public override fun asin(arg: StructureND<T>): StructureND<T> = Transforms.asin(arg.ndArray).wrap()
|
||||
public override fun acos(arg: StructureND<T>): StructureND<T> = Transforms.acos(arg.ndArray).wrap()
|
||||
public override fun atan(arg: StructureND<T>): StructureND<T> = Transforms.atan(arg.ndArray).wrap()
|
||||
|
||||
public override fun power(arg: StructureND<T>, pow: Number): StructureND<T> =
|
||||
Transforms.pow(arg.ndArray, pow).wrap()
|
||||
|
||||
public override fun exp(arg: StructureND<T>): StructureND<T> = Transforms.exp(arg.ndArray).wrap()
|
||||
public override fun ln(arg: StructureND<T>): StructureND<T> = Transforms.log(arg.ndArray).wrap()
|
||||
public override fun sqrt(arg: StructureND<T>): StructureND<T> = Transforms.sqrt(arg.ndArray).wrap()
|
||||
public override fun sinh(arg: StructureND<T>): StructureND<T> = Transforms.sinh(arg.ndArray).wrap()
|
||||
public override fun cosh(arg: StructureND<T>): StructureND<T> = Transforms.cosh(arg.ndArray).wrap()
|
||||
public override fun tanh(arg: StructureND<T>): StructureND<T> = Transforms.tanh(arg.ndArray).wrap()
|
||||
|
||||
public override fun asinh(arg: StructureND<T>): StructureND<T> =
|
||||
Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap()
|
||||
|
||||
public override fun acosh(arg: StructureND<T>): StructureND<T> =
|
||||
Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap()
|
||||
|
||||
public override fun atanh(arg: StructureND<T>): StructureND<T> = Transforms.atanh(arg.ndArray).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayDoubleStructure].
|
||||
*/
|
||||
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Double, DoubleField>,
|
||||
ExtendedField<StructureND<Double>> {
|
||||
public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField<Double, DoubleField> {
|
||||
public override val elementContext: DoubleField get() = DoubleField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = checkShape(this).asDoubleStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Double>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Double> -> checkShape(ndArray)
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun scale(a: StructureND<Double>, value: Double): Nd4jArrayStructure<Double> {
|
||||
return a.ndArray.mul(value).wrap()
|
||||
}
|
||||
@ -245,34 +263,25 @@ public class DoubleNd4jArrayField(public override val shape: IntArray) : Nd4jArr
|
||||
public override operator fun Double.minus(arg: StructureND<Double>): Nd4jArrayStructure<Double> {
|
||||
return arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
|
||||
override fun sin(arg: StructureND<Double>): StructureND<Double> = Transforms.sin(arg.ndArray).wrap()
|
||||
|
||||
override fun cos(arg: StructureND<Double>): StructureND<Double> = Transforms.cos(arg.ndArray).wrap()
|
||||
|
||||
override fun asin(arg: StructureND<Double>): StructureND<Double> = Transforms.asin(arg.ndArray).wrap()
|
||||
|
||||
override fun acos(arg: StructureND<Double>): StructureND<Double> = Transforms.acos(arg.ndArray).wrap()
|
||||
|
||||
override fun atan(arg: StructureND<Double>): StructureND<Double> = Transforms.atan(arg.ndArray).wrap()
|
||||
|
||||
override fun power(arg: StructureND<Double>, pow: Number): StructureND<Double> =
|
||||
Transforms.pow(arg.ndArray,pow).wrap()
|
||||
|
||||
override fun exp(arg: StructureND<Double>): StructureND<Double> = Transforms.exp(arg.ndArray).wrap()
|
||||
|
||||
override fun ln(arg: StructureND<Double>): StructureND<Double> = Transforms.log(arg.ndArray).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [FieldND] over [Nd4jArrayStructure] of [Float].
|
||||
*/
|
||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayField<Float, FloatField>,
|
||||
ExtendedField<StructureND<Float>> {
|
||||
public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArrayExtendedField<Float, FloatField> {
|
||||
public override val elementContext: FloatField get() = FloatField
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Float> = checkShape(this).asFloatStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public override val StructureND<Float>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Float> -> checkShape(ndArray)
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun scale(a: StructureND<Float>, value: Double): StructureND<Float> =
|
||||
a.ndArray.mul(value).wrap()
|
||||
|
||||
@ -293,23 +302,6 @@ public class FloatNd4jArrayField(public override val shape: IntArray) : Nd4jArra
|
||||
|
||||
public override operator fun Float.minus(arg: StructureND<Float>): Nd4jArrayStructure<Float> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
|
||||
override fun sin(arg: StructureND<Float>): StructureND<Float> = Sin(arg.ndArray).z().wrap()
|
||||
|
||||
override fun cos(arg: StructureND<Float>): StructureND<Float> = Cos(arg.ndArray).z().wrap()
|
||||
|
||||
override fun asin(arg: StructureND<Float>): StructureND<Float> = ASin(arg.ndArray).z().wrap()
|
||||
|
||||
override fun acos(arg: StructureND<Float>): StructureND<Float> = ACos(arg.ndArray).z().wrap()
|
||||
|
||||
override fun atan(arg: StructureND<Float>): StructureND<Float> = ATan(arg.ndArray).z().wrap()
|
||||
|
||||
override fun power(arg: StructureND<Float>, pow: Number): StructureND<Float> =
|
||||
Pow(arg.ndArray, pow.toDouble()).z().wrap()
|
||||
|
||||
override fun exp(arg: StructureND<Float>): StructureND<Float> = Exp(arg.ndArray).z().wrap()
|
||||
|
||||
override fun ln(arg: StructureND<Float>): StructureND<Float> = Log(arg.ndArray).z().wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
@ -321,6 +313,15 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Int> = checkShape(this).asIntStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public override val StructureND<Int>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Int> -> checkShape(ndArray)
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
}
|
||||
|
||||
public override operator fun StructureND<Int>.plus(arg: Int): Nd4jArrayStructure<Int> =
|
||||
ndArray.add(arg).wrap()
|
||||
|
||||
@ -333,25 +334,3 @@ public class IntNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRi
|
||||
public override operator fun Int.minus(arg: StructureND<Int>): Nd4jArrayStructure<Int> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents [RingND] over [Nd4jArrayStructure] of [Long].
|
||||
*/
|
||||
public class LongNd4jArrayRing(public override val shape: IntArray) : Nd4jArrayRing<Long, LongRing> {
|
||||
public override val elementContext: LongRing
|
||||
get() = LongRing
|
||||
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Long> = checkShape(this).asLongStructure()
|
||||
|
||||
public override operator fun StructureND<Long>.plus(arg: Long): Nd4jArrayStructure<Long> =
|
||||
ndArray.add(arg).wrap()
|
||||
|
||||
public override operator fun StructureND<Long>.minus(arg: Long): Nd4jArrayStructure<Long> =
|
||||
ndArray.sub(arg).wrap()
|
||||
|
||||
public override operator fun StructureND<Long>.times(arg: Long): Nd4jArrayStructure<Long> =
|
||||
ndArray.mul(arg).wrap()
|
||||
|
||||
public override operator fun Long.minus(arg: StructureND<Long>): Nd4jArrayStructure<Long> =
|
||||
arg.ndArray.rsub(this).wrap()
|
||||
}
|
||||
|
@ -48,12 +48,6 @@ private class Nd4jArrayDoubleIterator(iterateOver: INDArray) : Nd4jArrayIterator
|
||||
|
||||
internal fun INDArray.realIterator(): Iterator<Pair<IntArray, Double>> = Nd4jArrayDoubleIterator(this)
|
||||
|
||||
private class Nd4jArrayLongIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Long>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getLong(*indices)
|
||||
}
|
||||
|
||||
internal fun INDArray.longIterator(): Iterator<Pair<IntArray, Long>> = Nd4jArrayLongIterator(this)
|
||||
|
||||
private class Nd4jArrayIntIterator(iterateOver: INDArray) : Nd4jArrayIteratorBase<Int>(iterateOver) {
|
||||
override fun getSingle(indices: LongArray) = iterateOver.getInt(*indices.toIntArray())
|
||||
}
|
||||
|
@ -17,7 +17,8 @@ import space.kscience.kmath.nd.StructureND
|
||||
*/
|
||||
public sealed class Nd4jArrayStructure<T> : MutableStructureND<T> {
|
||||
/**
|
||||
* The wrapped [INDArray].
|
||||
* The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming that the size of [INDArray] is less or equal to
|
||||
* [Int.MAX_VALUE].
|
||||
*/
|
||||
public abstract val ndArray: INDArray
|
||||
|
||||
@ -25,6 +26,7 @@ public sealed class Nd4jArrayStructure<T> : MutableStructureND<T> {
|
||||
|
||||
internal abstract fun elementsIterator(): Iterator<Pair<IntArray, T>>
|
||||
internal fun indicesIterator(): Iterator<IntArray> = ndArray.indicesIterator()
|
||||
|
||||
@PerformancePitfall
|
||||
public override fun elements(): Sequence<Pair<IntArray, T>> = Sequence(::elementsIterator)
|
||||
}
|
||||
@ -40,17 +42,6 @@ private data class Nd4jArrayIntStructure(override val ndArray: INDArray) : Nd4jA
|
||||
*/
|
||||
public fun INDArray.asIntStructure(): Nd4jArrayStructure<Int> = Nd4jArrayIntStructure(this)
|
||||
|
||||
private data class Nd4jArrayLongStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Long>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Long>> = ndArray.longIterator()
|
||||
override fun get(index: IntArray): Long = ndArray.getLong(*index.toLongArray())
|
||||
override fun set(index: IntArray, value: Long): Unit = run { ndArray.putScalar(index, value.toDouble()) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps this [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.asLongStructure(): Nd4jArrayStructure<Long> = Nd4jArrayLongStructure(this)
|
||||
|
||||
private data class Nd4jArrayDoubleStructure(override val ndArray: INDArray) : Nd4jArrayStructure<Double>() {
|
||||
override fun elementsIterator(): Iterator<Pair<IntArray, Double>> = ndArray.realIterator()
|
||||
override fun get(index: IntArray): Double = ndArray.getDouble(*index)
|
||||
|
@ -0,0 +1,175 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.nd4j
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.factory.ops.NDBase
|
||||
import org.nd4j.linalg.ops.transforms.Transforms
|
||||
import space.kscience.kmath.misc.PerformancePitfall
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
|
||||
/**
|
||||
* ND4J based [TensorAlgebra] implementation.
|
||||
*/
|
||||
public sealed interface Nd4jTensorAlgebra<T : Number> : AnalyticTensorAlgebra<T> {
|
||||
/**
|
||||
* Wraps [INDArray] to [Nd4jArrayStructure].
|
||||
*/
|
||||
public fun INDArray.wrap(): Nd4jArrayStructure<T>
|
||||
|
||||
/**
|
||||
* Unwraps to or acquires [INDArray] from [StructureND].
|
||||
*/
|
||||
public val StructureND<T>.ndArray: INDArray
|
||||
|
||||
public override fun T.plus(other: Tensor<T>): Tensor<T> = other.ndArray.add(this).wrap()
|
||||
public override fun Tensor<T>.plus(value: T): Tensor<T> = ndArray.add(value).wrap()
|
||||
|
||||
public override fun Tensor<T>.plus(other: Tensor<T>): Tensor<T> = ndArray.add(other.ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.plusAssign(value: T) {
|
||||
ndArray.addi(value)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.plusAssign(other: Tensor<T>) {
|
||||
ndArray.addi(other.ndArray)
|
||||
}
|
||||
|
||||
public override fun T.minus(other: Tensor<T>): Tensor<T> = other.ndArray.rsub(this).wrap()
|
||||
public override fun Tensor<T>.minus(value: T): Tensor<T> = ndArray.sub(value).wrap()
|
||||
public override fun Tensor<T>.minus(other: Tensor<T>): Tensor<T> = ndArray.sub(other.ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.minusAssign(value: T) {
|
||||
ndArray.rsubi(value)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.minusAssign(other: Tensor<T>) {
|
||||
ndArray.subi(other.ndArray)
|
||||
}
|
||||
|
||||
public override fun T.times(other: Tensor<T>): Tensor<T> = other.ndArray.mul(this).wrap()
|
||||
|
||||
public override fun Tensor<T>.times(value: T): Tensor<T> =
|
||||
ndArray.mul(value).wrap()
|
||||
|
||||
public override fun Tensor<T>.times(other: Tensor<T>): Tensor<T> = ndArray.mul(other.ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.timesAssign(value: T) {
|
||||
ndArray.muli(value)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.timesAssign(other: Tensor<T>) {
|
||||
ndArray.mmuli(other.ndArray)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.unaryMinus(): Tensor<T> = ndArray.neg().wrap()
|
||||
public override fun Tensor<T>.get(i: Int): Tensor<T> = ndArray.slice(i.toLong()).wrap()
|
||||
public override fun Tensor<T>.transpose(i: Int, j: Int): Tensor<T> = ndArray.swapAxes(i, j).wrap()
|
||||
public override fun Tensor<T>.dot(other: Tensor<T>): Tensor<T> = ndArray.mmul(other.ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.min(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndArray.min(keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndArray.sum(keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.max(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndArray.max(keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.view(shape: IntArray): Tensor<T> = ndArray.reshape(shape).wrap()
|
||||
public override fun Tensor<T>.viewAs(other: Tensor<T>): Tensor<T> = view(other.shape)
|
||||
|
||||
public override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
ndBase.get().argmax(ndArray, keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.mean(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.mean(keepDim, dim).wrap()
|
||||
|
||||
public override fun Tensor<T>.exp(): Tensor<T> = Transforms.exp(ndArray).wrap()
|
||||
public override fun Tensor<T>.ln(): Tensor<T> = Transforms.log(ndArray).wrap()
|
||||
public override fun Tensor<T>.sqrt(): Tensor<T> = Transforms.sqrt(ndArray).wrap()
|
||||
public override fun Tensor<T>.cos(): Tensor<T> = Transforms.cos(ndArray).wrap()
|
||||
public override fun Tensor<T>.acos(): Tensor<T> = Transforms.acos(ndArray).wrap()
|
||||
public override fun Tensor<T>.cosh(): Tensor<T> = Transforms.cosh(ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.acosh(): Tensor<T> =
|
||||
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
|
||||
|
||||
public override fun Tensor<T>.sin(): Tensor<T> = Transforms.sin(ndArray).wrap()
|
||||
public override fun Tensor<T>.asin(): Tensor<T> = Transforms.asin(ndArray).wrap()
|
||||
public override fun Tensor<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.asinh(): Tensor<T> =
|
||||
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
|
||||
|
||||
public override fun Tensor<T>.tan(): Tensor<T> = Transforms.tan(ndArray).wrap()
|
||||
public override fun Tensor<T>.atan(): Tensor<T> = Transforms.atan(ndArray).wrap()
|
||||
public override fun Tensor<T>.tanh(): Tensor<T> = Transforms.tanh(ndArray).wrap()
|
||||
public override fun Tensor<T>.atanh(): Tensor<T> = Transforms.atanh(ndArray).wrap()
|
||||
public override fun Tensor<T>.ceil(): Tensor<T> = Transforms.ceil(ndArray).wrap()
|
||||
public override fun Tensor<T>.floor(): Tensor<T> = Transforms.floor(ndArray).wrap()
|
||||
public override fun Tensor<T>.std(dim: Int, keepDim: Boolean): Tensor<T> = ndArray.std(true, keepDim, dim).wrap()
|
||||
public override fun T.div(other: Tensor<T>): Tensor<T> = other.ndArray.rdiv(this).wrap()
|
||||
public override fun Tensor<T>.div(value: T): Tensor<T> = ndArray.div(value).wrap()
|
||||
public override fun Tensor<T>.div(other: Tensor<T>): Tensor<T> = ndArray.div(other.ndArray).wrap()
|
||||
|
||||
public override fun Tensor<T>.divAssign(value: T) {
|
||||
ndArray.divi(value)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.divAssign(other: Tensor<T>) {
|
||||
ndArray.divi(other.ndArray)
|
||||
}
|
||||
|
||||
public override fun Tensor<T>.variance(dim: Int, keepDim: Boolean): Tensor<T> =
|
||||
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
|
||||
|
||||
private companion object {
|
||||
private val ndBase: ThreadLocal<NDBase> = ThreadLocal.withInitial(::NDBase)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* [Double] specialization of [Nd4jTensorAlgebra].
|
||||
*/
|
||||
public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double> {
|
||||
public override fun INDArray.wrap(): Nd4jArrayStructure<Double> = asDoubleStructure()
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
public override val StructureND<Double>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Double> -> ndArray
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
elements().forEach { (idx, value) -> it.putScalar(idx, value) }
|
||||
}
|
||||
}
|
||||
|
||||
public override fun Tensor<Double>.valueOrNull(): Double? =
|
||||
if (shape contentEquals intArrayOf(1)) ndArray.getDouble(0) else null
|
||||
|
||||
// TODO rewrite
|
||||
@PerformancePitfall
|
||||
public override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<Double>,
|
||||
offset: Int,
|
||||
dim1: Int,
|
||||
dim2: Int,
|
||||
): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2)
|
||||
|
||||
public override fun Tensor<Double>.sum(): Double = ndArray.sumNumber().toDouble()
|
||||
public override fun Tensor<Double>.min(): Double = ndArray.minNumber().toDouble()
|
||||
public override fun Tensor<Double>.max(): Double = ndArray.maxNumber().toDouble()
|
||||
public override fun Tensor<Double>.mean(): Double = ndArray.meanNumber().toDouble()
|
||||
public override fun Tensor<Double>.std(): Double = ndArray.stdNumber().toDouble()
|
||||
public override fun Tensor<Double>.variance(): Double = ndArray.varNumber().toDouble()
|
||||
}
|
@ -5,5 +5,4 @@
|
||||
|
||||
package space.kscience.kmath.nd4j
|
||||
|
||||
internal fun IntArray.toLongArray(): LongArray = LongArray(size) { this[it].toLong() }
|
||||
internal fun LongArray.toIntArray(): IntArray = IntArray(size) { this[it].toInt() }
|
||||
|
@ -52,7 +52,7 @@ internal class Nd4jArrayAlgebraTest {
|
||||
|
||||
@Test
|
||||
fun testSin() = DoubleNd4jArrayField(intArrayOf(2, 2)).invoke {
|
||||
val initial = produce { (i, j) -> if (i == j) PI/2 else 0.0 }
|
||||
val initial = produce { (i, j) -> if (i == j) PI / 2 else 0.0 }
|
||||
val transformed = sin(initial)
|
||||
val expected = produce { (i, j) -> if (i == j) 1.0 else 0.0 }
|
||||
|
||||
|
@ -72,7 +72,7 @@ internal class Nd4jArrayStructureTest {
|
||||
@Test
|
||||
fun testSet() {
|
||||
val nd = Nd4j.rand(17, 12, 4, 8)!!
|
||||
val struct = nd.asLongStructure()
|
||||
val struct = nd.asIntStructure()
|
||||
struct[intArrayOf(1, 2, 3, 4)] = 777
|
||||
assertEquals(777, struct[1, 2, 3, 4])
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ import space.kscience.kmath.operations.Algebra
|
||||
*
|
||||
* @param T the type of items in the tensors.
|
||||
*/
|
||||
public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
public interface TensorAlgebra<T> : Algebra<Tensor<T>> {
|
||||
|
||||
/**
|
||||
* Returns a single tensor value of unit dimension if tensor shape equals to [1].
|
||||
@ -27,7 +27,8 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
*
|
||||
* @return the value of a scalar tensor.
|
||||
*/
|
||||
public fun Tensor<T>.value(): T
|
||||
public fun Tensor<T>.value(): T =
|
||||
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to this value.
|
||||
@ -60,15 +61,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
*
|
||||
* @param value the number to be added to each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.plusAssign(value: T): Unit
|
||||
public operator fun Tensor<T>.plusAssign(value: T)
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is added to each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be added.
|
||||
*/
|
||||
public operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit
|
||||
|
||||
public operator fun Tensor<T>.plusAssign(other: Tensor<T>)
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is subtracted from this value.
|
||||
@ -101,14 +101,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
*
|
||||
* @param value the number to be subtracted from each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.minusAssign(value: T): Unit
|
||||
public operator fun Tensor<T>.minusAssign(value: T)
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is subtracted from each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be subtracted.
|
||||
*/
|
||||
public operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit
|
||||
public operator fun Tensor<T>.minusAssign(other: Tensor<T>)
|
||||
|
||||
|
||||
/**
|
||||
@ -142,14 +142,14 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
*
|
||||
* @param value the number to be multiplied by each element of this tensor.
|
||||
*/
|
||||
public operator fun Tensor<T>.timesAssign(value: T): Unit
|
||||
public operator fun Tensor<T>.timesAssign(value: T)
|
||||
|
||||
/**
|
||||
* Each element of the tensor [other] is multiplied by each element of this tensor.
|
||||
*
|
||||
* @param other tensor to be multiplied.
|
||||
*/
|
||||
public operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit
|
||||
public operator fun Tensor<T>.timesAssign(other: Tensor<T>)
|
||||
|
||||
/**
|
||||
* Numerical negative, element-wise.
|
||||
@ -217,7 +217,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
* a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.
|
||||
* If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix
|
||||
* multiple and removed after.
|
||||
* The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable).
|
||||
* The non-matrix (i.e. batch) dimensions are broadcast (and thus must be broadcastable).
|
||||
* For example, if `input` is a (j × 1 × n × n) tensor and `other` is a
|
||||
* (k × n × n) tensor, out will be a (j × k × n × n) tensor.
|
||||
*
|
||||
@ -255,7 +255,7 @@ public interface TensorAlgebra<T>: Algebra<Tensor<T>> {
|
||||
diagonalEntries: Tensor<T>,
|
||||
offset: Int = 0,
|
||||
dim1: Int = -2,
|
||||
dim2: Int = -1
|
||||
dim2: Int = -1,
|
||||
): Tensor<T>
|
||||
|
||||
/**
|
||||
|
@ -9,20 +9,9 @@ import space.kscience.kmath.nd.as1D
|
||||
import space.kscience.kmath.nd.as2D
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.LinearOpsTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.core.internal.dotHelper
|
||||
import space.kscience.kmath.tensors.core.internal.getRandomNormals
|
||||
import space.kscience.kmath.tensors.api.TensorPartialDivisionAlgebra
|
||||
import space.kscience.kmath.tensors.core.internal.*
|
||||
import space.kscience.kmath.tensors.core.internal.broadcastOuterTensors
|
||||
import space.kscience.kmath.tensors.core.internal.checkBufferShapeConsistency
|
||||
import space.kscience.kmath.tensors.core.internal.checkEmptyDoubleBuffer
|
||||
import space.kscience.kmath.tensors.core.internal.checkEmptyShape
|
||||
import space.kscience.kmath.tensors.core.internal.checkShapesCompatible
|
||||
import space.kscience.kmath.tensors.core.internal.checkSquareMatrix
|
||||
import space.kscience.kmath.tensors.core.internal.checkTranspose
|
||||
import space.kscience.kmath.tensors.core.internal.checkView
|
||||
import space.kscience.kmath.tensors.core.internal.minusIndexFrom
|
||||
import kotlin.math.*
|
||||
|
||||
/**
|
||||
@ -38,8 +27,8 @@ public open class DoubleTensorAlgebra :
|
||||
override fun Tensor<Double>.valueOrNull(): Double? = if (tensor.shape contentEquals intArrayOf(1))
|
||||
tensor.mutableBuffer.array()[tensor.bufferStart] else null
|
||||
|
||||
override fun Tensor<Double>.value(): Double =
|
||||
valueOrNull() ?: throw IllegalArgumentException("Inconsistent value for tensor of with $shape shape")
|
||||
override fun Tensor<Double>.value(): Double = valueOrNull()
|
||||
?: throw IllegalArgumentException("The tensor shape is $shape, but value method is allowed only for shape [1]")
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and data.
|
||||
@ -466,7 +455,7 @@ public open class DoubleTensorAlgebra :
|
||||
|
||||
private fun Tensor<Double>.eq(
|
||||
other: Tensor<Double>,
|
||||
eqFunction: (Double, Double) -> Boolean
|
||||
eqFunction: (Double, Double) -> Boolean,
|
||||
): Boolean {
|
||||
checkShapesCompatible(tensor, other)
|
||||
val n = tensor.numElements
|
||||
@ -540,7 +529,7 @@ public open class DoubleTensorAlgebra :
|
||||
internal fun Tensor<Double>.foldDim(
|
||||
foldFunction: (DoubleArray) -> Double,
|
||||
dim: Int,
|
||||
keepDim: Boolean
|
||||
keepDim: Boolean,
|
||||
): DoubleTensor {
|
||||
check(dim < dimension) { "Dimension $dim out of range $dimension" }
|
||||
val resShape = if (keepDim) {
|
||||
@ -729,7 +718,7 @@ public open class DoubleTensorAlgebra :
|
||||
*/
|
||||
public fun luPivot(
|
||||
luTensor: Tensor<Double>,
|
||||
pivotsTensor: Tensor<Int>
|
||||
pivotsTensor: Tensor<Int>,
|
||||
): Triple<DoubleTensor, DoubleTensor, DoubleTensor> {
|
||||
checkSquareMatrix(luTensor.shape)
|
||||
check(
|
||||
|
Loading…
x
Reference in New Issue
Block a user