halfspring
halfspring

Reputation: 1

Using TestStream for testing REAL_TIME timers in Apache Beam for Python with DirectRunner

I want to run unit tests as part of the build process to test timers in an Apache Beam pipeline.

The actual pipeline runs in Dataflow, but the tests use DirectRunner to run the tests locally. I’m using a TestStream to simulate an unbounded data source.

The problem is that the timer doesn’t seem to trigger correctly when TimeDomain is set to REAL_TIME. After the timer is triggered as expected the first time, it gets triggered for each new element entering the pipeline.

When setting the TimeDomain to WATERMARK, everything works as expected.

Why is this happening? Is the DirectRunner not suitable to be used for testing REAL_TIME timers with a TestStream? Or is it something else I'm missing here? All help will be greatly appreciated!

I've written some code to test this specifically. It's a simple pipeline that prints to the console when the timer is triggered after 1s of inactivity. It should be able to be run out of the box.

The TIMEDOMAIN in line 12 is currently set to REAL_TIME. This is not working, but if changed to WATERMARK, the timer is triggered as expected.

import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline as dppPipeline
from apache_beam.options.pipeline_options import DirectOptions
from apache_beam.testing.test_stream import TestStream as tStream
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec, BagStateSpec, TimerSpec, on_timer 
from apache_beam.utils.timestamp import Timestamp, Duration
from apache_beam.coders import StrUtf8Coder, FloatCoder
from typing import Tuple
import pytest

# WATERMARK or REAL_TIME
TIMEDOMAIN = beam.TimeDomain.REAL_TIME


@beam.typehints.with_input_types(Tuple[str, str])
class InactivityTimerDoFn(beam.DoFn):
    """
    A DoFn that sets a real-time or waterrmark timer for 1.0s after receiving an element.
    If no new element arrives for that key within 1.0 seconds, the timer fires.
    """
    STATE_BUFFER = BagStateSpec("state_buffer", StrUtf8Coder())
    INACTIVITY_TIMER = TimerSpec('inactivity_timer', TIMEDOMAIN)
    EXPIRATION_STATE = ReadModifyWriteStateSpec('expiration_time', FloatCoder())    # Store expiration timestamp for debugging

    def process(self,
                element,
                timer=beam.DoFn.TimerParam(INACTIVITY_TIMER),                       # Timer to fire after 1.0s of inactivity
                state_buffer=beam.DoFn.StateParam(STATE_BUFFER),                    # Buffer to store elements in utill timer fires
                expiration_state=beam.DoFn.StateParam(EXPIRATION_STATE),            # Store expiration timestamp for debugging
                element_timestamp=beam.DoFn.TimestampParam                          # Timestamp of the element entering the pipeline
                ):
        
        key, value = element
        state_buffer.add(value)
        if TIMEDOMAIN == beam.TimeDomain.WATERMARK:
            timestamp = element_timestamp + 1
        else:
            timestamp= Timestamp.now() + Duration(seconds=1)
        
        expiration_state.write(timestamp)
        timer.set(timestamp)                                                        # Fire 1.0s from current processing time or watermark
        print(f'Timer set for: {timestamp}')

    @on_timer(INACTIVITY_TIMER)
    def on_inactivity_expired(self,
                              state_buffer=beam.DoFn.StateParam(STATE_BUFFER),
                              timer=beam.DoFn.TimerParam(INACTIVITY_TIMER),
                              expiration_state=beam.DoFn.StateParam(EXPIRATION_STATE)):
        
        
        expiration_timestamp = expiration_state.read()                              # Read expiration timestamp for debugging
        if TIMEDOMAIN == beam.TimeDomain.WATERMARK:
            print(f"Timer fired! Expected watermark: {expiration_timestamp}")
        else:
            print(f"Timer fired! Expected expiration time: {expiration_timestamp}, actual time: {Timestamp.now()}")
        buffer = list(state_buffer.read())
        state_buffer.clear()
        if buffer:
            print(f'Elements in buffer when timer triggered: {buffer}')
        timer.clear()
        expiration_state.clear()  # Clear stored timestamp

# ===========================
# Test Case Using TestStream
# ===========================

def test_inactivity_timer():
    runner_options = {
        'runner': 'DirectRunner',
        'streaming': True,
    }
    
    pipeline_options = DirectOptions([], **runner_options)
    
    # Create a TestStream to simulate real-time processing.
    # Watermark and processing time are advanced synchronously.
    test_stream = (
        'TestStream' >> tStream()
        .advance_watermark_to(0)
        .advance_processing_time(advance_by=round(float(Timestamp.now()), 3)) # Advance processing time to now. Only milliseconds precision accepted
        .add_elements([('key1', 'A')])
        .advance_processing_time(Duration(seconds=0, micros=500_000))
        .advance_watermark_to(0.5)
        .add_elements([('key1', 'B')])
        .advance_processing_time(Duration(seconds=0, micros=100_000))
        .advance_watermark_to(0.6)
        .add_elements([('key1', 'C')])
        .advance_processing_time(Duration(seconds=1, micros=100_000))   # more than 1s passed. Fire timer and print ['A', 'B', 'C']
        .advance_watermark_to(1.7)
        .add_elements([('key1', 'D'), ('key1', 'E')])
        .add_elements([('key1', 'F')])
        .advance_processing_time(Duration(seconds=0, micros=300_000))
        .advance_watermark_to(2.0)
        .add_elements([('key1', 'G')])
        .advance_processing_time(Duration(seconds=0, micros=900_000))
        .advance_watermark_to(2.9)
        .add_elements([('key1', 'H')])
        .advance_processing_time(Duration(seconds=1, micros=300_000))   # more than 1s passed. Fire timer and print ['D', 'E', 'F', 'G', 'H']
        .advance_watermark_to(4.2)
        .add_elements([('key1', 'I')])
        .add_elements([('key1', 'J')])
        .advance_watermark_to_infinity()                                # End the stream and print ['I', 'J']
    )
    # Run the pipeline
    with dppPipeline(options=pipeline_options) as p:

        # Apply the DoFn on the test stream
        outputs = (
            p
            | test_stream
            | "Force Coder" >> beam.Map(lambda x: x).with_output_types(Tuple[str, str])  
            | 'InactivityTimerDoFn' >> beam.ParDo(InactivityTimerDoFn())
        )

        print('output', outputs)

if __name__ == '__main__':
    pytest.main(['-s', __file__])

The expected output related to elements in the buffer printed to the console when the timer fires should be:

Elements in buffer when timer triggered: ['A', 'B', 'C']

Elements in buffer when timer triggered: ['D', 'E', 'F', 'G', 'H']

Elements in buffer when timer triggered: ['I', 'J']

Upvotes: 0

Views: 24

Answers (0)

Related Questions