basics for div algebra

This commit is contained in:
Roland Grinis 2021-07-08 22:09:53 +01:00
parent 803a88ac2c
commit 0088be99f5
2 changed files with 89 additions and 36 deletions

View File

@ -33,6 +33,8 @@ class JNoa {
public static native void disposeTensor(long tensorHandle);
public static native long emptyTensor();
public static native long fromBlobDouble(double[] data, int[] shape, int device);
public static native long fromBlobFloat(float[] data, int[] shape, int device);
@ -185,7 +187,7 @@ class JNoa {
public static native long expTensor(long tensorHandle);
public static native long logTensor(long tensorHandle);
public static native long lnTensor(long tensorHandle);
public static native long sumTensor(long tensorHandle);
@ -213,7 +215,7 @@ class JNoa {
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle);
public static native boolean requiresGrad(long tensorHandle);

View File

@ -1,6 +1,6 @@
/*
* Copyright 2018-2021 KMath contributors.
* Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
* Use of tensor source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package space.kscience.kmath.noa
@ -15,112 +15,118 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
protected abstract fun Tensor<T>.cast(): TensorType
protected abstract val Tensor<T>.tensor: TensorType
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
/**
* A scalar tensor in this.cast() implementation must have empty shape
* A scalar tensor must have empty shape
*/
override fun Tensor<T>.valueOrNull(): T? =
try {
this.cast().cast().item()
tensor.item()
} catch (e: NoaException) {
null
}
override fun Tensor<T>.value(): T = this.cast().cast().item()
override fun Tensor<T>.value(): T = tensor.item()
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
override operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit {
JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.timesTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
override operator fun Tensor<T>.plus(other: Tensor<T>): TensorType {
return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.plusTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
override operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit {
JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.plusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
override operator fun Tensor<T>.minus(other: Tensor<T>): TensorType {
return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.minusTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
override operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit {
JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.minusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
override operator fun Tensor<T>.unaryMinus(): TensorType =
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
wrap(JNoa.unaryMinus(tensor.tensorHandle))
override infix fun Tensor<T>.dot(other: Tensor<T>): TensorType {
return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.matmul(tensor.tensorHandle, other.tensor.tensorHandle))
}
public infix fun Tensor<T>.dotAssign(other: Tensor<T>): Unit {
JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.matmulAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
public infix fun Tensor<T>.dotRightAssign(other: Tensor<T>): Unit {
JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.matmulRightAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
override operator fun Tensor<T>.get(i: Int): TensorType =
wrap(JNoa.getIndex(this.cast().tensorHandle, i))
wrap(JNoa.getIndex(tensor.tensorHandle, i))
public operator fun Tensor<T>.get(indexTensor: NoaLongTensor): TensorType =
wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle))
wrap(JNoa.getIndexTensor(tensor.tensorHandle, indexTensor.tensorHandle))
override fun diagonalEmbedding(
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
): TensorType =
wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2))
wrap(JNoa.diagEmbed(diagonalEntries.tensor.tensorHandle, offset, dim1, dim2))
override fun Tensor<T>.transpose(i: Int, j: Int): TensorType {
return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j))
return wrap(JNoa.transposeTensor(tensor.tensorHandle, i, j))
}
override fun Tensor<T>.view(shape: IntArray): TensorType {
return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape))
return wrap(JNoa.viewTensor(tensor.tensorHandle, shape))
}
override fun Tensor<T>.viewAs(other: Tensor<T>): TensorType {
return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.viewAsTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle))
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(tensor.tensorHandle))
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle))
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(tensor.tensorHandle))
override fun Tensor<T>.sum(): T = sumAll().item()
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
wrap(JNoa.sumDimTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle))
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(tensor.tensorHandle))
override fun Tensor<T>.min(): T = minAll().item()
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim))
wrap(JNoa.minDimTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle))
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(tensor.tensorHandle))
override fun Tensor<T>.max(): T = maxAll().item()
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): TensorType =
wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim))
wrap(JNoa.maxDimTensor(tensor.tensorHandle, dim, keepDim))
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor =
NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim))
NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim))
public fun Tensor<T>.flatten(): TensorType =
wrap(JNoa.flattenTensor(this.cast().tensorHandle))
wrap(JNoa.flattenTensor(tensor.tensorHandle))
public fun Tensor<T>.randIntegral(low: Long, high: Long): TensorType =
wrap(JNoa.randintLike(tensor.tensorHandle, low, high))
public fun Tensor<T>.randIntegralAssign(low: Long, high: Long): Unit =
JNoa.randintLikeAssign(tensor.tensorHandle, low, high)
public fun Tensor<T>.copy(): TensorType =
wrap(JNoa.copyTensor(this.cast().tensorHandle))
wrap(JNoa.copyTensor(tensor.tensorHandle))
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt()))
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
}
@ -129,13 +135,58 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
AnalyticTensorAlgebra<T> {
override operator fun Tensor<T>.div(other: Tensor<T>): TensorType {
return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle))
return wrap(JNoa.divTensor(tensor.tensorHandle, other.tensor.tensorHandle))
}
override operator fun Tensor<T>.divAssign(other: Tensor<T>): Unit {
JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
}
public fun Tensor<T>.randUniform(): TensorType =
wrap(JNoa.randLike(tensor.tensorHandle))
public fun Tensor<T>.randUniformAssign(): Unit =
JNoa.randLikeAssign(tensor.tensorHandle)
public fun Tensor<T>.randNormal(): TensorType =
wrap(JNoa.randnLike(tensor.tensorHandle))
public fun Tensor<T>.randNormalAssign(): Unit =
JNoa.randnLikeAssign(tensor.tensorHandle)
override fun Tensor<T>.exp(): TensorType =
wrap(JNoa.expTensor(tensor.tensorHandle))
override fun Tensor<T>.ln(): TensorType =
wrap(JNoa.lnTensor(tensor.tensorHandle))
override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> {
val U = JNoa.emptyTensor()
val V = JNoa.emptyTensor()
val S = JNoa.emptyTensor()
JNoa.svdTensor(tensor.tensorHandle, U, S, V)
return Triple(wrap(U), wrap(S), wrap(V))
}
override fun Tensor<T>.symEig(): Pair<TensorType, TensorType> {
val V = JNoa.emptyTensor()
val S = JNoa.emptyTensor()
JNoa.symeigTensor(tensor.tensorHandle, S, V)
return Pair(wrap(S), wrap(V))
}
public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType {
return wrap(JNoa.autogradTensor(tensorHandle, variable.tensorHandle, retainGraph))
}
public infix fun TensorType.hess(variable: TensorType): TensorType {
return wrap(JNoa.autohessTensor(tensorHandle, variable.tensorHandle))
}
public fun TensorType.detachFromGraph(): TensorType =
wrap(JNoa.detachFromGraph(tensorHandle))
}