Change tensor algebra operations to prefix notations
This commit is contained in:
parent
7e4d223044
commit
bc8ffd02c2
@ -48,7 +48,7 @@ fun reluDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
class ReLU : Activation(::relu, ::reluDer)
|
||||
|
||||
fun sigmoid(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
1.0 / (1.0 + (-x).exp())
|
||||
1.0 / (1.0 + exp(-x))
|
||||
}
|
||||
|
||||
fun sigmoidDer(x: DoubleTensor): DoubleTensor = DoubleTensorAlgebra {
|
||||
@ -116,7 +116,7 @@ class NeuralNetwork(private val layers: List<Layer>) {
|
||||
onesForAnswers[intArrayOf(index, label)] = 1.0
|
||||
}
|
||||
|
||||
val softmaxValue = yPred.exp() / yPred.exp().sum(dim = 1, keepDim = true)
|
||||
val softmaxValue = exp(yPred) / exp(yPred).sum(dim = 1, keepDim = true)
|
||||
|
||||
(-onesForAnswers + softmaxValue) / (yPred.shape[0].toDouble())
|
||||
}
|
||||
|
@ -114,27 +114,29 @@ public sealed interface Nd4jTensorAlgebra<T : Number, A : Field<T>> : AnalyticTe
|
||||
override fun StructureND<T>.mean(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||
ndArray.mean(keepDim, dim).wrap()
|
||||
|
||||
override fun StructureND<T>.exp(): Nd4jArrayStructure<T> = Transforms.exp(ndArray).wrap()
|
||||
override fun StructureND<T>.ln(): Nd4jArrayStructure<T> = Transforms.log(ndArray).wrap()
|
||||
override fun StructureND<T>.sqrt(): Nd4jArrayStructure<T> = Transforms.sqrt(ndArray).wrap()
|
||||
override fun StructureND<T>.cos(): Nd4jArrayStructure<T> = Transforms.cos(ndArray).wrap()
|
||||
override fun StructureND<T>.acos(): Nd4jArrayStructure<T> = Transforms.acos(ndArray).wrap()
|
||||
override fun StructureND<T>.cosh(): Nd4jArrayStructure<T> = Transforms.cosh(ndArray).wrap()
|
||||
|
||||
override fun StructureND<T>.acosh(): Nd4jArrayStructure<T> =
|
||||
Nd4j.getExecutioner().exec(ACosh(ndArray, ndArray.ulike())).wrap()
|
||||
override fun exp(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.exp(arg.ndArray).wrap()
|
||||
override fun ln(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.log(arg.ndArray).wrap()
|
||||
override fun cos(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.cos(arg.ndArray).wrap()
|
||||
override fun acos(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.acos(arg.ndArray).wrap()
|
||||
override fun cosh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.cosh(arg.ndArray).wrap()
|
||||
|
||||
override fun StructureND<T>.sin(): Nd4jArrayStructure<T> = Transforms.sin(ndArray).wrap()
|
||||
override fun StructureND<T>.asin(): Nd4jArrayStructure<T> = Transforms.asin(ndArray).wrap()
|
||||
override fun StructureND<T>.sinh(): Tensor<T> = Transforms.sinh(ndArray).wrap()
|
||||
override fun acosh(arg: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
Nd4j.getExecutioner().exec(ACosh(arg.ndArray, arg.ndArray.ulike())).wrap()
|
||||
|
||||
override fun StructureND<T>.asinh(): Nd4jArrayStructure<T> =
|
||||
Nd4j.getExecutioner().exec(ASinh(ndArray, ndArray.ulike())).wrap()
|
||||
override fun sin(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.sin(arg.ndArray).wrap()
|
||||
override fun asin(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.asin(arg.ndArray).wrap()
|
||||
override fun sinh(arg: StructureND<T>): Tensor<T> = Transforms.sinh(arg.ndArray).wrap()
|
||||
|
||||
override fun asinh(arg: StructureND<T>): Nd4jArrayStructure<T> =
|
||||
Nd4j.getExecutioner().exec(ASinh(arg.ndArray, arg.ndArray.ulike())).wrap()
|
||||
|
||||
override fun tan(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.tan(arg.ndArray).wrap()
|
||||
override fun atan(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.atan(arg.ndArray).wrap()
|
||||
override fun tanh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.tanh(arg.ndArray).wrap()
|
||||
override fun atanh(arg: StructureND<T>): Nd4jArrayStructure<T> = Transforms.atanh(arg.ndArray).wrap()
|
||||
|
||||
override fun StructureND<T>.tan(): Nd4jArrayStructure<T> = Transforms.tan(ndArray).wrap()
|
||||
override fun StructureND<T>.atan(): Nd4jArrayStructure<T> = Transforms.atan(ndArray).wrap()
|
||||
override fun StructureND<T>.tanh(): Nd4jArrayStructure<T> = Transforms.tanh(ndArray).wrap()
|
||||
override fun StructureND<T>.atanh(): Nd4jArrayStructure<T> = Transforms.atanh(ndArray).wrap()
|
||||
override fun StructureND<T>.ceil(): Nd4jArrayStructure<T> = Transforms.ceil(ndArray).wrap()
|
||||
override fun StructureND<T>.floor(): Nd4jArrayStructure<T> = Transforms.floor(ndArray).wrap()
|
||||
override fun StructureND<T>.std(dim: Int, keepDim: Boolean): Nd4jArrayStructure<T> =
|
||||
@ -178,7 +180,6 @@ public object DoubleNd4jTensorAlgebra : Nd4jTensorAlgebra<Double, DoubleField> {
|
||||
return array.wrap()
|
||||
}
|
||||
|
||||
|
||||
@OptIn(PerformancePitfall::class)
|
||||
override val StructureND<Double>.ndArray: INDArray
|
||||
get() = when (this) {
|
||||
|
@ -72,78 +72,95 @@ public interface AnalyticTensorAlgebra<T, A : Field<T>> :
|
||||
*/
|
||||
public fun StructureND<T>.variance(dim: Int, keepDim: Boolean): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||
public fun StructureND<T>.exp(): Tensor<T>
|
||||
override fun sin(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.log.html
|
||||
public fun StructureND<T>.ln(): Tensor<T>
|
||||
override fun cos(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun asin(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun acos(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun tan(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun atan(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun exp(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun ln(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun sinh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun cosh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun tanh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun asinh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun acosh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
override fun atanh(arg: StructureND<T>): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.sqrt.html
|
||||
public fun StructureND<T>.sqrt(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
||||
public fun StructureND<T>.cos(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
|
||||
public fun StructureND<T>.acos(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
|
||||
public fun StructureND<T>.cosh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
|
||||
public fun StructureND<T>.acosh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
|
||||
public fun StructureND<T>.sin(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
|
||||
public fun StructureND<T>.asin(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
|
||||
public fun StructureND<T>.sinh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
|
||||
public fun StructureND<T>.asinh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
|
||||
public fun StructureND<T>.tan(): Tensor<T>
|
||||
|
||||
//https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
|
||||
public fun StructureND<T>.atan(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
|
||||
public fun StructureND<T>.tanh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
|
||||
public fun StructureND<T>.atanh(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.ceil.html#torch.ceil
|
||||
public fun StructureND<T>.ceil(): Tensor<T>
|
||||
|
||||
//For information: https://pytorch.org/docs/stable/generated/torch.floor.html#torch.floor
|
||||
public fun StructureND<T>.floor(): Tensor<T>
|
||||
|
||||
override fun sin(arg: StructureND<T>): Tensor<T> = arg.sin()
|
||||
|
||||
override fun cos(arg: StructureND<T>): Tensor<T> = arg.cos()
|
||||
|
||||
override fun asin(arg: StructureND<T>): Tensor<T> = arg.asin()
|
||||
|
||||
override fun acos(arg: StructureND<T>): Tensor<T> = arg.acos()
|
||||
|
||||
override fun atan(arg: StructureND<T>): Tensor<T> = arg.atan()
|
||||
|
||||
override fun exp(arg: StructureND<T>): Tensor<T> = arg.exp()
|
||||
|
||||
override fun ln(arg: StructureND<T>): Tensor<T> = arg.ln()
|
||||
|
||||
override fun sinh(arg: StructureND<T>): Tensor<T> = arg.sinh()
|
||||
|
||||
override fun cosh(arg: StructureND<T>): Tensor<T> = arg.cosh()
|
||||
|
||||
override fun asinh(arg: StructureND<T>): Tensor<T> = arg.asinh()
|
||||
|
||||
override fun acosh(arg: StructureND<T>): Tensor<T> = arg.acosh()
|
||||
|
||||
override fun atanh(arg: StructureND<T>): Tensor<T> = arg.atanh()
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.exp.html
|
||||
// public fun StructureND<T>.exp(): Tensor<T> = exp(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.log.html
|
||||
// @JvmName("lnChain")
|
||||
// public fun StructureND<T>.ln(): Tensor<T> = ln(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.cos
|
||||
// @JvmName("cosChain")
|
||||
// public fun StructureND<T>.cos(): Tensor<T> = cos(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.acos.html#torch.acos
|
||||
// @JvmName("acosChain")
|
||||
// public fun StructureND<T>.acos(): Tensor<T> = acos(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.cosh
|
||||
// @JvmName("coshChain")
|
||||
// public fun StructureND<T>.cosh(): Tensor<T> = cosh(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.acosh.html#torch.acosh
|
||||
// @JvmName("acoshChain")
|
||||
// public fun StructureND<T>.acosh(): Tensor<T> = acosh(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sin
|
||||
// @JvmName("sinChain")
|
||||
// public fun StructureND<T>.sin(): Tensor<T> = sin(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asin
|
||||
// @JvmName("asinChain")
|
||||
// public fun StructureND<T>.asin(): Tensor<T> = asin(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.sinh
|
||||
// @JvmName("sinhChain")
|
||||
// public fun StructureND<T>.sinh(): Tensor<T> = sinh(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.asin.html#torch.asinh
|
||||
// @JvmName("asinhChain")
|
||||
// public fun StructureND<T>.asinh(): Tensor<T> = asinh(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.atan.html#torch.tan
|
||||
// @JvmName("tanChain")
|
||||
// public fun StructureND<T>.tan(): Tensor<T> = tan(this)
|
||||
//
|
||||
// //https://pytorch.org/docs/stable/generated/torch.atan.html#torch.atan
|
||||
// @JvmName("atanChain")
|
||||
// public fun StructureND<T>.atan(): Tensor<T> = atan(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.tanh
|
||||
// @JvmName("tanhChain")
|
||||
// public fun StructureND<T>.tanh(): Tensor<T> = tanh(this)
|
||||
//
|
||||
// //For information: https://pytorch.org/docs/stable/generated/torch.atanh.html#torch.atanh
|
||||
// @JvmName("atanhChain")
|
||||
// public fun StructureND<T>.atanh(): Tensor<T> = atanh(this)
|
||||
}
|
@ -685,35 +685,35 @@ public open class DoubleTensorAlgebra :
|
||||
return resTensor
|
||||
}
|
||||
|
||||
override fun StructureND<Double>.exp(): DoubleTensor = tensor.map { exp(it) }
|
||||
override fun exp(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { exp(it) }
|
||||
|
||||
override fun StructureND<Double>.ln(): DoubleTensor = tensor.map { ln(it) }
|
||||
override fun ln(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { ln(it) }
|
||||
|
||||
override fun StructureND<Double>.sqrt(): DoubleTensor = tensor.map { sqrt(it) }
|
||||
|
||||
override fun StructureND<Double>.cos(): DoubleTensor = tensor.map { cos(it) }
|
||||
override fun cos(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { cos(it) }
|
||||
|
||||
override fun StructureND<Double>.acos(): DoubleTensor = tensor.map { acos(it) }
|
||||
override fun acos(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { acos(it) }
|
||||
|
||||
override fun StructureND<Double>.cosh(): DoubleTensor = tensor.map { cosh(it) }
|
||||
override fun cosh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { cosh(it) }
|
||||
|
||||
override fun StructureND<Double>.acosh(): DoubleTensor = tensor.map { acosh(it) }
|
||||
override fun acosh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { acosh(it) }
|
||||
|
||||
override fun StructureND<Double>.sin(): DoubleTensor = tensor.map { sin(it) }
|
||||
override fun sin(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { sin(it) }
|
||||
|
||||
override fun StructureND<Double>.asin(): DoubleTensor = tensor.map { asin(it) }
|
||||
override fun asin(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { asin(it) }
|
||||
|
||||
override fun StructureND<Double>.sinh(): DoubleTensor = tensor.map { sinh(it) }
|
||||
override fun sinh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { sinh(it) }
|
||||
|
||||
override fun StructureND<Double>.asinh(): DoubleTensor = tensor.map { asinh(it) }
|
||||
override fun asinh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { asinh(it) }
|
||||
|
||||
override fun StructureND<Double>.tan(): DoubleTensor = tensor.map { tan(it) }
|
||||
override fun tan(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { tan(it) }
|
||||
|
||||
override fun StructureND<Double>.atan(): DoubleTensor = tensor.map { atan(it) }
|
||||
override fun atan(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { atan(it) }
|
||||
|
||||
override fun StructureND<Double>.tanh(): DoubleTensor = tensor.map { tanh(it) }
|
||||
override fun tanh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { tanh(it) }
|
||||
|
||||
override fun StructureND<Double>.atanh(): DoubleTensor = tensor.map { atanh(it) }
|
||||
override fun atanh(arg: StructureND<Double>): DoubleTensor = arg.tensor.map { atanh(it) }
|
||||
|
||||
override fun StructureND<Double>.ceil(): DoubleTensor = tensor.map { ceil(it) }
|
||||
|
||||
|
@ -32,12 +32,12 @@ internal class TestDoubleAnalyticTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testExp() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.exp() eq expectedTensor(::exp) }
|
||||
assertTrue { exp(tensor) eq expectedTensor(::exp) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testLog() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.ln() eq expectedTensor(::ln) }
|
||||
assertTrue { ln(tensor) eq expectedTensor(::ln) }
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -47,48 +47,48 @@ internal class TestDoubleAnalyticTensorAlgebra {
|
||||
|
||||
@Test
|
||||
fun testCos() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.cos() eq expectedTensor(::cos) }
|
||||
assertTrue { cos(tensor) eq expectedTensor(::cos) }
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun testCosh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.cosh() eq expectedTensor(::cosh) }
|
||||
assertTrue { cosh(tensor) eq expectedTensor(::cosh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAcosh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.acosh() eq expectedTensor(::acosh) }
|
||||
assertTrue { acosh(tensor) eq expectedTensor(::acosh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSin() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.sin() eq expectedTensor(::sin) }
|
||||
assertTrue { sin(tensor) eq expectedTensor(::sin) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testSinh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.sinh() eq expectedTensor(::sinh) }
|
||||
assertTrue { sinh(tensor) eq expectedTensor(::sinh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAsinh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.asinh() eq expectedTensor(::asinh) }
|
||||
assertTrue { asinh(tensor) eq expectedTensor(::asinh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTan() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.tan() eq expectedTensor(::tan) }
|
||||
assertTrue { tan(tensor) eq expectedTensor(::tan) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testAtan() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.atan() eq expectedTensor(::atan) }
|
||||
assertTrue { atan(tensor) eq expectedTensor(::atan) }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testTanh() = DoubleTensorAlgebra {
|
||||
assertTrue { tensor.tanh() eq expectedTensor(::tanh) }
|
||||
assertTrue { tanh(tensor) eq expectedTensor(::tanh) }
|
||||
}
|
||||
|
||||
@Test
|
||||
|
Loading…
Reference in New Issue
Block a user