Mobile Kotlin TensorFlow
This is a Kotlin MultiPlatform library that provides access to TensorFlow-Lite functionality from common source set.
Table of Contents
Features
Requirements
- Gradle version 5.6.4+
- Android API 19+
- iOS version 9.0+
Versions
Bintray
- kotlin 1.3.72
- 0.1.0
mavenCentral
- kotlin 1.4.31
- 0.1.1
Installation
root build.gradle
allprojects {
repositories {
mavenCentral()
}
}
project build.gradle
dependencies {
commonMainApi("dev.icerock.moko:tensorflow:0.1.1")
}
cocoaPods {
podsProject = file("../ios-app/Pods/Pods.xcodeproj") // here should be path to Pods xcode project
pod("TensorFlowLiteObjC", module = "TFLTensorFlowLite", onlyLink = true)
}
kotlin {
targets
.filterIsInstance<org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget>()
.flatMap { it.binaries }
.filterIsInstance<org.jetbrains.kotlin.gradle.plugin.mpp.Framework>()
.forEach { framework ->
framework.linkerOpts(
project.file("../ios-app/Pods/TensorFlowLiteC/Frameworks").path.let { "-F$it" },
"-framework",
"TensorFlowLiteC"
)
}
}
Podfile
pod 'TensorFlowLiteObjC', '~> 2.2.0'
Usage
First place the model file in the multi-platform resource folder commonMain/resources/MR/files
.
common
:
class Classifier(private val interpreter: Interpreter) {
fun classify(inputData: Any) {
val inputShape = interpreter.getInputTensor(0).shape
val inputSize = inputShape[1]
val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
interpreter.run(listOf(inputData), mapOf(Pair(0, result)))
}
}
Getting shared model file (in common
):
object ResHolder {
fun getModelFile(): FileResource {
return MR.files.mymodel
}
}
android
:
class MainActivity : AppCompatActivity() {
private lateinit var interpreter: Interpreter
override fun onCreate(savedInstanceState: Bundle?) {
interpreter = Interpreter(ResHolder.getModelFile(), InterpreterOptions(2, useNNAPI = true), this)
val classifier = Classifier(interpreter)
classifier.classify(data)
}
override fun onDestroy() {
super.onDestroy()
interpreter.close()
}
}
iOS
:
class ViewController: UIViewController {
private var interpreter: TensorflowInterpreter?
override func viewDidLoad() {
super.viewDidLoad()
let options: TensorflowInterpreterOptions = TensorflowInterpreterOptions(numThreads: 2)
let modelFileRes: ResourcesFileResource = ResHolder().getModelFile()
interpreter = TensorflowInterpreter(fileResource: modelFileRes, options: options)
let classifier = Classifier(interpreter: interpreter!)
classifier.classify(data)
}
deinit {
interpreter?.close()
}
}
Samples
Please see more examples in the sample directory.
Set Up Locally
- The tensorflow directory contains the
tensorflow
library; - The sample directory contains sample apps for Android and iOS; plus the mpp-library connected to the apps;
- For local testing a use the
./publishToMavenLocal.sh
script - so that sample apps use the locally published version.
Contributing
All development (both new features and bug fixes) is performed in the develop
branch. This way master
always contains the sources of the most recently released version. Please send PRs with bug fixes to the develop
branch. Documentation fixes in the markdown files are an exception to this rule. They are updated directly in master
.
The develop
branch is pushed to master
on release.
For more details on contributing please see the contributing guide.
License
Copyright 2020 IceRock MAG Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.