本篇博文用来研究YOLOv5在Android上部署的例程
主要参考的是Pytorch官方提供的Demo:https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp
App主页如下图所示:

主要功能:
切换测试图片
在程序中直接指定三张(或任意张)图片,点击测试图片,可以切换图片
选择图片
点击选择图片,可以在相册中选择一张图片,也可以直接进行拍照
实时视频
点击实时视频,可以开启摄像头,直接在摄像预览中显示检测结果
切换模型(我添加的功能)
点击切换模型,可以选择不同的模型进行检测
首先来跑通官方Demo,首先下载官方提供的yolov5s.torchscript.ptl
下载链接:https://pytorch-mobile-demo-apps.s3.us-east-2.amazonaws.com/yolov5s.torchscript.ptl
下载完放到assets文件夹下
直接运行,从相册中选择图片时会报错:
Unable to decode stream: java.io.FileNotFoundException:/…/open failed: EACCES (Permission denied)
此时需要在AndroidManifest.xml的application标签中添加一句:
android:requestLegacyExternalStorage="true"
然后就可以正常运行了
下面用YOLOv5-6.0版本训练自己的模型,怎么训练不做赘述,可以参考本专栏的往期博文。
然后修改export.py中的export_torchscript函数,主要添加三行代码,用以导出.torchscript.ptl后缀模型。
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):# YOLOv5 TorchScript model exporttry:print(f'\n{prefix} starting export with torch {torch.__version__}...')f = file.with_suffix('.torchscript.pt')f = str(f)fl = file.with_suffix('.torchscript.ptl')ts = torch.jit.trace(model, im, strict=False)(optimize_for_mobile(ts) if optimize else ts).save(f)(optimize_for_mobile(ts) if optimize else ts)._save_for_lite_interpreter(str(fl))print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')except Exception as e:print(f'{prefix} export failure: {e}')
然后在终端运行:
python export.py --weights runs/train/exp/weights/best.pt --include torchscript
运行完得到best.torchscript.ptl模型
下面来添加一个切换模型的功能,并使用自己训练的模型。
首先修改pytorch依赖版本,修改build.gradle中的依赖:
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
这里的版本尽量和后面训练用的pytorch版本对应,比如后面自己用的pytorch版本是1.9.0,这里就写1.9.0。
然后修改ObjectDetectionActivitys,java,这里将mOutputColumn的private修饰符去掉,使其可以在外部访问:

