diff --git a/kmath-torch/README.md b/kmath-torch/README.md index 6e0368036..7d5c76833 100644 --- a/kmath-torch/README.md +++ b/kmath-torch/README.md @@ -89,7 +89,7 @@ memScoped { scope = this, array = floatArrayOf(7f,8.9f,2.6f,5.6f), shape = intArrayOf(4)) - println(intTensor) + println(floatTensor) } ``` diff --git a/kmath-torch/build.gradle.kts b/kmath-torch/build.gradle.kts index 08f74fd88..5db0af2f7 100644 --- a/kmath-torch/build.gradle.kts +++ b/kmath-torch/build.gradle.kts @@ -12,6 +12,10 @@ val home = System.getProperty("user.home") val thirdPartyDir = "$home/.konan/third-party/kmath-torch-${project.property("version")}" val cppBuildDir = "$thirdPartyDir/cpp-build" +val cudaHome: String? = System.getenv("CUDA_HOME") +val cudaDefault = file("/usr/local/cuda").exists() +val cudaFound = cudaHome?.isNotEmpty() ?: false or cudaDefault + val cmakeArchive = "cmake-3.19.2-Linux-x86_64" val torchArchive = "libtorch" @@ -32,8 +36,11 @@ val downloadNinja by tasks.registering(Download::class) { } val downloadTorch by tasks.registering(Download::class) { - val zipFile = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2Bcu110.zip" - src("https://download.pytorch.org/libtorch/cu110/$zipFile") + val abiMeta = "$torchArchive-cxx11-abi-shared-with-deps-1.7.1%2B" + val cudaUrl = "https://download.pytorch.org/libtorch/cu110/${abiMeta}cu110.zip" + val cpuUrl = "https://download.pytorch.org/libtorch/cpu/${abiMeta}cpu.zip" + val url = if (cudaFound) cudaUrl else cpuUrl + src(url) dest(File(thirdPartyDir, "$torchArchive.zip")) overwrite(false) } diff --git a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt index 3a174b186..d49cb3720 100644 --- a/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt +++ b/kmath-torch/src/nativeTest/kotlin/kscience/kmath/torch/TestUtils.kt @@ -1,5 +1,6 @@ package kscience.kmath.torch +import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -12,8 +13,4 @@ internal class TestUtils { setNumThreads(numThreads) assertEquals(numThreads, getNumThreads()) } - @Test - fun cudaAvailability(){ - assertTrue(cudaAvailable()) - } } \ No newline at end of file