forked from kscience/kmath
Autograd support drafting
This commit is contained in:
parent
7894799e8e
commit
7105331149
@ -76,7 +76,7 @@ pluginManagement {
|
||||
|
||||
## Usage
|
||||
|
||||
Tensors implement the buffer protocol over `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||
Tensors are implemented over the `MutableNDStructure`. They can only be instantiated through provided factory methods and require scoping:
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
|
||||
@ -94,4 +94,24 @@ TorchTensorRealAlgebra {
|
||||
println(gpuRealTensor)
|
||||
}
|
||||
```
|
||||
Enjoy a high performance automatic differentiation engine:
|
||||
```kotlin
|
||||
TorchTensorRealAlgebra {
|
||||
val dim = 10
|
||||
val device = TorchDevice.TorchCPU //or TorchDevice.TorchCUDA(0)
|
||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||
// x is the variable
|
||||
x.requiresGrad = true
|
||||
|
||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||
val Q = X + X.transpose(0,1)
|
||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val c = randNormal(shape = IntArray(0), device = device)
|
||||
|
||||
// expression to differentiate w.r.t. x
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||
// value of the gradient at x
|
||||
val gradf = f grad x
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -40,7 +40,6 @@ extern "C"
|
||||
void dispose_char(char *ptr);
|
||||
void dispose_tensor(TorchTensorHandle tensor_handle);
|
||||
|
||||
|
||||
double get_double(TorchTensorHandle tensor_handle, int *index);
|
||||
float get_float(TorchTensorHandle tensor_handle, int *index);
|
||||
long get_long(TorchTensorHandle tensor_handle, int *index);
|
||||
@ -50,14 +49,42 @@ extern "C"
|
||||
void set_long(TorchTensorHandle tensor_handle, int *index, long value);
|
||||
void set_int(TorchTensorHandle tensor_handle, int *index, int value);
|
||||
|
||||
|
||||
TorchTensorHandle randn_double(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle rand_double(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle randn_float(int *shape, int shape_size, int device);
|
||||
TorchTensorHandle rand_float(int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device);
|
||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device);
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_float(float value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_long(long value, TorchTensorHandle other);
|
||||
TorchTensorHandle times_int(int value, TorchTensorHandle other);
|
||||
|
||||
void times_assign_double(double value, TorchTensorHandle other);
|
||||
void times_assign_float(float value, TorchTensorHandle other);
|
||||
void times_assign_long(long value, TorchTensorHandle other);
|
||||
void times_assign_int(int value, TorchTensorHandle other);
|
||||
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs);
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j);
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle);
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status);
|
||||
void detach_from_graph(TorchTensorHandle tensor_handle);
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -92,4 +92,10 @@ namespace ctorch
|
||||
return torch::rand(shape, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
inline torch::Tensor full(Dtype value, std::vector<int64_t> shape, torch::Device device)
|
||||
{
|
||||
return torch::full(shape, value, torch::TensorOptions().dtype(dtype<Dtype>()).layout(torch::kStrided).device(device));
|
||||
}
|
||||
|
||||
} // namespace ctorch
|
||||
|
@ -156,12 +156,111 @@ TorchTensorHandle rand_float(int *shape, int shape_size, int device)
|
||||
return new torch::Tensor(ctorch::rand<float>(ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle full_double(double value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<double>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_float(float value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<float>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_long(long value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<long>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
TorchTensorHandle full_int(int value, int *shape, int shape_size, int device)
|
||||
{
|
||||
return new torch::Tensor(ctorch::full<int>(value, ctorch::to_vec_int(shape, shape_size), ctorch::int_to_device(device)));
|
||||
}
|
||||
|
||||
TorchTensorHandle matmul(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(torch::matmul(ctorch::cast(lhs), ctorch::cast(rhs)));
|
||||
}
|
||||
|
||||
void matmul_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
void matmul_right_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(rhs) = ctorch::cast(lhs).matmul(ctorch::cast(rhs));
|
||||
}
|
||||
|
||||
TorchTensorHandle times_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
TorchTensorHandle times_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
return new torch::Tensor(value * ctorch::cast(other));
|
||||
}
|
||||
void times_assign_double(double value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_assign_float(float value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_assign_long(long value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
void times_assign_int(int value, TorchTensorHandle other)
|
||||
{
|
||||
ctorch::cast(other) *= value;
|
||||
}
|
||||
|
||||
TorchTensorHandle plus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) + ctorch::cast(rhs));
|
||||
}
|
||||
void plus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) += ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle minus_tensor(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(lhs) - ctorch::cast(rhs));
|
||||
}
|
||||
void minus_tensor_assign(TorchTensorHandle lhs, TorchTensorHandle rhs)
|
||||
{
|
||||
ctorch::cast(lhs) -= ctorch::cast(rhs);
|
||||
}
|
||||
TorchTensorHandle abs_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).abs());
|
||||
}
|
||||
TorchTensorHandle sum_tensor(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).sum());
|
||||
}
|
||||
TorchTensorHandle transpose_tensor(TorchTensorHandle tensor_handle, int i, int j)
|
||||
{
|
||||
return new torch::Tensor(ctorch::cast(tensor_handle).transpose(i,j));
|
||||
}
|
||||
|
||||
bool requires_grad(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
return ctorch::cast(tensor_handle).requires_grad();
|
||||
}
|
||||
void requires_grad_(TorchTensorHandle tensor_handle, bool status)
|
||||
{
|
||||
ctorch::cast(tensor_handle).requires_grad_(status);
|
||||
}
|
||||
void detach_from_graph(TorchTensorHandle tensor_handle)
|
||||
{
|
||||
ctorch::cast(tensor_handle).detach();
|
||||
}
|
||||
TorchTensorHandle autograd_tensor(TorchTensorHandle value, TorchTensorHandle variable)
|
||||
{
|
||||
return new torch::Tensor(torch::autograd::grad({ctorch::cast(value)}, {ctorch::cast(variable)})[0]);
|
||||
}
|
||||
|
@ -6,8 +6,17 @@ import kotlin.test.*
|
||||
class TestTorchTensorAlgebraGPU {
|
||||
|
||||
@Test
|
||||
fun testScalarProduct() = testingScalarProduct(device = TorchDevice.TorchCUDA(0))
|
||||
fun testScalarProduct() =
|
||||
testingScalarProduct(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testMatrixMultiplication() = testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
|
||||
fun testMatrixMultiplication() =
|
||||
testingMatrixMultiplication(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testLinearStructure() =
|
||||
testingLinearStructure(device = TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 3, device = TorchDevice.TorchCUDA(0))
|
||||
}
|
||||
|
@ -13,21 +13,4 @@ internal class TestUtilsGPU {
|
||||
@Test
|
||||
fun testSetSeed() = testingSetSeed(TorchDevice.TorchCUDA(0))
|
||||
|
||||
@Test
|
||||
fun testReadmeFactory() = TorchTensorRealAlgebra {
|
||||
|
||||
val realTensor: TorchTensorReal = copyFromArray(
|
||||
array = (1..10).map { it + 50.0 }.toList().toDoubleArray(),
|
||||
shape = intArrayOf(2,5)
|
||||
)
|
||||
println(realTensor)
|
||||
|
||||
val gpuRealTensor: TorchTensorReal = copyFromArray(
|
||||
array = (1..8).map { it * 2.5 }.toList().toDoubleArray(),
|
||||
shape = intArrayOf(2, 2, 2),
|
||||
device = TorchDevice.TorchCUDA(0)
|
||||
)
|
||||
println(gpuRealTensor)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -62,6 +62,13 @@ public sealed class TorchTensor<T> constructor(
|
||||
outTensorHandle = copy_to_device(tensorHandle, device.toInt())!!
|
||||
)
|
||||
|
||||
public var requiresGrad: Boolean
|
||||
get() = requires_grad(tensorHandle)
|
||||
set(value) = requires_grad_(tensorHandle, value)
|
||||
|
||||
public fun detachFromGraph(): Unit {
|
||||
detach_from_graph(tensorHandle)
|
||||
}
|
||||
}
|
||||
|
||||
public class TorchTensorReal internal constructor(
|
||||
|
@ -14,6 +14,11 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
||||
): TorchTensor<T>
|
||||
|
||||
public abstract fun TorchTensor<T>.copyToArray(): PrimitiveArrayType
|
||||
public abstract fun full(value: T, shape: IntArray, device: TorchDevice): TorchTensor<T>
|
||||
|
||||
public abstract operator fun T.times(other: TorchTensor<T>): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.times(value: T): TorchTensor<T>
|
||||
public abstract operator fun TorchTensor<T>.timesAssign(value: T): Unit
|
||||
|
||||
public infix fun TorchTensor<T>.dot(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(matmul(this.tensorHandle, other.tensorHandle)!!)
|
||||
@ -21,6 +26,34 @@ public sealed class TorchTensorAlgebra<T, PrimitiveArrayType> constructor(
|
||||
public infix fun TorchTensor<T>.dotAssign(other: TorchTensor<T>): Unit {
|
||||
matmul_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public infix fun TorchTensor<T>.dotRightAssign(other: TorchTensor<T>): Unit {
|
||||
matmul_right_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public operator fun TorchTensor<T>.plus(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(plus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
public operator fun TorchTensor<T>.plusAssign(other: TorchTensor<T>): Unit {
|
||||
plus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public operator fun TorchTensor<T>.minus(other: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(minus_tensor(this.tensorHandle, other.tensorHandle)!!)
|
||||
|
||||
public operator fun TorchTensor<T>.minusAssign(other: TorchTensor<T>): Unit {
|
||||
minus_tensor_assign(this.tensorHandle, other.tensorHandle)
|
||||
}
|
||||
|
||||
public fun TorchTensor<T>.abs(): TorchTensor<T> = wrap(abs_tensor(tensorHandle)!!)
|
||||
|
||||
public fun TorchTensor<T>.sum(): TorchTensor<T> = wrap(sum_tensor(tensorHandle)!!)
|
||||
|
||||
public fun TorchTensor<T>.transpose(i: Int, j: Int): TorchTensor<T> =
|
||||
wrap(transpose_tensor(tensorHandle, i , j)!!)
|
||||
|
||||
public infix fun TorchTensor<T>.grad(variable: TorchTensor<T>): TorchTensor<T> =
|
||||
wrap(autograd_tensor(this.tensorHandle, variable.tensorHandle)!!)
|
||||
}
|
||||
|
||||
public sealed class TorchTensorFieldAlgebra<T, PrimitiveArrayType>(scope: DeferScope) :
|
||||
@ -60,6 +93,27 @@ public class TorchTensorRealAlgebra(scope: DeferScope) : TorchTensorFieldAlgebra
|
||||
scope = scope,
|
||||
tensorHandle = rand_double(shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
|
||||
override operator fun Double.times(other: TorchTensor<Double>): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = times_double(this, other.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.times(value: Double): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = times_double(value, this.tensorHandle)!!
|
||||
)
|
||||
|
||||
override fun TorchTensor<Double>.timesAssign(value: Double): Unit {
|
||||
times_assign_double(value, this.tensorHandle)
|
||||
}
|
||||
|
||||
|
||||
override fun full(value: Double, shape: IntArray, device: TorchDevice): TorchTensorReal = TorchTensorReal(
|
||||
scope = scope,
|
||||
tensorHandle = full_double(value, shape.toCValues(), shape.size, device.toInt())!!
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
public fun <R> TorchTensorRealAlgebra(block: TorchTensorRealAlgebra.() -> R): R =
|
||||
|
@ -19,4 +19,14 @@ class TestTorchTensor {
|
||||
|
||||
@Test
|
||||
fun testCopyFromArray() = testingCopyFromArray()
|
||||
|
||||
@Test
|
||||
fun testRequiresGrad() = TorchTensorRealAlgebra {
|
||||
val tensor = randNormal(intArrayOf(3))
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
tensor.requiresGrad = true
|
||||
assertTrue(tensor.requiresGrad)
|
||||
tensor.requiresGrad = false
|
||||
assertTrue(!tensor.requiresGrad)
|
||||
}
|
||||
}
|
@ -4,12 +4,13 @@ import kscience.kmath.linear.RealMatrixContext
|
||||
import kscience.kmath.operations.invoke
|
||||
import kscience.kmath.structures.Matrix
|
||||
import kotlin.math.abs
|
||||
import kotlin.math.exp
|
||||
import kotlin.test.*
|
||||
|
||||
internal fun testingScalarProduct(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
val lhs = randUniform(shape = intArrayOf(10), device = device)
|
||||
val rhs = randUniform(shape = intArrayOf(10), device = device)
|
||||
val lhs = randUniform(shape = intArrayOf(3), device = device)
|
||||
val rhs = randUniform(shape = intArrayOf(3), device = device)
|
||||
val product = lhs dot rhs
|
||||
var expected = 0.0
|
||||
lhs.elements().forEach {
|
||||
@ -23,27 +24,76 @@ internal fun testingMatrixMultiplication(device: TorchDevice = TorchDevice.Torch
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
|
||||
val lhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
|
||||
val rhsTensor = randNormal(shape = intArrayOf(20, 20), device = device)
|
||||
val lhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val rhsTensor = randNormal(shape = intArrayOf(3, 3), device = device)
|
||||
val product = lhsTensor dot rhsTensor
|
||||
|
||||
val expected: Matrix<Double> = RealMatrixContext {
|
||||
val lhs = produce(20, 20) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
||||
val rhs = produce(20, 20) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
||||
val lhs = produce(3, 3) { i, j -> lhsTensor[intArrayOf(i, j)] }
|
||||
val rhs = produce(3, 3) { i, j -> rhsTensor[intArrayOf(i, j)] }
|
||||
lhs dot rhs
|
||||
}
|
||||
|
||||
lhsTensor dotAssign rhsTensor
|
||||
val lhsTensorCopy = lhsTensor.copy()
|
||||
val rhsTensorCopy = rhsTensor.copy()
|
||||
|
||||
lhsTensorCopy dotAssign rhsTensor
|
||||
lhsTensor dotRightAssign rhsTensorCopy
|
||||
|
||||
var error: Double = 0.0
|
||||
product.elements().forEach {
|
||||
error += abs(expected[it.first] - it.second) +
|
||||
abs(expected[it.first] - lhsTensor[it.first])
|
||||
abs(expected[it.first] - lhsTensorCopy[it.first]) +
|
||||
abs(expected[it.first] - rhsTensorCopy[it.first])
|
||||
}
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingLinearStructure(device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
val shape = intArrayOf(3)
|
||||
val tensorA = full(value = -4.5, shape = shape, device = device)
|
||||
val tensorB = full(value = 10.9, shape = shape, device = device)
|
||||
val tensorC = full(value = 789.3, shape = shape, device = device)
|
||||
val result = 15.8 * tensorA - 1.5 * tensorB + 0.02 * tensorC
|
||||
val expected = copyFromArray(
|
||||
array = (1..3).map { 15.8 * (-4.5) - 1.5 * 10.9 + 0.02 * 789.3 }.toDoubleArray(),
|
||||
shape = shape,
|
||||
device = device
|
||||
)
|
||||
|
||||
val assignResult = full(value = 0.0, shape = shape, device = device)
|
||||
tensorA *= 15.8
|
||||
tensorB *= 1.5
|
||||
tensorC *= 0.02
|
||||
assignResult += tensorA
|
||||
assignResult -= tensorB
|
||||
assignResult += tensorC
|
||||
|
||||
val error = (expected - result).abs().sum().value() +
|
||||
(expected - assignResult).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun testingAutoGrad(dim: Int, device: TorchDevice = TorchDevice.TorchCPU): Unit {
|
||||
TorchTensorRealAlgebra {
|
||||
setSeed(SEED)
|
||||
val x = randNormal(shape = intArrayOf(dim), device = device)
|
||||
x.requiresGrad = true
|
||||
val X = randNormal(shape = intArrayOf(dim,dim), device = device)
|
||||
val Q = X + X.transpose(0,1)
|
||||
val mu = randNormal(shape = intArrayOf(dim), device = device)
|
||||
val c = randNormal(shape = IntArray(0), device = device)
|
||||
val f = 0.5 * (x dot (Q dot x)) + (mu dot x) + c
|
||||
val gradf = f grad x
|
||||
val error = (gradf - ((Q dot x) + mu)).abs().sum().value()
|
||||
assertTrue(error < TOLERANCE)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
internal class TestTorchTensorAlgebra {
|
||||
|
||||
@Test
|
||||
@ -52,4 +102,10 @@ internal class TestTorchTensorAlgebra {
|
||||
@Test
|
||||
fun testMatrixMultiplication() = testingMatrixMultiplication()
|
||||
|
||||
@Test
|
||||
fun testLinearStructure() = testingLinearStructure()
|
||||
|
||||
@Test
|
||||
fun testAutoGrad() = testingAutoGrad(dim = 100)
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user