GaitSet 复现

- 4 分钟前

准备

安装程序所需的依赖

# scipy
pip3 install scipy

# skbuild
pip3 install scikit-build

# cmake
pip3 install cmake

# opencv
pip3 install opencv-python

OpenCV for Python 有很长一段时间需要 Build,要等比较长一段时间。

Build 过程中出现问题,无法找到 Python.h,需要安装python-dev。

# 找一下 yum 里面叫什么, 不同的系统名称不一样
yum search python3 | grep devel

yum install python3-devel

# 重新安装 opencv, 可以使用清华源
pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python

安装成功后,开始对图像进行预处理,将原图像剪裁为 64 * 64 的输入图像。

先将 CASIA-B 数据集整理为your_dataset_path/subject_ids/walking_conditions/views. E.g. CASIA-B/001/nm-01/000/. 这样的格式。

image-20200918085026319

python3 pretreatment.py --input_path='/root/inger/casia-b' --output_path='/root/inger/cut-casia-b'

出现问题:AttributeError: module 'scipy.misc' has no attribute 'imsave',原因是Python 3中的scipy库/模块中imsave()imread()等方法在新版本(scipy 1.4.1)中被弃用了,虽然降版本(手动安装scipy 1.2.1或更低版本)可以,但是可以用Python3中其它模块代替使用。

pip3 install imageio

修改pretreatment.py

# from scipy import misc as scisc
import imageio

# scisc.imsave(save_path, img)
imageio.imsave(save_path, img)

重新运行:

[root@localhost GaitSet]# python3 pretreatment.py --input_path='/root/inger/casia-b' --output_path='/root/inger/cut-casia-b'
Pretreatment Start.
Input path: /root/inger/casia-b
Output path: /root/inger/cut-casia-b
Log file: ./pretreatment.log
Worker num: 1
# JOB 0 : --START-- 001-bg-01-000
# JOB 0 : --FINISH-- Contain 89 valid frames. Saved to /root/inger/cut-casia-b/001/bg-01/000.

修改 config.py 中的数据集路径为自己的路径。

然后训练模型

python3 train.py

出现 bug:

# python3 train.py
Initialzing...
Initializing data source...
/usr/local/lib64/python3.6/site-packages/numpy/core/_asarray.py:136: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  return array(a, dtype, copy=False, order=order, subok=True)
Traceback (most recent call last):
  File "train.py", line 18, in <module>
    m = initialization(conf, train=opt.cache)[0]
  File "/root/inger/GaitSet/model/initialization.py", line 57, in initialization
    train_source, test_source = initialize_data(config, train, test)
  File "/root/inger/GaitSet/model/initialization.py", line 15, in initialize_data
    train_source, test_source = load_data(**config['data'], cache=(train or test))
  File "/root/inger/GaitSet/model/utils/data_loader.py", line 42, in load_data
    pid_list = np.load(pid_fname)
  File "/usr/local/lib64/python3.6/site-packages/numpy/lib/npyio.py", line 440, in load
    pickle_kwargs=pickle_kwargs)
  File "/usr/local/lib64/python3.6/site-packages/numpy/lib/format.py", line 727, in read_array
    raise ValueError("Object arrays cannot be loaded when "
ValueError: Object arrays cannot be loaded when allow_pickle=False

numpy 需要降级:

# pip3 show numpy
Name: numpy
Version: 1.19.2

pip3 install numpy==1.16.1

[root@localhost GaitSet]# python3 train.py
Initialzing...
Initializing data source...
Loading training data...

image-20200918172052670

从 9 月 17日 下午四点开始训练,大概 12 个小时训练完成。

[root@localhost GaitSet]# ls -l work/
总用量 0
drwxr-xr-x. 3 root root 21 9月  17 04:18 checkpoint
drwxr-xr-x. 2 root root 47 9月  17 04:18 OUMVLP_network
drwxr-xr-x. 2 root root 34 9月  18 04:36 partition
[root@localhost GaitSet]# python3 test.py
Initialzing...
Initializing data source...
Data initialization complete.
Initializing model...
Model initialization complete.
Loading the model of iteration 80000...
Transforming...
Evaluating...
Evaluation complete. Cost: 0:00:49.330972
===Rank-1 (Include identical-view cases)===
NM: 95.587,	BG: 89.026,	CL: 73.198
===Rank-1 (Exclude identical-view cases)===
NM: 95.145,	BG: 88.029,	CL: 71.645
===Rank-1 of each angle (Exclude identical-view cases)===
NM: [91.80 98.30 98.40 97.50 93.30 92.10 94.10 97.40 98.00 97.80 87.90]
BG: [86.30 91.30 93.30 89.90 85.50 80.50 84.50 90.40 92.30 92.12 82.20]
CL: [69.20 78.90 79.10 76.80 69.20 67.70 70.20 73.40 74.50 70.70 58.40]

测试在 CASIA-B 的下的性能。

Inger Notes © 2024
rss facebook twitter github gitlab youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora qq quora wechat