xXAI-botXx
xXAI-botXx

Reputation: 21

Is it possible to train Mask RCNN with a RTX 4090?

I want to train a Mask R-CNN Model with a Nvidia RTX 4090 GPU but it seems impossible. There seems to be an issue with the weightloading.

I tried following implementations:

With Python 3.7.12 and Tensorflow 1.14/1.15 and Keras 2.3.1 Mask-RCNN works very well. BUT I can't use the RTX 4090 with these Versions. I need at least Tensorflow 2.12.1 to get the right CUDA version of the RTX 4090 (CUDA 11.8).

So when I use Tensorflow 2.12.1 or 2.13.1 there is a special issue, where the loading of the weights is wrong. The weights always gets loaded randomly. So I train the model for 1 epoch, then I load the weights 3 times and the weights are different. See the result looking like that:

Results when loading the same weights 3 Times

Note that the results of course will look sketch and weird because I don't train the mode, but I expected to become the same results for the same weights.

I tried that with Tom Gross's Implementation and with mrk1992's Implementation and both got this bug.

I looked many issues but didn't found something. Maybe you can help.

Does anybody knows how to fix that?


Recreation of the bug

For Recreation I give you my conda environment. Create a environment.yml:

name: maskrcnn
channels:
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - _tflow_select=2.3.0=mkl
  - absl-py=0.15.0=pyhd8ed1ab_0
  - aiohttp=3.7.4.post0=py37h5e8e339_1
  - alembic=1.13.1=pyhd8ed1ab_0
  - alsa-lib=1.2.8=h166bdaf_0
  - anyio=3.7.1=pyhd8ed1ab_0
  - aom=3.5.0=h27087fc_0
  - argon2-cffi=23.1.0=pyhd8ed1ab_0
  - argon2-cffi-bindings=21.2.0=py37h540881e_2
  - astor=0.8.1=pyh9f0ad1d_0
  - async-timeout=3.0.1=py_1000
  - attr=2.5.1=h166bdaf_1
  - attrs=23.2.0=pyh71513ae_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=2.0.0=pyhd8ed1ab_0
  - bcrypt=3.2.2=py37h540881e_0
  - beautifulsoup4=4.12.3=pyha770c72_0
  - binutils_impl_linux-64=2.40=ha885e6a_0
  - binutils_linux-64=2.40=hdade7a5_3
  - bleach=6.1.0=pyhd8ed1ab_0
  - blinker=1.6.3=pyhd8ed1ab_0
  - bottleneck=1.3.5=py37hda87dfa_0
  - brotli=1.1.0=hd590300_1
  - brotli-bin=1.1.0=hd590300_1
  - brotli-python=1.0.9=py37hd23a5d3_7
  - bzip2=1.0.8=hd590300_5
  - c-ares=1.28.1=hd590300_0
  - ca-certificates=2024.2.2=hbcca054_0
  - cached-property=1.5.2=hd8ed1ab_1
  - cached_property=1.5.2=pyha770c72_1
  - cachetools=5.3.3=pyhd8ed1ab_0
  - cairo=1.16.0=ha61ee94_1014
  - certifi=2024.2.2=pyhd8ed1ab_0
  - cffi=1.15.1=py37h43b0acd_1
  - chardet=4.0.0=py37h89c1867_3
  - charset-normalizer=3.3.2=pyhd8ed1ab_0
  - click=8.1.3=py37h89c1867_0
  - cloudpickle=2.2.1=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - colorlog=6.7.0=py37h89c1867_0
  - comm=0.2.2=pyhd8ed1ab_0
  - cryptography=38.0.2=py37h5994e8b_1
  - cycler=0.11.0=pyhd8ed1ab_0
  - cython=0.29.32=py37hd23a5d3_0
  - cytoolz=0.12.0=py37h540881e_0
  - dask-core=2022.2.0=pyhd8ed1ab_0
  - dbus=1.13.6=h5008d03_3
  - debugpy=1.6.3=py37hd23a5d3_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - defusedxml=0.7.1=pyhd8ed1ab_0
  - dill=0.3.8=pyhd8ed1ab_0
  - distro=1.9.0=pyhd8ed1ab_0
  - docker-py=6.1.3=pyhd8ed1ab_0
  - entrypoints=0.4=pyhd8ed1ab_0
  - exceptiongroup=1.2.0=pyhd8ed1ab_2
  - expat=2.6.2=h59595ed_0
  - ffmpeg=4.4.2=gpl_h8dda1f0_112
  - fftw=3.3.10=nompi_hc118613_108
  - flask=1.1.2=pyh9f0ad1d_0
  - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
  - font-ttf-inconsolata=3.000=h77eed37_0
  - font-ttf-source-code-pro=2.038=h77eed37_0
  - font-ttf-ubuntu=0.83=h77eed37_2
  - fontconfig=2.14.2=h14ed4e7_0
  - fonts-conda-ecosystem=1=0
  - fonts-conda-forge=1=0
  - fonttools=4.38.0=py37h540881e_0
  - freeglut=3.2.2=h9c3ff4c_1
  - freetype=2.12.1=h267a509_2
  - fsspec=2023.1.0=pyhd8ed1ab_0
  - gast=0.2.2=py_0
  - gcc_impl_linux-64=13.2.0=h9eb54c0_7
  - gcc_linux-64=13.2.0=h1ed452b_3
  - geos=3.11.0=h27087fc_0
  - gettext=0.22.5=h59595ed_2
  - gettext-tools=0.22.5=h59595ed_2
  - gitdb=4.0.11=pyhd8ed1ab_0
  - gitpython=3.1.43=pyhd8ed1ab_0
  - glib=2.80.2=hf974151_0
  - glib-tools=2.80.2=hb6ce0ca_0
  - gmp=6.3.0=h59595ed_1
  - gnutls=3.7.9=hb077bed_0
  - google-auth=2.23.0=pyh1a96a4e_0
  - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
  - google-pasta=0.2.0=pyh8c360ce_0
  - graphite2=1.3.13=h59595ed_1003
  - greenlet=1.1.3=py37hd23a5d3_0
  - grpc-cpp=1.48.1=hc2bec63_1
  - grpcio=1.48.1=py37h42e856d_1
  - gst-plugins-base=1.21.3=h4243ec0_1
  - gstreamer=1.21.3=h25f0c4b_1
  - gstreamer-orc=0.4.38=hd590300_0
  - gunicorn=20.1.0=py37h89c1867_2
  - gxx_impl_linux-64=13.2.0=h2a599c4_7
  - gxx_linux-64=13.2.0=he8deefe_3
  - h5py=3.7.0=nompi_py37hf1ce037_101
  - harfbuzz=5.3.0=h418a68e_0
  - hdf5=1.12.2=nompi_h4df4325_101
  - icu=70.1=h27087fc_0
  - idna=3.7=pyhd8ed1ab_0
  - imagecodecs-lite=2019.12.3=py37hc105733_5
  - imageio=2.34.1=pyh4b66e23_0
  - imgaug=0.4.0=pyhd8ed1ab_1
  - importlib-metadata=4.11.4=py37h89c1867_0
  - importlib_resources=6.0.0=pyhd8ed1ab_0
  - imutils=0.5.4=py37h89c1867_2
  - ipykernel=6.16.2=pyh210e3f2_0
  - ipython=7.33.0=py37h89c1867_0
  - ipython_genutils=0.2.0=py_1
  - ipywidgets=8.1.2=pyhd8ed1ab_1
  - itsdangerous=2.1.2=pyhd8ed1ab_0
  - jack=1.9.22=h11f4161_0
  - jasper=2.0.33=h0ff4b12_1
  - jedi=0.19.1=pyhd8ed1ab_0
  - jinja2=3.1.4=pyhd8ed1ab_0
  - joblib=1.2.0=pyhd8ed1ab_0
  - jpeg=9e=h0b41bf4_3
  - jsonschema=4.17.3=pyhd8ed1ab_0
  - jupyter=1.0.0=pyhd8ed1ab_10
  - jupyter_client=7.4.9=pyhd8ed1ab_0
  - jupyter_console=6.5.1=pyhd8ed1ab_0
  - jupyter_core=4.11.1=py37h89c1867_0
  - jupyter_server=1.23.4=pyhd8ed1ab_0
  - jupyterlab_pygments=0.3.0=pyhd8ed1ab_1
  - jupyterlab_widgets=3.0.10=pyhd8ed1ab_0
  - keras=2.3.1=py37_0
  - keras-applications=1.0.8=py_1
  - keras-preprocessing=1.1.2=pyhd8ed1ab_0
  - kernel-headers_linux-64=2.6.32=he073ed8_17
  - keyutils=1.6.1=h166bdaf_0
  - kiwisolver=1.4.4=py37h7cecad7_0
  - krb5=1.20.1=h81ceb04_0
  - lame=3.100=h166bdaf_1003
  - lcms2=2.14=h6ed2654_0
  - ld_impl_linux-64=2.40=h55db66e_0
  - lerc=4.0.0=h27087fc_0
  - libabseil=20220623.0=cxx17_h05df665_6
  - libaec=1.1.3=h59595ed_0
  - libasprintf=0.22.5=h661eb56_2
  - libasprintf-devel=0.22.5=h661eb56_2
  - libblas=3.9.0=20_linux64_openblas
  - libbrotlicommon=1.1.0=hd590300_1
  - libbrotlidec=1.1.0=hd590300_1
  - libbrotlienc=1.1.0=hd590300_1
  - libcap=2.67=he9d0100_0
  - libcblas=3.9.0=20_linux64_openblas
  - libclang=15.0.7=default_h127d8a8_5
  - libclang13=15.0.7=default_h5d6823c_5
  - libcups=2.3.3=h36d4200_3
  - libcurl=8.1.2=h409715c_0
  - libdb=6.2.32=h9c3ff4c_0
  - libdeflate=1.14=h166bdaf_0
  - libdrm=2.4.120=hd590300_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=hd590300_2
  - libevent=2.1.10=h28343ad_4
  - libexpat=2.6.2=h59595ed_0
  - libffi=3.4.2=h7f98852_5
  - libflac=1.4.3=h59595ed_0
  - libgcc-devel_linux-64=13.2.0=hceb6213_107
  - libgcc-ng=13.2.0=h77fa898_7
  - libgcrypt=1.10.3=hd590300_0
  - libgettextpo=0.22.5=h59595ed_2
  - libgettextpo-devel=0.22.5=h59595ed_2
  - libgfortran-ng=13.2.0=h69a702a_7
  - libgfortran5=13.2.0=hca663fb_7
  - libglib=2.80.2=hf974151_0
  - libglu=9.0.0=he1b5a44_1001
  - libgomp=13.2.0=h77fa898_7
  - libgpg-error=1.49=h4f305b6_0
  - libgpuarray=0.7.6=h7f98852_1003
  - libiconv=1.17=hd590300_2
  - libidn2=2.3.7=hd590300_0
  - liblapack=3.9.0=20_linux64_openblas
  - liblapacke=3.9.0=20_linux64_openblas
  - libllvm11=11.1.0=he0ac6c6_5
  - libllvm15=15.0.7=hadd5161_1
  - libnghttp2=1.58.0=h47da74e_0
  - libnsl=2.0.1=hd590300_0
  - libogg=1.3.4=h7f98852_1
  - libopenblas=0.3.25=pthreads_h413a1c8_0
  - libopencv=4.6.0=py37hfe11ba8_3
  - libopus=1.3.1=h7f98852_1
  - libpciaccess=0.18=hd590300_0
  - libpng=1.6.43=h2797004_0
  - libpq=15.3=hbcd7760_1
  - libprotobuf=3.20.1=h6239696_4
  - libsanitizer=13.2.0=h6ddb7a1_7
  - libsndfile=1.2.2=hc60ed4a_1
  - libsodium=1.0.18=h36c2ea0_1
  - libsqlite=3.45.3=h2797004_0
  - libssh2=1.11.0=h0841786_0
  - libstdcxx-devel_linux-64=13.2.0=hceb6213_107
  - libstdcxx-ng=13.2.0=hc0a3c3a_7
  - libsystemd0=253=h8c4010b_1
  - libtasn1=4.19.0=h166bdaf_0
  - libtiff=4.4.0=h82bc61c_5
  - libtool=2.4.7=h27087fc_0
  - libudev1=253=h0b41bf4_1
  - libunistring=0.9.10=h7f98852_0
  - libuuid=2.38.1=h0b41bf4_0
  - libva=2.18.0=h0b41bf4_0
  - libvorbis=1.3.7=h9c3ff4c_0
  - libvpx=1.11.0=h9c3ff4c_3
  - libwebp-base=1.4.0=hd590300_0
  - libxcb=1.13=h7f98852_1004
  - libxkbcommon=1.5.0=h79f4944_1
  - libxml2=2.10.3=hca2bb57_4
  - libzlib=1.2.13=hd590300_5
  - llvmlite=0.39.1=py37h0761922_0
  - locket=1.0.0=pyhd8ed1ab_0
  - lz4-c=1.9.4=hcb278e6_0
  - mako=1.3.5=pyhd8ed1ab_0
  - markdown=3.6=pyhd8ed1ab_0
  - markupsafe=2.1.1=py37h540881e_1
  - matplotlib-base=3.5.3=py37hf395dca_2
  - matplotlib-inline=0.1.7=pyhd8ed1ab_0
  - mistune=3.0.2=pyhd8ed1ab_0
  - mlflow=1.30.0=py37h02d9ccd_0
  - mpg123=1.32.6=h59595ed_0
  - multidict=6.0.2=py37h540881e_1
  - multiprocess=0.70.14=py37h540881e_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - mysql-common=8.0.33=hf1915f5_6
  - mysql-libs=8.0.33=hca2cd23_6
  - nbclassic=1.0.0=pyhb4ecaf3_1
  - nbclient=0.7.0=pyhd8ed1ab_0
  - nbconvert=7.6.0=pyhd8ed1ab_0
  - nbconvert-core=7.6.0=pyhd8ed1ab_0
  - nbconvert-pandoc=7.6.0=pyhd8ed1ab_0
  - nbformat=5.8.0=pyhd8ed1ab_0
  - ncurses=6.5=h59595ed_0
  - nest-asyncio=1.6.0=pyhd8ed1ab_0
  - nettle=3.9.1=h7ab15ed_0
  - networkx=2.6.3=pyhd8ed1ab_1
  - nomkl=1.0=h5ca1d4c_0
  - notebook=6.5.7=pyha770c72_0
  - notebook-shim=0.2.4=pyhd8ed1ab_0
  - nspr=4.35=h27087fc_0
  - nss=3.100=hca3bf56_0
  - numba=0.56.3=py37hf081915_0
  - numexpr=2.8.3=py37h85a3170_100
  - numpy=1.21.6=py37h976b520_0
  - oauthlib=3.2.2=pyhd8ed1ab_0
  - opencv=4.6.0=py37h89c1867_3
  - openh264=2.3.1=hcb278e6_2
  - openjpeg=2.5.0=h7d73246_1
  - openssl=3.1.5=hd590300_0
  - opt_einsum=3.3.0=pyhc1e730c_2
  - p11-kit=0.24.1=hc5aa10d_0
  - packaging=21.3=pyhd8ed1ab_0
  - pandas=1.3.5=py37h8c16a72_0
  - pandoc=3.2=ha770c72_0
  - pandocfilters=1.5.0=pyhd8ed1ab_0
  - paramiko=3.4.0=pyhd8ed1ab_0
  - parso=0.8.4=pyhd8ed1ab_0
  - partd=1.4.1=pyhd8ed1ab_0
  - pcre2=10.43=hcad00b1_0
  - pexpect=4.9.0=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pillow=9.2.0=py37h850a105_2
  - pip=24.0=pyhd8ed1ab_0
  - pixman=0.43.2=h59595ed_0
  - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
  - prometheus_client=0.17.1=pyhd8ed1ab_0
  - prometheus_flask_exporter=0.23.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.42=pyha770c72_0
  - prompt_toolkit=3.0.42=hd8ed1ab_0
  - protobuf=3.20.1=py37hd23a5d3_0
  - psutil=5.9.3=py37h540881e_0
  - pthread-stubs=0.4=h36c2ea0_1001
  - ptyprocess=0.7.0=pyhd3deb0d_0
  - pulseaudio=16.1=hcb278e6_3
  - pulseaudio-client=16.1=h5195f5e_3
  - pulseaudio-daemon=16.1=ha8d29e2_3
  - py-opencv=4.6.0=py37h25bab4e_3
  - pyasn1=0.5.1=pyhd8ed1ab_0
  - pyasn1-modules=0.3.0=pyhd8ed1ab_0
  - pycocotools=2.0.4=py37hda87dfa_2
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.17.2=pyhd8ed1ab_0
  - pygpu=0.7.6=py37hb1e94ed_1003
  - pyjwt=2.8.0=pyhd8ed1ab_1
  - pynacl=1.5.0=py37h540881e_1
  - pyopenssl=23.2.0=pyhd8ed1ab_1
  - pyparsing=3.1.2=pyhd8ed1ab_0
  - pyrsistent=0.18.1=py37h540881e_1
  - pysocks=1.7.1=py37h89c1867_5
  - python=3.7.12=hf930737_100_cpython
  - python-dateutil=2.9.0=pyhd8ed1ab_0
  - python-fastjsonschema=2.19.1=pyhd8ed1ab_0
  - python_abi=3.7=4_cp37m
  - pytz=2022.7.1=pyhd8ed1ab_0
  - pyu2f=0.1.5=pyhd8ed1ab_0
  - pywavelets=1.3.0=py37hda87dfa_1
  - pywin32-on-windows=0.1.0=pyh1179c8e_3
  - pyyaml=6.0=py37h540881e_4
  - pyzmq=24.0.1=py37h0c0c2a8_0
  - qt-main=5.15.6=hf6cd601_5
  - qtconsole-base=5.4.4=pyha770c72_0
  - qtpy=2.4.1=pyhd8ed1ab_0
  - querystring_parser=1.2.4=py_0
  - re2=2022.06.01=h27087fc_1
  - readline=8.2=h8228510_1
  - requests=2.31.0=pyhd8ed1ab_0
  - requests-oauthlib=2.0.0=pyhd8ed1ab_0
  - rsa=4.9=pyhd8ed1ab_0
  - ruamel.yaml=0.17.10=py37h5e8e339_0
  - ruamel.yaml.clib=0.2.6=py37h540881e_1
  - scikit-build=0.17.6=pyh4af843d_0
  - scikit-image=0.19.3=py37hfb7772e_1
  - scikit-learn=1.0.2=py37hf9e9bfc_0
  - scipy=1.7.3=py37hf2a6cf1_0
  - send2trash=1.8.3=pyh0d859eb_0
  - setproctitle=1.3.2=py37h540881e_0
  - setuptools=69.0.3=pyhd8ed1ab_0
  - shapely=1.8.5=py37ha4e3bd1_0
  - six=1.16.0=pyh6c4a22f_0
  - smmap=5.0.0=pyhd8ed1ab_0
  - sniffio=1.3.1=pyhd8ed1ab_0
  - soupsieve=2.3.2.post1=pyhd8ed1ab_0
  - sqlalchemy=1.4.42=py37h540881e_0
  - sqlite=3.45.3=h2c6b66d_0
  - sqlparse=0.4.4=pyhd8ed1ab_0
  - svt-av1=1.4.1=hcb278e6_0
  - sysroot_linux-64=2.12=he073ed8_17
  - tensorboard=2.8.0=pyhd8ed1ab_1
  - tensorboard-data-server=0.6.1=py37h52d8a92_0
  - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
  - tensorflow=2.0.0=mkl_py37h66b46cc_0
  - tensorflow-base=2.0.0=mkl_py37h9204916_0
  - tensorflow-estimator=2.6.0=py37hcd2ae1e_0
  - termcolor=1.1.0=pyhd8ed1ab_3
  - terminado=0.17.1=pyh41d4057_0
  - theano=1.0.4=py37hf484d3e_1000
  - threadpoolctl=3.1.0=pyh8a188c0_0
  - tifffile=2020.6.3=py_0
  - tinycss2=1.3.0=pyhd8ed1ab_0
  - tk=8.6.13=noxft_h4845f30_101
  - tomli=2.0.1=pyhd8ed1ab_0
  - toolz=0.12.1=pyhd8ed1ab_0
  - tornado=6.2=py37h540881e_0
  - tqdm=4.66.4=pyhd8ed1ab_0
  - traitlets=5.9.0=pyhd8ed1ab_0
  - typing-extensions=4.7.1=hd8ed1ab_0
  - typing_extensions=4.7.1=pyha770c72_0
  - unicodedata2=14.0.0=py37h540881e_1
  - urllib3=1.26.18=pyhd8ed1ab_0
  - wcwidth=0.2.10=pyhd8ed1ab_0
  - webencodings=0.5.1=pyhd8ed1ab_2
  - websocket-client=1.6.1=pyhd8ed1ab_0
  - werkzeug=0.16.1=py_0
  - wheel=0.42.0=pyhd8ed1ab_0
  - widgetsnbextension=4.0.10=pyhd8ed1ab_0
  - wrapt=1.14.1=py37h540881e_0
  - x264=1!164.3095=h166bdaf_2
  - x265=3.5=h924138e_3
  - xcb-util=0.4.0=h516909a_0
  - xcb-util-image=0.4.0=h166bdaf_0
  - xcb-util-keysyms=0.4.0=h516909a_0
  - xcb-util-renderutil=0.3.9=h166bdaf_0
  - xcb-util-wm=0.4.1=h516909a_0
  - xkeyboard-config=2.38=h0b41bf4_0
  - xorg-fixesproto=5.0=h7f98852_1002
  - xorg-inputproto=2.3.2=h7f98852_1002
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libice=1.1.1=hd590300_0
  - xorg-libsm=1.2.4=h7391055_0
  - xorg-libx11=1.8.4=h0b41bf4_0
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxdmcp=1.1.3=h7f98852_0
  - xorg-libxext=1.3.4=h0b41bf4_2
  - xorg-libxfixes=5.0.3=h7f98852_1004
  - xorg-libxi=1.7.10=h7f98852_0
  - xorg-libxrender=0.9.10=h7f98852_1003
  - xorg-renderproto=0.11.1=h7f98852_1002
  - xorg-xextproto=7.3.0=h0b41bf4_1003
  - xorg-xproto=7.0.31=h7f98852_1007
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7f98852_2
  - yarl=1.7.2=py37h540881e_2
  - zeromq=4.3.5=h59595ed_1
  - zipp=3.15.0=pyhd8ed1ab_0
  - zlib=1.2.13=hd590300_5
  - zstd=1.5.6=ha6fb4c9_0
  - pip:
      - databricks-cli==0.18.0
      - tabulate==0.9.0
