forked from kscience/kmath
Autograd support drafting
This commit is contained in:
parent
7894799e8e
commit
7105331149
@ -76,7 +76,7 @@ pluginManagement {
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||||
```kotlin
|
```kotlin
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
|
|
||||||
@ -94,4 +94,24 @@ TorchTensorRealAlgebra {
|
|||||||
println(gpuRealTensor)
|
println(gpuRealTensor)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
Enjoy a high performance automatic differentiation engine:
|
||||||
|
```kotlin
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val dim = 10
|
||||||
|
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
||||||
|
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
// x is the variable
|
||||||
|
x.requiresGrad = true
|
||||||
|
|
||||||
|
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||||
|
val Q = X + X.transpose(0,1)
|
||||||
|
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
val c = randNormal(shape = IntArray(0), device = device)
|
||||||
|
|
||||||
|
// expression to differentiate w.r.t. x
|
||||||
|
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||||
|
// value of the gradient at x
|
||||||
|
val gradf = f grad x
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
@ -40,24 +40,51 @@ extern "C"
|
|||||||
void dispose_char(char *ptr);
|
void dispose_char(char *ptr);
|
||||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
|
||||||
|
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
long get_long(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
int get_int(TorchTensorHandle tensor_handle, int *index);
|
||||||
|
void set_double(TorchTensorHandle tensor_handle, int *index, double value);
|
||||||
|
void set_float(TorchTensorHandle tensor_handle, int *index, float value);
|
||||||
|
void set_long(TorchTensorHandle tensor_handle, int *index, long value);
|
||||||
|
void set_int(TorchTensorHandle tensor_handle, int *index, int value);
|
||||||
|
|
||||||
double get_double(TorchTensorHandle tensor_handle, int* index);
|
TorchTensorHandle randn_double(int *shape, int shape_size, int device);
|
||||||
float get_float(TorchTensorHandle tensor_handle, int* index);
|
TorchTensorHandle rand_double(int *shape, int shape_size, int device);
|
||||||
long get_long(TorchTensorHandle tensor_handle, int* index);
|
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||||
int get_int(TorchTensorHandle tensor_handle, int* index);
|
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||||
void set_double(TorchTensorHandle tensor_handle, int* index, double value);
|
|
||||||
void set_float(TorchTensorHandle tensor_handle, int* index, float value);
|
|
||||||
void set_long(TorchTensorHandle tensor_handle, int* index, long value);
|
|
||||||
void set_int(TorchTensorHandle tensor_handle, int* index, int value);
|
|
||||||
|
|
||||||
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle randn_double(int* shape, int shape_size, int device);
|
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle rand_double(int* shape, int shape_size, int device);
|
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle randn_float(int* shape, int shape_size, int device);
|
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||||
TorchTensorHandle rand_float(int* shape, int shape_size, int device);
|
|
||||||
|
|
||||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
|
||||||
|
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||||
|
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
void times_assign_double(double value, TorchTensorHandle other);
|
||||||
|
void times_assign_float(float value, TorchTensorHandle other);
|
||||||
|
void times_assign_long(long value, TorchTensorHandle other);
|
||||||
|
void times_assign_int(int value, TorchTensorHandle other);
|
||||||
|
|
||||||
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||||
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||||
|
void detach_from_graph(TorchTensorHandle tensor_handle);
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -92,4 +92,10 @@ namespace ctorch
|
|||||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Dtype>
|
||||||
|
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
||||||
|
{
|
||||||
|
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ctorch
|
} // namespace ctorch
|
||||||
|
@ -156,12 +156,111 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
|||||||
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<float>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<long>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||||
|
}
|
||||||
|
|
||||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
{
|
{
|
||||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
{
|
{
|
||||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
}
|
}
|
||||||
|
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
TorchTensorHandle times_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(value * ctorch::cast(other));
|
||||||
|
}
|
||||||
|
void times_assign_double(double value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_assign_float(float value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_assign_long(long value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
void times_assign_int(int value, TorchTensorHandle other)
|
||||||
|
{
|
||||||
|
ctorch::cast(other) *= value;
|
||||||
|
}
|
||||||
|
|
||||||
|
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||||
|
}
|
||||||
|
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||||
|
{
|
||||||
|
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||||
|
}
|
||||||
|
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||||
|
}
|
||||||
|
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||||
|
}
|
||||||
|
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
return ctorch::cast(tensor_handle).requires_grad();
|
||||||
|
}
|
||||||
|
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||||
|
}
|
||||||
|
void detach_from_graph(TorchTensorHandle tensor_handle)
|
||||||
|
{
|
||||||
|
ctorch::cast(tensor_handle).detach();
|
||||||
|
}
|
||||||
|
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||||
|
{
|
||||||
|
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||||
|
}
|
||||||
|
@ -6,8 +6,17 @@ import kotlin.test.*
|
|||||||
class TestTorchTensorAlgebraGPU {
|
class TestTorchTensorAlgebraGPU {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testScalarProduct() = testingScalarProduct(device = TorchDevice.TorchCUDA(0))
|
fun testScalarProduct() =
|
||||||
|
testingScalarProduct(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testMatrixMultiplication() = testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
|
fun testMatrixMultiplication() =
|
||||||
|
testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLinearStructure() =
|
||||||
|
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
||||||
}
|
}
|
||||||
|
@ -13,21 +13,4 @@ internal class TestUtilsGPU {
|
|||||||
@Test
|
@Test
|
||||||
fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0))
|
fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0))
|
||||||
|
|
||||||
@Test
|
|
||||||
fun testReadmeFactory() = TorchTensorRealAlgebra {
|
|
||||||
|
|
||||||
val realTensor: TorchTensorReal = copyFromArray(
|
|
||||||
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
|
||||||
shape = intArrayOf(2,5)
|
|
||||||
)
|
|
||||||
println(realTensor)
|
|
||||||
|
|
||||||
val gpuRealTensor: TorchTensorReal = copyFromArray(
|
|
||||||
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
|
|
||||||
shape = intArrayOf(2, 2, 2),
|
|
||||||
device = TorchDevice.TorchCUDA(0)
|
|
||||||
)
|
|
||||||
println(gpuRealTensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,13 @@ public sealed class TorchTensor<T> constructor(
|
|||||||
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
|
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
public var requiresGrad: Boolean
|
||||||
|
get() = requires_grad(tensorHandle)
|
||||||
|
set(value) = requires_grad_(tensorHandle, value)
|
||||||
|
|
||||||
|
public fun detachFromGraph(): Unit {
|
||||||
|
detach_from_graph(tensorHandle)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TorchTensorReal internal constructor(
|
public class TorchTensorReal internal constructor(
|
||||||
|
@ -14,6 +14,11 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
|||||||
): TorchTensor<T>
|
): TorchTensor<T>
|
||||||
|
|
||||||
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
||||||
|
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||||
|
|
||||||
|
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
||||||
|
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
||||||
|
|
||||||
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||||
@ -21,6 +26,34 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
|||||||
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
|
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
|
||||||
matmul_assign(this.tensorHandle, other.tensorHandle)
|
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit {
|
||||||
|
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
|
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit {
|
||||||
|
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> =
|
||||||
|
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||||
|
|
||||||
|
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
||||||
|
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
||||||
|
|
||||||
|
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!)
|
||||||
|
|
||||||
|
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> =
|
||||||
|
wrap(transpose_tensor(tensorHandle, i , j)!!)
|
||||||
|
|
||||||
|
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> =
|
||||||
|
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||||
}
|
}
|
||||||
|
|
||||||
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
|
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
|
||||||
@ -60,6 +93,27 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
|||||||
scope = scope,
|
scope = scope,
|
||||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = times_double(value, this.tensorHandle)!!
|
||||||
|
)
|
||||||
|
|
||||||
|
override fun TorchTensor<Double>.timesAssign(value: Double): Unit {
|
||||||
|
times_assign_double(value, this.tensorHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||||
|
scope = scope,
|
||||||
|
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
|
||||||
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||||
|
@ -19,4 +19,14 @@ class TestTorchTensor {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun testCopyFromArray() = testingCopyFromArray()
|
fun testCopyFromArray() = testingCopyFromArray()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||||
|
val tensor = randNormal(intArrayOf(3))
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = true
|
||||||
|
assertTrue(tensor.requiresGrad)
|
||||||
|
tensor.requiresGrad = false
|
||||||
|
assertTrue(!tensor.requiresGrad)
|
||||||
|
}
|
||||||
}
|
}
|
@ -4,12 +4,13 @@ import kscience.kmath.linear.RealMatrixContext
|
|||||||
import kscience.kmath.operations.invoke
|
import kscience.kmath.operations.invoke
|
||||||
import kscience.kmath.structures.Matrix
|
import kscience.kmath.structures.Matrix
|
||||||
import kotlin.math.abs
|
import kotlin.math.abs
|
||||||
|
import kotlin.math.exp
|
||||||
import kotlin.test.*
|
import kotlin.test.*
|
||||||
|
|
||||||
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
val lhs = randUniform(shape = intArrayOf(10), device = device)
|
val lhs = randUniform(shape = intArrayOf(3), device = device)
|
||||||
val rhs = randUniform(shape = intArrayOf(10), device = device)
|
val rhs = randUniform(shape = intArrayOf(3), device = device)
|
||||||
val product = lhs dot rhs
|
val product = lhs dot rhs
|
||||||
var expected = 0.0
|
var expected = 0.0
|
||||||
lhs.elements().forEach {
|
lhs.elements().forEach {
|
||||||
@ -23,27 +24,76 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
|
|||||||
TorchTensorRealAlgebra {
|
TorchTensorRealAlgebra {
|
||||||
setSeed(SEED)
|
setSeed(SEED)
|
||||||
|
|
||||||
val lhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
|
val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
val rhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
|
val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||||
val product = lhsTensor dot rhsTensor
|
val product = lhsTensor dot rhsTensor
|
||||||
|
|
||||||
val expected: Matrix<Double> = RealMatrixContext {
|
val expected: Matrix<Double> = RealMatrixContext {
|
||||||
val lhs = produce(20, 20) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
||||||
val rhs = produce(20, 20) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
||||||
lhs dot rhs
|
lhs dot rhs
|
||||||
}
|
}
|
||||||
|
|
||||||
lhsTensor dotAssign rhsTensor
|
val lhsTensorCopy = lhsTensor.copy()
|
||||||
|
val rhsTensorCopy = rhsTensor.copy()
|
||||||
|
|
||||||
|
lhsTensorCopy dotAssign rhsTensor
|
||||||
|
lhsTensor dotRightAssign rhsTensorCopy
|
||||||
|
|
||||||
var error: Double = 0.0
|
var error: Double = 0.0
|
||||||
product.elements().forEach {
|
product.elements().forEach {
|
||||||
error += abs(expected[it.first] - it.second) +
|
error += abs(expected[it.first] - it.second) +
|
||||||
abs(expected[it.first] - lhsTensor[it.first])
|
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
||||||
|
abs(expected[it.first] - rhsTensorCopy[it.first])
|
||||||
}
|
}
|
||||||
assertTrue(error < TOLERANCE)
|
assertTrue(error < TOLERANCE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
val shape = intArrayOf(3)
|
||||||
|
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||||
|
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||||
|
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||||
|
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
|
||||||
|
val expected = copyFromArray(
|
||||||
|
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
|
||||||
|
shape = shape,
|
||||||
|
device = device
|
||||||
|
)
|
||||||
|
|
||||||
|
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||||
|
tensorA *= 15.8
|
||||||
|
tensorB *= 1.5
|
||||||
|
tensorC *= 0.02
|
||||||
|
assignResult += tensorA
|
||||||
|
assignResult -= tensorB
|
||||||
|
assignResult += tensorC
|
||||||
|
|
||||||
|
val error = (expected - result).abs().sum().value() +
|
||||||
|
(expected - assignResult).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||||
|
TorchTensorRealAlgebra {
|
||||||
|
setSeed(SEED)
|
||||||
|
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
x.requiresGrad = true
|
||||||
|
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||||
|
val Q = X + X.transpose(0,1)
|
||||||
|
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||||
|
val c = randNormal(shape = IntArray(0), device = device)
|
||||||
|
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||||
|
val gradf = f grad x
|
||||||
|
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
|
||||||
|
assertTrue(error < TOLERANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
internal class TestTorchTensorAlgebra {
|
internal class TestTorchTensorAlgebra {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -52,4 +102,10 @@ internal class TestTorchTensorAlgebra {
|
|||||||
@Test
|
@Test
|
||||||
fun testMatrixMultiplication() = testingMatrixMultiplication()
|
fun testMatrixMultiplication() = testingMatrixMultiplication()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testLinearStructure() = testingLinearStructure()
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||||
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user