Ngo Cuong
Ngo Cuong

Reputation: 83

How is str.join(iterable) method implemented in Python/ Linear time string concatenation

I am trying to implement my own str.join method in Python, e.g: ''.join(['aa','bbb','cccc']) returns 'aabbbcccc'. I know that string concatenation using the join method would result in linear (in the number of characters of the result) complexity, and I want to know how to do it, as using the '+' operator in a for loop would result in quadratic complexity e.g.:

res=''
for word in ['aa','bbb','cccc']:
  res = res +  word

As strings are immutable, this copies a new string at each iteration resulting in quadratic running time. However, I want to know how to do it in linear time or find how ''.join works exactly.

I could not find anywhere a linear time algorithm nor the implementation of str.join(iterable). Any help is much appreciated.

Upvotes: 7

Views: 2449

Answers (1)

MisterMiyagi
MisterMiyagi

Reputation: 52159

Joining str as actual str is a red herring and not what Python itself does: Python operates on mutable bytes, not the str, which also removes the need to know string internals. In specific, str.join converts its arguments to bytes, then pre-allocates and mutates its result.

This directly corresponds to:

  1. a wrapper to encode/decode str arguments to/from bytes
  2. summing the len of elements and separators
  3. allocating a mutable bytesarray to construct the result
  4. copying each element/separator directly into the result
# helper to convert to/from joinable bytes
def str_join(sep: "str", elements: "list[str]") -> "str":
    joined_bytes = bytes_join(
        sep.encode(),
        [elem.encode() for elem in elements],
    )
    return joined_bytes.decode()

# actual joining at bytes level
def bytes_join(sep: "bytes", elements: "list[bytes]") -> "bytes":
    # create a mutable buffer that is long enough to hold the result
    total_length = sum(len(elem) for elem in elements)
    total_length += (len(elements) - 1) * len(sep)
    result = bytearray(total_length)
    # copy all characters from the inputs to the result
    insert_idx = 0
    for elem in elements:
        result[insert_idx:insert_idx+len(elem)] = elem
        insert_idx += len(elem)
        if insert_idx < total_length:
            result[insert_idx:insert_idx+len(sep)] = sep
            insert_idx += len(sep)
    return bytes(result)

print(str_join(" ", ["Hello", "World!"]))

Notably, while the element iteration and element copying basically are two nested loops, they iterate over separate things. The algorithm still touches each character/byte only thrice/once.

Upvotes: 6

Related Questions