forked from kscience/kmath
basics for div algebra
This commit is contained in:
parent
803a88ac2c
commit
0088be99f5
@ -33,6 +33,8 @@ class JNoa {
|
||||
|
||||
public static native void disposeTensor(long tensorHandle);
|
||||
|
||||
public static native long emptyTensor();
|
||||
|
||||
public static native long fromBlobDouble(double[] data, int[] shape, int device);
|
||||
|
||||
public static native long fromBlobFloat(float[] data, int[] shape, int device);
|
||||
@ -185,7 +187,7 @@ class JNoa {
|
||||
|
||||
public static native long expTensor(long tensorHandle);
|
||||
|
||||
public static native long logTensor(long tensorHandle);
|
||||
public static native long lnTensor(long tensorHandle);
|
||||
|
||||
public static native long sumTensor(long tensorHandle);
|
||||
|
||||
@ -213,7 +215,7 @@ class JNoa {
|
||||
|
||||
public static native void svdTensor(long tensorHandle, long Uhandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle, boolean eigenvectors);
|
||||
public static native void symeigTensor(long tensorHandle, long Shandle, long Vhandle);
|
||||
|
||||
public static native boolean requiresGrad(long tensorHandle);
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
/*
|
||||
* Copyright 2018-2021 KMath contributors.
|
||||
* Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
* Use of tensor source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
|
||||
*/
|
||||
|
||||
package space.kscience.kmath.noa
|
||||
@ -15,112 +15,118 @@ import space.kscience.kmath.tensors.api.TensorAlgebra
|
||||
public sealed class NoaAlgebra<T, TensorType : NoaTensor<T>>
|
||||
constructor(protected val scope: NoaScope) : TensorAlgebra<T> {
|
||||
|
||||
protected abstract fun Tensor<T>.cast(): TensorType
|
||||
protected abstract val Tensor<T>.tensor: TensorType
|
||||
|
||||
protected abstract fun wrap(tensorHandle: TensorHandle): TensorType
|
||||
|
||||
/**
|
||||
* A scalar tensor in this.cast() implementation must have empty shape
|
||||
* A scalar tensor must have empty shape
|
||||
*/
|
||||
override fun Tensor<T>.valueOrNull(): T? =
|
||||
try {
|
||||
this.cast().cast().item()
|
||||
tensor.item()
|
||||
} catch (e: NoaException) {
|
||||
null
|
||||
}
|
||||
|
||||
override fun Tensor<T>.value(): T = this.cast().cast().item()
|
||||
override fun Tensor<T>.value(): T = tensor.item()
|
||||
|
||||
override operator fun Tensor<T>.times(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.timesTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.timesAssign(other: Tensor<T>): Unit {
|
||||
JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.timesTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.plus(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.plusTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.plusAssign(other: Tensor<T>): Unit {
|
||||
JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.plusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.minus(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.minusTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.minusAssign(other: Tensor<T>): Unit {
|
||||
JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.minusTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.unaryMinus(): TensorType =
|
||||
wrap(JNoa.unaryMinus(this.cast().tensorHandle))
|
||||
wrap(JNoa.unaryMinus(tensor.tensorHandle))
|
||||
|
||||
override infix fun Tensor<T>.dot(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.matmul(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.matmul(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
public infix fun Tensor<T>.dotAssign(other: Tensor<T>): Unit {
|
||||
JNoa.matmulAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.matmulAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun Tensor<T>.dotRightAssign(other: Tensor<T>): Unit {
|
||||
JNoa.matmulRightAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.matmulRightAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.get(i: Int): TensorType =
|
||||
wrap(JNoa.getIndex(this.cast().tensorHandle, i))
|
||||
wrap(JNoa.getIndex(tensor.tensorHandle, i))
|
||||
|
||||
public operator fun Tensor<T>.get(indexTensor: NoaLongTensor): TensorType =
|
||||
wrap(JNoa.getIndexTensor(this.cast().tensorHandle, indexTensor.tensorHandle))
|
||||
wrap(JNoa.getIndexTensor(tensor.tensorHandle, indexTensor.tensorHandle))
|
||||
|
||||
override fun diagonalEmbedding(
|
||||
diagonalEntries: Tensor<T>, offset: Int, dim1: Int, dim2: Int
|
||||
): TensorType =
|
||||
wrap(JNoa.diagEmbed(diagonalEntries.cast().tensorHandle, offset, dim1, dim2))
|
||||
wrap(JNoa.diagEmbed(diagonalEntries.tensor.tensorHandle, offset, dim1, dim2))
|
||||
|
||||
override fun Tensor<T>.transpose(i: Int, j: Int): TensorType {
|
||||
return wrap(JNoa.transposeTensor(this.cast().tensorHandle, i, j))
|
||||
return wrap(JNoa.transposeTensor(tensor.tensorHandle, i, j))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.view(shape: IntArray): TensorType {
|
||||
return wrap(JNoa.viewTensor(this.cast().tensorHandle, shape))
|
||||
return wrap(JNoa.viewTensor(tensor.tensorHandle, shape))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.viewAs(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.viewAsTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.viewAsTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(this.cast().tensorHandle))
|
||||
public fun Tensor<T>.abs(): TensorType = wrap(JNoa.absTensor(tensor.tensorHandle))
|
||||
|
||||
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(this.cast().tensorHandle))
|
||||
public fun Tensor<T>.sumAll(): TensorType = wrap(JNoa.sumTensor(tensor.tensorHandle))
|
||||
override fun Tensor<T>.sum(): T = sumAll().item()
|
||||
override fun Tensor<T>.sum(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.sumDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
wrap(JNoa.sumDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(this.cast().tensorHandle))
|
||||
public fun Tensor<T>.minAll(): TensorType = wrap(JNoa.minTensor(tensor.tensorHandle))
|
||||
override fun Tensor<T>.min(): T = minAll().item()
|
||||
override fun Tensor<T>.min(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.minDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
wrap(JNoa.minDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(this.cast().tensorHandle))
|
||||
public fun Tensor<T>.maxAll(): TensorType = wrap(JNoa.maxTensor(tensor.tensorHandle))
|
||||
override fun Tensor<T>.max(): T = maxAll().item()
|
||||
override fun Tensor<T>.max(dim: Int, keepDim: Boolean): TensorType =
|
||||
wrap(JNoa.maxDimTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
wrap(JNoa.maxDimTensor(tensor.tensorHandle, dim, keepDim))
|
||||
|
||||
override fun Tensor<T>.argMax(dim: Int, keepDim: Boolean): NoaIntTensor =
|
||||
NoaIntTensor(scope, JNoa.argMaxTensor(this.cast().tensorHandle, dim, keepDim))
|
||||
NoaIntTensor(scope, JNoa.argMaxTensor(tensor.tensorHandle, dim, keepDim))
|
||||
|
||||
public fun Tensor<T>.flatten(): TensorType =
|
||||
wrap(JNoa.flattenTensor(this.cast().tensorHandle))
|
||||
wrap(JNoa.flattenTensor(tensor.tensorHandle))
|
||||
|
||||
public fun Tensor<T>.randIntegral(low: Long, high: Long): TensorType =
|
||||
wrap(JNoa.randintLike(tensor.tensorHandle, low, high))
|
||||
|
||||
public fun Tensor<T>.randIntegralAssign(low: Long, high: Long): Unit =
|
||||
JNoa.randintLikeAssign(tensor.tensorHandle, low, high)
|
||||
|
||||
public fun Tensor<T>.copy(): TensorType =
|
||||
wrap(JNoa.copyTensor(this.cast().tensorHandle))
|
||||
wrap(JNoa.copyTensor(tensor.tensorHandle))
|
||||
|
||||
public fun Tensor<T>.copyToDevice(device: Device): TensorType =
|
||||
wrap(JNoa.copyToDevice(this.cast().tensorHandle, device.toInt()))
|
||||
wrap(JNoa.copyToDevice(tensor.tensorHandle, device.toInt()))
|
||||
|
||||
}
|
||||
|
||||
@ -129,13 +135,58 @@ internal constructor(scope: NoaScope) : NoaAlgebra<T, TensorType>(scope), Linear
|
||||
AnalyticTensorAlgebra<T> {
|
||||
|
||||
override operator fun Tensor<T>.div(other: Tensor<T>): TensorType {
|
||||
return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle))
|
||||
return wrap(JNoa.divTensor(tensor.tensorHandle, other.tensor.tensorHandle))
|
||||
}
|
||||
|
||||
override operator fun Tensor<T>.divAssign(other: Tensor<T>): Unit {
|
||||
JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle)
|
||||
JNoa.divTensorAssign(tensor.tensorHandle, other.tensor.tensorHandle)
|
||||
}
|
||||
|
||||
public fun Tensor<T>.randUniform(): TensorType =
|
||||
wrap(JNoa.randLike(tensor.tensorHandle))
|
||||
|
||||
public fun Tensor<T>.randUniformAssign(): Unit =
|
||||
JNoa.randLikeAssign(tensor.tensorHandle)
|
||||
|
||||
public fun Tensor<T>.randNormal(): TensorType =
|
||||
wrap(JNoa.randnLike(tensor.tensorHandle))
|
||||
|
||||
public fun Tensor<T>.randNormalAssign(): Unit =
|
||||
JNoa.randnLikeAssign(tensor.tensorHandle)
|
||||
|
||||
override fun Tensor<T>.exp(): TensorType =
|
||||
wrap(JNoa.expTensor(tensor.tensorHandle))
|
||||
|
||||
override fun Tensor<T>.ln(): TensorType =
|
||||
wrap(JNoa.lnTensor(tensor.tensorHandle))
|
||||
|
||||
|
||||
override fun Tensor<T>.svd(): Triple<TensorType, TensorType, TensorType> {
|
||||
val U = JNoa.emptyTensor()
|
||||
val V = JNoa.emptyTensor()
|
||||
val S = JNoa.emptyTensor()
|
||||
JNoa.svdTensor(tensor.tensorHandle, U, S, V)
|
||||
return Triple(wrap(U), wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
override fun Tensor<T>.symEig(): Pair<TensorType, TensorType> {
|
||||
val V = JNoa.emptyTensor()
|
||||
val S = JNoa.emptyTensor()
|
||||
JNoa.symeigTensor(tensor.tensorHandle, S, V)
|
||||
return Pair(wrap(S), wrap(V))
|
||||
}
|
||||
|
||||
public fun TensorType.grad(variable: TensorType, retainGraph: Boolean): TensorType {
|
||||
return wrap(JNoa.autogradTensor(tensorHandle, variable.tensorHandle, retainGraph))
|
||||
}
|
||||
|
||||
public infix fun TensorType.hess(variable: TensorType): TensorType {
|
||||
return wrap(JNoa.autohessTensor(tensorHandle, variable.tensorHandle))
|
||||
}
|
||||
|
||||
public fun TensorType.detachFromGraph(): TensorType =
|
||||
wrap(JNoa.detachFromGraph(tensorHandle))
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user