prefix: /home/local-admin/.conda/envs/maskrcnn

Then just type

conda env create -n maskrcnn python=3.8.19 -f environment.yml

Also download a Maks RCNN Model:

git clone -b tensorflow-2.0 https://github.com/tomgross/Mask_RCNN.git Mask_RCNN

Here is Code to test the model:

import sys
sys.path.append("./Mask_RCNN")

import os

import numpy as np
import cv2
import imutils
import matplotlib.pyplot as plt

import tensorflow as tf

from IPython.display import clear_output


from mrcnn.config import Config
from mrcnn import model as modellib
from mrcnn import visualize

img_path = f"./test.jpg"

image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = imutils.resize(image, width=512)

plt.figure(figsize=(10, 10))
plt.axis("off")
plt.imshow(image);

class TestConfig(Config):
    NAME = "mask-rcnn test"
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    NUM_CLASSES = 2
    # BACKBONE = "resnet50"
    # IMAGE_MIN_DIM = 800
    # IMAGE_MAX_DIM = 1024

CLASS_NAMES = ['BG', 'FG']
TEST_MODEL_PATH = "./test_model.h5"
config = TestConfig()
# config.display()

# create new model with random weights
model = modellib.MaskRCNN(mode="training", config=config, model_dir=os.getcwd())
model.keras_model.save_weights(TEST_MODEL_PATH)

