forked from kscience/kmath
Clone some APIs from DoubleTensorAlgebra to Nd4j, fix #348
This commit is contained in:
parent
c24cf90262
commit
86b8cd1c97
@ -15,9 +15,9 @@ private class Nd4jArrayIndicesIterator(private val iterateOver: INDArray) : Iter
|
||||
|
||||
override fun next(): IntArray {
|
||||
val la = if (iterateOver.ordering() == 'c')
|
||||
Shape.ind2subC(iterateOver, i++.toLong())!!
|
||||
Shape.ind2subC(iterateOver, i++.toLong())
|
||||
else
|
||||
Shape.ind2sub(iterateOver, i++.toLong())!!
|
||||
Shape.ind2sub(iterateOver, i++.toLong())
|
||||
|
||||
return la.toIntArray()
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ import space.kscience.kmath.nd.StructureND
|
||||
*/
|
||||
public sealed class Nd4jArrayStructure<T> : MutableStructureND<T> {
|
||||
/**
|
||||
* The wrapped [INDArray]. Since KMath uses [Int] indexes, assuming the size of [INDArray] is less or equal to
|
||||
* The wrapped [INDArray]. Since KMath uses [Int] indices, assuming the size of [INDArray] is less or equal to
|
||||
* [Int.MAX_VALUE].
|
||||
*/
|
||||
public abstract val ndArray: INDArray
|
||||
|
@ -9,6 +9,7 @@ import org.nd4j.linalg.api.ndarray.INDArray
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh
|
||||
import org.nd4j.linalg.api.shape.Shape
|
||||
import org.nd4j.linalg.factory.Nd4j
|
||||
import org.nd4j.linalg.factory.ops.NDBase
|
||||
import org.nd4j.linalg.ops.transforms.Transforms
|
||||
@ -18,10 +19,14 @@ import space.kscience.kmath.nd.Shape
|
||||
import space.kscience.kmath.nd.StructureND
|
||||
import space.kscience.kmath.operations.DoubleField
|
||||
import space.kscience.kmath.operations.Field
|
||||
import space.kscience.kmath.samplers.GaussianSampler
|
||||
import space.kscience.kmath.stat.RandomGenerator
|
||||
import space.kscience.kmath.structures.toDoubleArray
|
||||
import space.kscience.kmath.tensors.api.AnalyticTensorAlgebra
|
||||
import space.kscience.kmath.tensors.api.Tensor
|
||||
import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensorAlgebra
|
||||
import space.kscience.kmath.tensors.core.DoubleTensor
|
||||
import kotlin.math.abs
|
||||
|
||||
/**
|
||||
* ND4J based [TensorAlgebra] implementation.
|
||||
@ -156,10 +161,23 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
||||
Nd4j.getExecutioner().exec(Variance(ndArray, true, true, dim)).wrap()
|
||||
|
||||
private companion object {
|
||||
private val ndBase: ThreadLocal<NDBase> = ThreadLocal.withInitial(::NDBase)
|
||||
private val ndBase = NDBase()
|
||||
}
|
||||
|
||||
|
||||
private fun minusIndexFrom(n: Int, i: Int): Int = if (i >= 0) i else {
|
||||
val ii = n + i
|
||||
check(ii >= 0) { "Out of bound index $i for tensor of dim $n" }
|
||||
ii
|
||||
}
|
||||
|
||||
private fun getRandomNormals(n: Int, seed: Long): DoubleArray {
|
||||
val distribution = GaussianSampler(0.0, 1.0)
|
||||
val generator = RandomGenerator.default(seed)
|
||||
return distribution.sample(generator).nextBufferBlocking(n).toDoubleArray()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* [Double] specialization of [Nd4jTensorAlgebra].
|
||||
*/
|
||||
@ -180,7 +198,7 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Double>.ndArray: INDArray
|
||||
public override val StructureND<Double>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
is Nd4jArrayStructure<Double> -> ndArray
|
||||
else -> Nd4j.zeros(*shape).also {
|
||||
@ -197,7 +215,47 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||
offset: Int,
|
||||
dim1: Int,
|
||||
dim2: Int,
|
||||
): Tensor<Double> = DoubleTensorAlgebra.diagonalEmbedding(diagonalEntries, offset, dim1, dim2)
|
||||
): Tensor<Double> {
|
||||
val diagonalEntriesNDArray = diagonalEntries.ndArray
|
||||
val n = diagonalEntries.shape.size
|
||||
val d1 = minusIndexFrom(n + 1, dim1)
|
||||
val d2 = minusIndexFrom(n + 1, dim2)
|
||||
check(d1 != d2) { "Diagonal dimensions cannot be identical $d1, $d2" }
|
||||
check(d1 <= n && d2 <= n) { "Dimension out of range" }
|
||||
var lessDim = d1
|
||||
var greaterDim = d2
|
||||
var realOffset = offset
|
||||
|
||||
if (lessDim > greaterDim) {
|
||||
realOffset *= -1
|
||||
lessDim = greaterDim.also { greaterDim = lessDim }
|
||||
}
|
||||
|
||||
val resShape = diagonalEntries.shape.sliceArray(0 until lessDim) +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.sliceArray(lessDim until greaterDim - 1) +
|
||||
intArrayOf(diagonalEntries.shape[n - 1] + abs(realOffset)) +
|
||||
diagonalEntries.shape.sliceArray(greaterDim - 1 until n - 1)
|
||||
val resTensor = Nd4j.zeros(*resShape).wrap()
|
||||
|
||||
for (i in 0 until diagonalEntriesNDArray.length()) {
|
||||
val multiIndex = (if (diagonalEntriesNDArray.ordering() == 'c')
|
||||
Shape.ind2subC(diagonalEntriesNDArray, i)
|
||||
else
|
||||
Shape.ind2sub(diagonalEntriesNDArray, i)).toIntArray()
|
||||
|
||||
var offset1 = 0
|
||||
var offset2 = abs(realOffset)
|
||||
if (realOffset < 0) offset1 = offset2.also { offset2 = offset1 }
|
||||
|
||||
val diagonalMultiIndex = multiIndex.sliceArray(0 until lessDim) +
|
||||
intArrayOf(multiIndex[n - 1] + offset1) +
|
||||
multiIndex.sliceArray(lessDim until greaterDim - 1) +
|
||||
intArrayOf(multiIndex[n - 1] + offset2) +
|
||||
multiIndex.sliceArray(greaterDim - 1 until n - 1)
|
||||
|
||||
resTensor[diagonalMultiIndex] = diagonalEntries[multiIndex]
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.sum(): Double = ndArray.sumNumber().toDouble()
|
||||
override fun StructureND<Double>.min(): Double = ndArray.minNumber().toDouble()
|
||||
@ -205,4 +263,125 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||
override fun StructureND<Double>.mean(): Double = ndArray.meanNumber().toDouble()
|
||||
override fun StructureND<Double>.std(): Double = ndArray.stdNumber().toDouble()
|
||||
override fun StructureND<Double>.variance(): Double = ndArray.varNumber().toDouble()
|
||||
return resTensor
|
||||
}
|
||||
|
||||
/**
|
||||
* Compares element-wise two tensors with a specified precision.
|
||||
*
|
||||
* @param other the tensor to compare with `input` tensor.
|
||||
* @param epsilon permissible error when comparing two Double values.
|
||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||
*/
|
||||
public fun Tensor<Double>.eq(other: Tensor<Double>, epsilon: Double): Boolean =
|
||||
ndArray.equalsWithEps(other, epsilon)
|
||||
|
||||
/**
|
||||
* Compares element-wise two tensors.
|
||||
* Comparison of two Double values occurs with 1e-5 precision.
|
||||
*
|
||||
* @param other the tensor to compare with `input` tensor.
|
||||
* @return true if two tensors have the same shape and elements, false otherwise.
|
||||
*/
|
||||
public infix fun Tensor<Double>.eq(other: Tensor<Double>): Boolean = eq(other, 1e-5)
|
||||
|
||||
public override fun Tensor<Double>.sum(): Double = ndArray.sumNumber().toDouble()
|
||||
public override fun Tensor<Double>.min(): Double = ndArray.minNumber().toDouble()
|
||||
public override fun Tensor<Double>.max(): Double = ndArray.maxNumber().toDouble()
|
||||
public override fun Tensor<Double>.mean(): Double = ndArray.meanNumber().toDouble()
|
||||
public override fun Tensor<Double>.std(): Double = ndArray.stdNumber().toDouble()
|
||||
public override fun Tensor<Double>.variance(): Double = ndArray.varNumber().toDouble()
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and data.
|
||||
*
|
||||
* @param shape the desired shape for the tensor.
|
||||
* @param buffer one-dimensional data array.
|
||||
* @return tensor with the [shape] shape and [buffer] data.
|
||||
*/
|
||||
public fun fromArray(shape: IntArray, buffer: DoubleArray): Nd4jArrayStructure<Double> =
|
||||
Nd4j.create(buffer, shape).wrap()
|
||||
|
||||
/**
|
||||
* Constructs a tensor with the specified shape and initializer.
|
||||
*
|
||||
* @param shape the desired shape for the tensor.
|
||||
* @param initializer mapping tensor indices to values.
|
||||
* @return tensor with the [shape] shape and data generated by the [initializer].
|
||||
*/
|
||||
public fun produce(shape: IntArray, initializer: (IntArray) -> Double): Nd4jArrayStructure<Double> {
|
||||
val struct = Nd4j.create(*shape)!!.wrap()
|
||||
struct.indicesIterator().forEach { struct[it] = initializer(it) }
|
||||
return struct
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a tensor of random numbers drawn from normal distributions with 0.0 mean and 1.0 standard deviation.
|
||||
*
|
||||
* @param shape the desired shape for the output tensor.
|
||||
* @param seed the random seed of the pseudo-random number generator.
|
||||
* @return tensor of a given shape filled with numbers from the normal distribution
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*/
|
||||
public fun randomNormal(shape: IntArray, seed: Long = 0): Nd4jArrayStructure<Double> =
|
||||
fromArray(shape, getRandomNormals(shape.reduce(Int::times), seed))
|
||||
|
||||
/**
|
||||
* Returns a tensor with the same shape as `input` of random numbers drawn from normal distributions
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*
|
||||
* @param seed the random seed of the pseudo-random number generator.
|
||||
* @return tensor with the same shape as `input` filled with numbers from the normal distribution
|
||||
* with 0.0 mean and 1.0 standard deviation.
|
||||
*/
|
||||
public fun Tensor<Double>.randomNormalLike(seed: Long = 0): Nd4jArrayStructure<Double> =
|
||||
fromArray(shape, getRandomNormals(shape.reduce(Int::times), seed))
|
||||
|
||||
/**
|
||||
* Creates a tensor of a given shape and fills all elements with a given value.
|
||||
*
|
||||
* @param value the value to fill the output tensor with.
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor with the [shape] shape and filled with [value].
|
||||
*/
|
||||
public fun full(value: Double, shape: IntArray): Nd4jArrayStructure<Double> = Nd4j.valueArrayOf(shape, value).wrap()
|
||||
|
||||
/**
|
||||
* Returns a tensor with the same shape as `input` filled with [value].
|
||||
*
|
||||
* @param value the value to fill the output tensor with.
|
||||
* @return tensor with the `input` tensor shape and filled with [value].
|
||||
*/
|
||||
public fun Tensor<Double>.fullLike(value: Double): Nd4jArrayStructure<Double> =
|
||||
Nd4j.valueArrayOf(ndArray.shape(), value).wrap()
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 0.0, with the shape defined by the variable argument [shape].
|
||||
*
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor filled with the scalar value 0.0, with the [shape] shape.
|
||||
*/
|
||||
public fun zeros(shape: IntArray): Nd4jArrayStructure<Double> = full(0.0, shape)
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 0.0, with the same shape as a given array.
|
||||
*
|
||||
* @return tensor filled with the scalar value 0.0, with the same shape as `input` tensor.
|
||||
*/
|
||||
public fun Tensor<Double>.zeroesLike(): Nd4jArrayStructure<Double> = Nd4j.zerosLike(ndArray).wrap()
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 1.0, with the shape defined by the variable argument [shape].
|
||||
*
|
||||
* @param shape array of integers defining the shape of the output tensor.
|
||||
* @return tensor filled with the scalar value 1.0, with the [shape] shape.
|
||||
*/
|
||||
public fun ones(shape: IntArray): Nd4jArrayStructure<Double> = Nd4j.ones(*shape).wrap()
|
||||
|
||||
/**
|
||||
* Returns a tensor filled with the scalar value 1.0, with the same shape as a given array.
|
||||
*
|
||||
* @return tensor filled with the scalar value 1.0, with the same shape as `input` tensor.
|
||||
*/
|
||||
public fun Tensor<Double>.onesLike(): Nd4jArrayStructure<Double> = Nd4j.onesLike(ndArray).wrap()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user