diff --git a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java index f2d65aa5c..ce7eddb93 100644 --- a/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java +++ b/kmath-noa/src/main/java/space/kscience/kmath/noa/JNoa.java @@ -33,6 +33,8 @@ class JNoa { public static native void disposeTensor(long tensorHandle); + public static native long emptyTensor(); + public static native long fromBlobDouble(double[] data, int[] shape, int device); public static native long fromBlobFloat(float[] data, int[] shape, int device); @@ -185,7 +187,7 @@ class JNoa { public static native long expTensor(long tensorHandle); - public static native long logTensor(long tensorHandle); + public static native long lnTensor(long tensorHandle); public static native long sumTensor(long tensorHandle); @@ -213,7 +215,7 @@ class JNoa { public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle); - public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors); + public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle); public static native boolean requiresGrad(long tensorHandle); diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index 3a0ce28df..56bd712c0 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of tensor source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ package space.kscience.kmath.noa @@ -15,112 +15,118 @@ import space.kscience.kmath.tensors.api.TensorAlgebra public sealed class NoaAlgebra> constructor(protected val scope: NoaScope) : TensorAlgebra { - protected abstract fun Tensor.cast(): TensorType + protected abstract val Tensor.tensor: TensorType protected abstract fun wrap(tensorHandle: TensorHandle): TensorType /** - * A scalar tensor in this.cast() implementation must have empty shape + * A scalar tensor must have empty shape */ override fun Tensor.valueOrNull(): T? = try { - this.cast().cast().item() + tensor.item() } catch (e: NoaException) { null } - override fun Tensor.value(): T = this.cast().cast().item() + override fun Tensor.value(): T = tensor.item() override operator fun Tensor.times(other: Tensor): TensorType { - return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } override operator fun Tensor.timesAssign(other: Tensor): Unit { - JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.timesTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) } override operator fun Tensor.plus(other: Tensor): TensorType { - return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.plusTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } override operator fun Tensor.plusAssign(other: Tensor): Unit { - JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.plusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) } override operator fun Tensor.minus(other: Tensor): TensorType { - return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.minusTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } override operator fun Tensor.minusAssign(other: Tensor): Unit { - JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.minusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) } override operator fun Tensor.unaryMinus(): TensorType = - wrap(JNoa.unaryMinus(this.cast().tensorHandle)) + wrap(JNoa.unaryMinus(tensor.tensorHandle)) override infix fun Tensor.dot(other: Tensor): TensorType { - return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.matmul(tensor.tensorHandle, other.tensor.tensorHandle)) } public infix fun Tensor.dotAssign(other: Tensor): Unit { - JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.matmulAssign(tensor.tensorHandle, other.tensor.tensorHandle) } public infix fun Tensor.dotRightAssign(other: Tensor): Unit { - JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.matmulRightAssign(tensor.tensorHandle, other.tensor.tensorHandle) } override operator fun Tensor.get(i: Int): TensorType = - wrap(JNoa.getIndex(this.cast().tensorHandle, i)) + wrap(JNoa.getIndex(tensor.tensorHandle, i)) public operator fun Tensor.get(indexTensor: NoaLongTensor): TensorType = - wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle)) + wrap(JNoa.getIndexTensor(tensor.tensorHandle, indexTensor.tensorHandle)) override fun diagonalEmbedding( diagonalEntries: Tensor, offset: Int, dim1: Int, dim2: Int ): TensorType = - wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2)) + wrap(JNoa.diagEmbed(diagonalEntries.tensor.tensorHandle, offset, dim1, dim2)) override fun Tensor.transpose(i: Int, j: Int): TensorType { - return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j)) + return wrap(JNoa.transposeTensor(tensor.tensorHandle, i, j)) } override fun Tensor.view(shape: IntArray): TensorType { - return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape)) + return wrap(JNoa.viewTensor(tensor.tensorHandle, shape)) } override fun Tensor.viewAs(other: Tensor): TensorType { - return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.viewAsTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } - public fun Tensor.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle)) + public fun Tensor.abs(): TensorType = wrap(JNoa.absTensor(tensor.tensorHandle)) - public fun Tensor.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle)) + public fun Tensor.sumAll(): TensorType = wrap(JNoa.sumTensor(tensor.tensorHandle)) override fun Tensor.sum(): T = sumAll().item() override fun Tensor.sum(dim: Int, keepDim: Boolean): TensorType = - wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim)) + wrap(JNoa.sumDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle)) + public fun Tensor.minAll(): TensorType = wrap(JNoa.minTensor(tensor.tensorHandle)) override fun Tensor.min(): T = minAll().item() override fun Tensor.min(dim: Int, keepDim: Boolean): TensorType = - wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim)) + wrap(JNoa.minDimTensor(tensor.tensorHandle, dim, keepDim)) - public fun Tensor.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle)) + public fun Tensor.maxAll(): TensorType = wrap(JNoa.maxTensor(tensor.tensorHandle)) override fun Tensor.max(): T = maxAll().item() override fun Tensor.max(dim: Int, keepDim: Boolean): TensorType = - wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim)) + wrap(JNoa.maxDimTensor(tensor.tensorHandle, dim, keepDim)) override fun Tensor.argMax(dim: Int, keepDim: Boolean): NoaIntTensor = - NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim)) + NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim)) public fun Tensor.flatten(): TensorType = - wrap(JNoa.flattenTensor(this.cast().tensorHandle)) + wrap(JNoa.flattenTensor(tensor.tensorHandle)) + + public fun Tensor.randIntegral(low: Long, high: Long): TensorType = + wrap(JNoa.randintLike(tensor.tensorHandle, low, high)) + + public fun Tensor.randIntegralAssign(low: Long, high: Long): Unit = + JNoa.randintLikeAssign(tensor.tensorHandle, low, high) public fun Tensor.copy(): TensorType = - wrap(JNoa.copyTensor(this.cast().tensorHandle)) + wrap(JNoa.copyTensor(tensor.tensorHandle)) public fun Tensor.copyToDevice(device: Device): TensorType = - wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt())) + wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt())) } @@ -129,13 +135,58 @@ internal constructor(scope: NoaScope) : NoaAlgebra(scope), Linear AnalyticTensorAlgebra { override operator fun Tensor.div(other: Tensor): TensorType { - return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + return wrap(JNoa.divTensor(tensor.tensorHandle, other.tensor.tensorHandle)) } override operator fun Tensor.divAssign(other: Tensor): Unit { - JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle) } + public fun Tensor.randUniform(): TensorType = + wrap(JNoa.randLike(tensor.tensorHandle)) + + public fun Tensor.randUniformAssign(): Unit = + JNoa.randLikeAssign(tensor.tensorHandle) + + public fun Tensor.randNormal(): TensorType = + wrap(JNoa.randnLike(tensor.tensorHandle)) + + public fun Tensor.randNormalAssign(): Unit = + JNoa.randnLikeAssign(tensor.tensorHandle) + + override fun Tensor.exp(): TensorType = + wrap(JNoa.expTensor(tensor.tensorHandle)) + + override fun Tensor.ln(): TensorType = + wrap(JNoa.lnTensor(tensor.tensorHandle)) + + + override fun Tensor.svd(): Triple { + val U = JNoa.emptyTensor() + val V = JNoa.emptyTensor() + val S = JNoa.emptyTensor() + JNoa.svdTensor(tensor.tensorHandle, U, S, V) + return Triple(wrap(U), wrap(S), wrap(V)) + } + + override fun Tensor.symEig(): Pair { + val V = JNoa.emptyTensor() + val S = JNoa.emptyTensor() + JNoa.symeigTensor(tensor.tensorHandle, S, V) + return Pair(wrap(S), wrap(V)) + } + + public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType { + return wrap(JNoa.autogradTensor(tensorHandle, variable.tensorHandle, retainGraph)) + } + + public infix fun TensorType.hess(variable: TensorType): TensorType { + return wrap(JNoa.autohessTensor(tensorHandle, variable.tensorHandle)) + } + + public fun TensorType.detachFromGraph(): TensorType = + wrap(JNoa.detachFromGraph(tensorHandle)) + }