clear_output()

# prepare visualization
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(20, 5))
fig.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.01, hspace=None)

# load the created random weights 3 times
weights = []
differences = []
for i in range(3):
    model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
    model.load_weights(TEST_MODEL_PATH, by_name=True)
    result = model.detect([image], verbose=0)[0]
    ax[i].set_title(f"{i}. Results")
    ax[i].axis("off")
    visualize.display_instances(image, result['rois'], result['masks'], result['class_ids'], CLASS_NAMES, result['scores'],
                                ax=ax[i], show_mask=True, show_bbox=True)
    weights += [model.keras_model.get_weights()]

clear_output()

# detect differences between the models
for cur_weights_1 in weights:
    cur_differences = []
    for cur_weights_2 in weights:
        all_weigt_differences = [np.abs(w1-w2) for w1, w2 in zip(cur_weights_1, cur_weights_2)]
        all_layer_differences = [np.sum(diff) for diff in all_weigt_differences]
        cur_differences += [np.sum(all_layer_differences)]
    differences += [cur_differences]
complete_sum_difference = np.sum(differences)
if complete_sum_difference == 0.0:
    print("Congratulations!\nYour implementation works and is ready to segment!")
else:
    print("Difference detected. Your Mask RCNN seem to have an problem.")
    print("Check if you correctly applied all steps of the installation.")
    print("\nMore detail of differences:")
    print("Absolute Weight Differences:", complete_sum_difference)
    print("Comparison: Weight 1, Weight 2, Weight 3")
    for idx in range(len(differences)):
        print(f"Weight {idx+1}: {differences[idx]}")

