From 623c96e4bb55f683db308f2ced757b6c7d5c3fdc Mon Sep 17 00:00:00 2001 From: Roland Grinis Date: Thu, 8 Jul 2021 11:03:20 +0100 Subject: [PATCH] basic algebra ops --- .../space/kscience/kmath/noa/algebras.kt | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt index deea592d4..fa164df29 100644 --- a/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt +++ b/kmath-noa/src/main/kotlin/space/kscience/kmath/noa/algebras.kt @@ -1,6 +1,6 @@ /* * Copyright 2018-2021 KMath contributors. - * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + * Use of this.cast() source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. */ package space.kscience.kmath.noa @@ -20,18 +20,56 @@ constructor(protected val scope: NoaScope) : TensorAlgebra { protected abstract fun wrap(tensorHandle: TensorHandle): TensorType /** - * A scalar tensor in this implementation must have empty shape + * A scalar tensor in this.cast() implementation must have empty shape */ override fun Tensor.valueOrNull(): T? = - try { this.cast().item() } catch (e: NoaException) { null } + try { + this.cast().cast().item() + } catch (e: NoaException) { + null + } - override fun Tensor.value(): T = this.cast().item() + override fun Tensor.value(): T = this.cast().cast().item() + + override operator fun Tensor.times(other: Tensor): TensorType { + return wrap(JNoa.timesTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + override operator fun Tensor.timesAssign(other: Tensor): Unit { + JNoa.timesTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } + + override operator fun Tensor.plus(other: Tensor): TensorType { + return wrap(JNoa.plusTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + override operator fun Tensor.plusAssign(other: Tensor): Unit { + JNoa.plusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } + + override operator fun Tensor.minus(other: Tensor): TensorType { + return wrap(JNoa.minusTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + override operator fun Tensor.minusAssign(other: Tensor): Unit { + JNoa.minusTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } + + override operator fun Tensor.unaryMinus(): TensorType = + wrap(JNoa.unaryMinus(this.cast().tensorHandle)) } public abstract class NoaPartialDivisionAlgebra> internal constructor(scope: NoaScope) : NoaAlgebra(scope), LinearOpsTensorAlgebra, AnalyticTensorAlgebra { + override operator fun Tensor.div(other: Tensor): TensorType { + return wrap(JNoa.divTensor(this.cast().tensorHandle, other.cast().tensorHandle)) + } + + override operator fun Tensor.divAssign(other: Tensor): Unit { + JNoa.divTensorAssign(this.cast().tensorHandle, other.cast().tensorHandle) + } }