RuntimeError: CUDA error: device-side assert triggered - Compile with TORCH_USE_CUDA_DSA to enable device-side assertions

I’m encountering the following error while running a TTS (text-to-speech) model on a GPU using PyTorch:

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This issue occurs when multiple requests are made in quick succession using threading. When requests are spaced out (i.e., single or slower sequential requests), the function works as expected, and no errors are thrown. However, when multiple threads simultaneously process text, the error is triggered inconsistently.

Here is the function I use to split text into smaller parts, synthesize it into audio, and save the results as WAV files:

!git clone https://github.com/coqui-ai/TTS.git
!pip install -r requirements.txt
!pip install .
!pip install numpy==1.24.3

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import torch
import soundfile as sf
from pydub import AudioSegment
import base64
import TTS.tts.layers.xtts.tokenizer as xttsTokenizer
import numpy as np
import io
import os

config_path = "/content/drive/MyDrive/XTTS-v2/config.json"
model_path = "/content/drive/MyDrive/XTTS-v2/"

config = XttsConfig()
config.load_json(config_path)
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True)
model.cuda()

def TTS_XTTSv2(prompt, speaker_wav_path, id, lang, speed, text_split_length=226):
    split_tts_sentence = xttsTokenizer.split_sentence(text=prompt, lang=lang, text_split_length=text_split_length)

    if lang is None or lang == "":
        lang = lang_detect(prompt)
    output_files = []
    for i, part in enumerate(split_tts_sentence):
        splitted_text_voice_output_path = f"{voice_test_path}/{id}_{i+1}.wav"
        outputs = model.synthesize(
            part,
            config=config,
            speaker_wav=speaker_wav_path,
            language=lang,
            speed=speed
        )

        wav_output = outputs['wav']

        sf.write(splitted_text_voice_output_path, wav_output, 24000)
        output_files.append(splitted_text_voice_output_path)

The error is triggered at the model.synthesize step, which is computationally heavy and runs on the GPU. This function is called within a threaded API using threading.Thread to parallelize text processing.

Full Error Traceback:

Exception in thread Thread-33 (generate_tts_response):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
Exception in thread Thread-35 (generate_tts_response):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
Exception in thread Thread-34 (generate_tts_response):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
        self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-9-d8f16acdbf2d>", line 1168, in generate_tts_response
  File "<ipython-input-9-d8f16acdbf2d>", line 1034, in TTS_XTTSv2
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 419, in synthesize
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-9-d8f16acdbf2d>", line 1168, in generate_tts_response
  File "<ipython-input-9-d8f16acdbf2d>", line 1034, in TTS_XTTSv2
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 419, in synthesize
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-9-d8f16acdbf2d>", line 1168, in generate_tts_response
  File "<ipython-input-9-d8f16acdbf2d>", line 1034, in TTS_XTTSv2
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 419, in synthesize
    return self.full_inference(text, speaker_wav, language, **settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^    return self.full_inference(text, speaker_wav, language, **settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^    ^return self.full_inference(text, speaker_wav, language, **settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^    return func(*args, **kwargs)
              return func(*args, **kwargs)
           ^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 488, in full_inference
^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 488, in full_inference
    ^^^return self.inference(
         ^^^    return self.inference(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^  ^^^^^^    
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^^^return func(*args, **kwargs)
 ^ ^^^^^         ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 541, in inference
^    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 488, in full_inference

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
        return func(*args, **kwargs)
           return self.inference(
           ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    gpt_codes = self.gpt.generate(
                ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt.py", line 590, in generate
^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 541, in inference
    return func(*args, **kwargs)
       gpt_codes = self.gpt.generate(
             ^^   ^^^^^^      ^^^^  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt.py", line 590, in generate
^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/models/xtts.py", line 541, in inference
    gen = self.gpt_inference.generate(
       gen = self.gpt_inference.generate(
           ^  ^^^^^^^^^  ^^^^^^^^^^^^^^^  ^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    gpt_codes = self.gpt.generate(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt.py", line 590, in generate
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^    return func(*args, **kwargs)
   ^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 2252, in generate
         gen = self.gpt_inference.generate(    result = self._sample(

                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
  ^^ ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 2252, in generate
^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 3254, in _sample
    outputs = model_forward(**model_inputs, return_dict=True)
              ^^    result = self._sample(
   ^^^^^^^^^^^^^^^^    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 2252, in generate
^  ^^^^^^^^^^^^^^^ ^^^^^ ^^ ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 3251, in _sample
 ^^    result = self._sample(
             ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 3310, in _sample
^^^^^^^^^^^^^^^^^^^    unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)^^^^^^^^^^^
                                                   ^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
   ^ ^      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/generation/stopping_criteria.py", line 494, in __call__
     ^^^outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt_inference.py", line 94, in forward
    is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    emb = emb + self.pos_embedding.get_fixed_embedding(
                ^^^^^^^^^^^^^^^^^^    return forward_call(*args, **kwargs)
     ^      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt_inference.py", line 97, in forward
  File "/usr/local/lib/python3.11/dist-packages/TTS/tts/layers/xtts/gpt.py", line 40, in get_fixed_embedding
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^    return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
 ^       ^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
   ^^^^^^^^^^^^^^^^    return forward_call(*args, **kwargs)
   ^        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 1133, in forward
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^    outputs = block(^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl

      return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^ ^         ^  ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/sparse.py", line 190, in forward
    return self._call_impl(*args, **kwargs)
           return F.embedding(
           ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/functional.py", line 2551, in embedding
    ^^^^^^^^^^^^^^^    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl

       return forward_call(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Here’s what I’ve tried so far:

Upvotes: 0

Views: 47

Answers (1)

Brandon Pardi
Brandon Pardi

Reputation: 196

First thing I'd try is to compile pytorch with TORCH_CUDA_USE_DSA to see if there are just underlying issues that don't effect single thread running but are more prevalent in multithread.

See here how to set that up

Alternatively it may not have as much of the benefit but you could lock the thread to model.synthesize so only one thread can access it at a time:

synthesize_lock = threading.Lock() # call this outside your function

def TTS_XTTSv2(prompt, speaker_wav_path, id, lang, speed, text_split_length=226):
    ...
    for i, part in enumerate(split_tts_sentence):
        splitted_text_voice_output_path = f"{voice_test_path}/{id}_{i+1}.wav"
        with synthesize_lock:  # Ensure exclusive access to the model
            outputs = model.synthesize(
                part,
                config=config,
                speaker_wav=speaker_wav_path,
                language=lang,
                speed=speed
            )
    ...

Good luck and lemme know how it goes

Upvotes: 0

Related Questions