# try to remove the test model file
try:
    os.remove(TEST_MODEL_PATH)
except Exception:
    print(f"Wasn't able to delete the test model at: {TEST_MODEL_PATH}")

plt.show();

And here is the test image:

Test Image

I hope this makes the recreation easy for you.

Does anyone have a solution?


I tried:

I also tried tensorflow versions smaller than 2.12.1 but these doesn't work. You can also check the CUDA versions of the tensorflow versions here: https://www.tensorflow.org/install/source.

I expect same weights when loading the same weights 3 times. To verify, before I trained and used the model on a GTX 1080 Ti and there with Python 3.7.12 and tensorflow 1.15 and Keras 2.3.1 this code above succeed and the weights was equaly, which is also expected. With the new GPU this looks different.

Upvotes: 1

Views: 261

Answers (1)

Johnny Cheesecutter
Johnny Cheesecutter

Reputation: 2853

sorry, why you are not using the model from the torchvision? You can load it for inference and torch also has recipes for training it from scratch:

from torchvision.models.detection import (
    MaskRCNN_ResNet50_FPN_Weights, 
    maskrcnn_resnet50_fpn
)
import torch
import PIL
import numpy as np

weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1

# image preprocessing
transform = weights.transforms()
model = maskrcnn_resnet50_fpn(weights=weights)


img = PIL.Image.open("airplane.jpeg")
input = transform(img)


model.eval()

# add batch dimension
input = input[None,...] 


preds = model(input)
print(preds[0]['labels']) # tensor([5])


# for training check this scripts
# https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn

Upvotes: 1

Related Questions