接下来修改xml界面,在activity_main.xml中添加切换模型按钮,并调整布局
然后修改MainActivity.java,添加以下三个属性
private String model_name = "yolov5s.torchscript.ptl";
private String model_class = "classes.txt";
private int num_class = 80;
添加选择模型按钮响应:
private void ShowChoise()
{AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);// builder.setIcon(R.drawable.ic_launcher_foreground);builder.setTitle("选择一个模型");// 指定下拉列表的显示数据final String[] cities = {"YOLOv5s", "王者荣耀模型"};// 设置一个下拉的列表选择项builder.setItems(cities, new DialogInterface.OnClickListener(){@Overridepublic void onClick(DialogInterface dialog, int which){Toast.makeText(MainActivity.this, "选择的模型为:" + cities[which], Toast.LENGTH_SHORT).show();if (which==0){model_name = "yolov5s.torchscript.ptl";model_class = "classes.txt";num_class = 80;}else {model_name = "mymodel.ptl";model_class = "classes_wzry.txt";num_class = 10;}// 重新加载try {mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));String line;List classes = new ArrayList<>();while ((line = br.readLine()) != null) {classes.add(line);}PrePostProcessor.mClasses = new String[classes.size()];PrePostProcessor.mOutputColumn = num_class + 5;classes.toArray(PrePostProcessor.mClasses);} catch (IOException e) {Log.e("Object Detection", "Error reading assets", e);finish();}}});builder.show();
}
这里选择的模型数量添加if分支,model_class为模型对应的类别标签,需要仿照classes.txt单独创建,num_class为类别数量。
最后将之上一步得到的best.torchscript.ptl复制到assets文件夹下,注意需要手动修改文件名mymodel.ptl,这里不改名会发生文件找不到的报错,最后再运行即可。
除了上面这部分,还对界面进行了汉化,图片加载做了微调,几个修改过的文件的完整源码如下:
activity_main.xml
MainActivity.java
// Copyright (c) 2020 Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.package org.pytorch.demo.objectdetection;import androidx.appcompat.app.AlertDialog;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;import android.Manifest;
import android.content.Context;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;
import android.widget.Toast;import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;public class MainActivity extends AppCompatActivity implements Runnable {private int mImageIndex = 0;private String[] mTestImages = {"test1.png", "test2.jpg", "test3.png"};private ImageView mImageView;private ResultView mResultView;private Button mButtonDetect;private Button mButtonSelect;private ProgressBar mProgressBar;private Bitmap mBitmap = null;private Module mModule = null;private float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;private String model_name = "yolov5s.torchscript.ptl";private String model_class = "classes.txt";private int num_class = 80;public static String assetFilePath(Context context, String assetName) throws IOException {File file = new File(context.getFilesDir(), assetName);if (file.exists() && file.length() > 0) {return file.getAbsolutePath();}try (InputStream is = context.getAssets().open(assetName)) {try (OutputStream os = new FileOutputStream(file)) {byte[] buffer = new byte[4 * 1024];int read;while ((read = is.read(buffer)) != -1) {os.write(buffer, 0, read);}os.flush();}return file.getAbsolutePath();}}private void ShowChoise(){AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);// builder.setIcon(R.drawable.ic_launcher_foreground);builder.setTitle("选择一个模型");// 指定下拉列表的显示数据final String[] cities = {"YOLOv5s", "王者荣耀模型"};// 设置一个下拉的列表选择项builder.setItems(cities, new DialogInterface.OnClickListener(){@Overridepublic void onClick(DialogInterface dialog, int which){Toast.makeText(MainActivity.this, "选择的模型为:" + cities[which], Toast.LENGTH_SHORT).show();if (which==0){model_name = "yolov5s.torchscript.ptl";model_class = "classes.txt";num_class = 80;}else {model_name = "mymodel.ptl";model_class = "classes_wzry.txt";num_class = 10;}// 重新加载try {mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));String line;List classes = new ArrayList<>();while ((line = br.readLine()) != null) {classes.add(line);}PrePostProcessor.mClasses = new String[classes.size()];PrePostProcessor.mOutputColumn = num_class + 5;classes.toArray(PrePostProcessor.mClasses);} catch (IOException e) {Log.e("Object Detection", "Error reading assets", e);finish();}}});builder.show();}@Overrideprotected void onCreate(Bundle savedInstanceState) {super.onCreate(savedInstanceState);if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE}, 1);}if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, 1);}setContentView(R.layout.activity_main);try {mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));} catch (IOException e) {Log.e("Object Detection", "Error reading assets", e);finish();}mImageView = findViewById(R.id.imageView);mImageView.setImageBitmap(mBitmap);mResultView = findViewById(R.id.resultView);mResultView.setVisibility(View.INVISIBLE);final Button buttonTest = findViewById(R.id.testButton);buttonTest.setText(("测试图片 1/3"));buttonTest.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {mResultView.setVisibility(View.INVISIBLE);mImageIndex = (mImageIndex + 1) % mTestImages.length;buttonTest.setText(String.format("测试图片 %d/%d", mImageIndex + 1, mTestImages.length));try {mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));mImageView.setImageBitmap(mBitmap);} catch (IOException e) {Log.e("Object Detection", "Error reading assets", e);finish();}}});final Button buttonSelect = findViewById(R.id.selectButton);buttonSelect.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {mResultView.setVisibility(View.INVISIBLE);final CharSequence[] options = { "从相册选择", "拍照", "取消" };AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);builder.setTitle("新测试图片");builder.setItems(options, new DialogInterface.OnClickListener() {@Overridepublic void onClick(DialogInterface dialog, int item) {if (options[item].equals("拍照")) {Intent takePicture = new Intent(android.provider.MediaStore.ACTION_IMAGE_CAPTURE);startActivityForResult(takePicture, 0);}else if (options[item].equals("从相册选择")) {Intent pickPhoto = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.INTERNAL_CONTENT_URI);startActivityForResult(pickPhoto , 1);}else if (options[item].equals("取消")) {dialog.dismiss();}}});builder.show();}});final Button buttonLive = findViewById(R.id.liveButton);buttonLive.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {final Intent intent = new Intent(MainActivity.this, ObjectDetectionActivity.class);startActivity(intent);}});mButtonDetect = findViewById(R.id.detectButton);mProgressBar = (ProgressBar) findViewById(R.id.progressBar);mButtonDetect.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {mButtonDetect.setEnabled(false);mProgressBar.setVisibility(ProgressBar.VISIBLE);mButtonDetect.setText(getString(R.string.run_model));mImgScaleX = (float)mBitmap.getWidth() / PrePostProcessor.mInputWidth;mImgScaleY = (float)mBitmap.getHeight() / PrePostProcessor.mInputHeight;mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float)mImageView.getWidth() / mBitmap.getWidth() : (float)mImageView.getHeight() / mBitmap.getHeight());mIvScaleY = (mBitmap.getHeight() > mBitmap.getWidth() ? (float)mImageView.getHeight() / mBitmap.getHeight() : (float)mImageView.getWidth() / mBitmap.getWidth());mStartX = (mImageView.getWidth() - mIvScaleX * mBitmap.getWidth())/2;mStartY = (mImageView.getHeight() - mIvScaleY * mBitmap.getHeight())/2;Thread thread = new Thread(MainActivity.this);thread.start();}});// 新增选择模型按钮mButtonSelect = findViewById(R.id.select);mButtonSelect.setOnClickListener(new View.OnClickListener() {public void onClick(View v) {ShowChoise();}});try {mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));String line;List classes = new ArrayList<>();while ((line = br.readLine()) != null) {classes.add(line);}PrePostProcessor.mClasses = new String[classes.size()];PrePostProcessor.mOutputColumn = num_class;classes.toArray(PrePostProcessor.mClasses);} catch (IOException e) {Log.e("Object Detection", "Error reading assets", e);finish();}}@Overrideprotected void onActivityResult(int requestCode, int resultCode, Intent data) {super.onActivityResult(requestCode, resultCode, data);if (resultCode != RESULT_CANCELED) {switch (requestCode) {case 0:if (resultCode == RESULT_OK && data != null) {mBitmap = (Bitmap) data.getExtras().get("data");Matrix matrix = new Matrix();//matrix.postRotate(90.0f);matrix.postRotate(0);mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);mImageView.setImageBitmap(mBitmap);}break;case 1:if (resultCode == RESULT_OK && data != null) {Uri selectedImage = data.getData();String[] filePathColumn = {MediaStore.Images.Media.DATA};if (selectedImage != null) {Cursor cursor = getContentResolver().query(selectedImage,filePathColumn, null, null, null);if (cursor != null) {cursor.moveToFirst();int columnIndex = cursor.getColumnIndex(filePathColumn[0]);String picturePath = cursor.getString(columnIndex);mBitmap = BitmapFactory.decodeFile(picturePath);Matrix matrix = new Matrix();//matrix.postRotate(90.0f);matrix.postRotate(0);mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);mImageView.setImageBitmap(mBitmap);cursor.close();}}}break;}}}@Overridepublic void run() {Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessor.mInputWidth, PrePostProcessor.mInputHeight, true);final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessor.NO_MEAN_RGB, PrePostProcessor.NO_STD_RGB);IValue[] outputTuple = mModule.forward(IValue.from(inputTensor)).toTuple();final Tensor outputTensor = outputTuple[0].toTensor();final float[] outputs = outputTensor.getDataAsFloatArray();final ArrayList results = PrePostProcessor.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY);runOnUiThread(() -> {mButtonDetect.setEnabled(true);mButtonDetect.setText(getString(R.string.detect));mProgressBar.setVisibility(ProgressBar.INVISIBLE);mResultView.setResults(results);mResultView.invalidate();mResultView.setVisibility(View.VISIBLE);});}
}
strings.xml
YOLOv5 Image View 检测 正在运行,请稍后 Restart 选择图片 实时视频 切换模型
button_selector.xml
经过实测,整个APK文件打包出来有1点多G,由此可见pytorch框架一加进去体积就会变得很大,后续轻量化还有研究空间。同时,视频实时检测,帧率很低,基本卡成PPT,可能是受限于手机的算力不足,后续也有待研究优化。