Reputation: 557
I've defined the following dataclass:
"""This module declares the SubtitleItem dataclass."""
import re
from dataclasses import dataclass
from time_utils import Timestamp
@dataclass
class SubtitleItem:
"""Class for storing all the information for
a subtitle item."""
index: int
start_time: Timestamp
end_time: Timestamp
text: str
@staticmethod
def load_from_text_item(text_item: str) -> "SubtitleItem":
"""Create new subtitle item from their .srt file text.
Example, if your .srt file contains the following subtitle item:
```
3
00:00:05,847 --> 00:00:06,916
The robot.
```
This function will return:
```
SubtitleItem(
index=3,
start_time=Timestamp(seconds=5, milliseconds=847),
end_time=Timestamp(seconds=6, milliseconds=916),
text='The robot.')
```
Args:
text_item (str): The .srt text for a subtitle item.
Returns:
SubtitleItem: A corresponding SubtitleItem.
"""
# Build regex
index_re = r"\d+"
timestamp = lambda prefix: rf"(?P<{prefix}_hours>\d\d):" + \
rf"(?P<{prefix}_minutes>\d\d):" + \
rf"(?P<{prefix}_seconds>\d\d)," + \
rf"(?P<{prefix}_milliseconds>\d\d\d)"
start_timestamp_re = timestamp("start")
end_timestamp_re = timestamp("end")
text_re = r".+"
complete_re = f"^(?P<index>{index_re})\n"
complete_re += f"{start_timestamp_re} --> {end_timestamp_re}\n"
complete_re += f"(?P<text>{text_re})$"
regex = re.compile(complete_re)
# Match and extract groups
match = regex.match(text_item)
if match is None:
raise ValueError(f"Index item invalid format:\n'{text_item}'")
groups = match.groupdict()
# Extract values
index = int(groups['index'])
group_items = filter(lambda kv: kv[0].startswith("start_"), groups.items())
args = { k[len("start_"):]: int(v) for k, v in group_items }
start = Timestamp(**args)
group_items = filter(lambda kv: kv[0].startswith("end_"), groups.items())
args = { k[len("end_"):]: int(v) for k, v in group_items }
end = Timestamp(**args)
text = groups['text']
if start >= end:
raise ValueError(
f"Start timestamp must be later than end timestamp: start={start}, end={end}")
return SubtitleItem(index, start, end, text)
@staticmethod
def _format_timestamp(t: Timestamp) -> str:
"""Format a timestamp in the .srt format.
Args:
t (Timestamp): The timestamp to convert.
Returns:
str: The textual representation for the .srt format.
"""
return f"{t.get_hours()}:{t.get_minutes()}:{t.get_seconds()},{t.get_milliseconds()}"
def __str__(self):
res = f"{self.index}\n"
res += f"{SubtitleItem._format_timestamp(self.start_time)}"
res += " --> "
res += f"{SubtitleItem._format_timestamp(self.end_time)}\n"
res += self.text
return res
... which I use in the following test:
import unittest
from src.subtitle_item import SubtitleItem
from src.time_utils import Timestamp
class SubtitleItemTest(unittest.TestCase):
def testLoadFromText(self):
text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
"Test subtitle."
)
self.assertEqual(res, exp)
This test fails, but I don't understand why.
I've checked with the debugger: exp
and res
have exactly the same fields. The Timestamp
class is another separate dataclass. I've checked equality per field manually in the debugger, all fields are identical:
>>> exp == res
False
>>> exp.index == res.index
True
>>> exp.start_time == res.start_time
True
>>> exp.end_time == res.end_time
True
>>> exp.text == res.text
True
Furthermore, asdict()
on each object returns identical dictionaries:
>>> dataclasses.asdict(exp) == dataclasses.asdict(res)
True
Is there something I'm misunderstanding regarding the implementation of the equality operator with dataclasses?
Thanks.
EDIT: my time_utils
module, sorry for not including that earlier
"""
This module declares the Delta and Timestamp classes.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class _TimeBase:
hours: int = 0
minutes: int = 0
seconds: int = 0
milliseconds: int = 0
def __post_init__(self):
BOUNDS_H = range(0, 100)
BOUNDS_M = range(0, 60)
BOUNDS_S = range(0, 60)
BOUNDS_MS = range(0, 1000)
if self.hours not in BOUNDS_H:
raise ValueError(
f"{self.hours=} not in [{BOUNDS_H.start, BOUNDS_H.stop})")
if self.minutes not in BOUNDS_M:
raise ValueError(
f"{self.minutes=} not in [{BOUNDS_M.start, BOUNDS_M.stop})")
if self.seconds not in BOUNDS_S:
raise ValueError(
f"{self.seconds=} not in [{BOUNDS_S.start, BOUNDS_S.stop})")
if self.milliseconds not in BOUNDS_MS:
raise ValueError(
f"{self.milliseconds=} not in [{BOUNDS_MS.start, BOUNDS_MS.stop})")
def _to_ms(self):
return self.milliseconds + 1000 * (self.seconds + 60 * (self.minutes + 60 * self.hours))
@dataclass(frozen=True)
class Delta(_TimeBase):
"""A time difference, with milliseconds accuracy.
Must be less than 100h long."""
sign: int = 1
def __post_init__(self):
if self.sign not in (1, -1):
raise ValueError(
f"{self.sign=} should either be 1 or -1")
super().__post_init__()
def __add__(self, other: "Delta") -> "Delta":
self_ms = self.sign * self._to_ms()
other_ms = other.sign * other._to_ms()
ms_sum = self_ms + other_ms
sign = -1 if ms_sum < 0 else 1
ms_sum = abs(ms_sum)
ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
s_n, m_rem = s_rem % 60, s_rem // 60
m_n, h_n = m_rem % 60, m_rem // 60
return Delta(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n, sign=sign)
@dataclass(frozen=True)
class Timestamp(_TimeBase):
"""A timestamp with milliseconds accuracy. Must be
less than 100h long."""
def __add__(self, other: Delta) -> "Timestamp":
ms_sum = self._to_ms() + other.sign * other._to_ms()
ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
s_n, m_rem = s_rem % 60, s_rem // 60
m_n, h_n = m_rem % 60, m_rem // 60
return Timestamp(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n)
def __ge__(self, other: "Timestamp") -> bool:
return self._to_ms() >= other._to_ms()
Upvotes: 6
Views: 1708
Reputation: 557
Okay, I think I found what's going wrong here.
First, I made a mistake when I reported the issue before: in the unit test, exp.start_time != res.start_time
and exp.end_time != res.end_time
. Sorry about that. That narrows down the issue to comparison of timestamps.
My sources are in project/src/
, the test that fails is in project/tests/
. To make source modules accessible to the test, I had to add the source directory to PYTHONPATH
:
$ PYTHONPATH=src/ python -m unittest discover -s tests/ -v
In the unit test, even though res.start_time
and end.start_time
do have the same fields, they do not have the same type:
>>> print(type(res.start_time), type(exp.start_time))
<class 'time_utils.Timestamp'> <class 'src.time_utils.Timestamp'>
I've added a new post with a minimally reproducible example, and more details about the file structure here: Minimally reproducible example.
Upvotes: 1
Reputation: 54718
class Timestamp:
def __init__( self, hours=0, minutes=0, seconds=0, milliseconds=0 ):
self.ms = ((hours*60+minutes)*60+seconds)*1000+milliseconds
def get_hours(self):
return self.ms // (60*60*1000)
def get_minutes(self):
return (self.ms // (60*1000)) % 60
def get_seconds(self):
return (self.ms // 1000) % 60
def get_milliseconds(self):
return self.ms % 1000
def __add__(self,other):
return Timestamp(milliseconds=self.ms + self.other)
def __eq__(self,other):
return self.ms == other.ms
def __lt__(self,other):
return self.ms < other.ms
def __le__(self,other):
return self.ms <= other.ms
... your code ...
text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
"Test subtitle."
)
print(res)
print(exp)
print(res==exp)
Produces:
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
True
with no assert exception.
Upvotes: 1