CYC
CYC

Reputation: 325

How to improve the speed of bigquery storage write api with pending mode in python

Our scenario is insert a roughly 4 million rows data into bigquery at the same time

Follow the google document, Batch load data using the Storage Write API, https://cloud.google.com/bigquery/docs/write-api-batch

We are using the bigquery storage write api to insert the data into the bigquery with pending mode, but the speed is slow when have a huge rows in the data

the package we use is

protobuf==3.20.0

google-cloud-bigquery==3.18.0

google-cloud-storage==2.14.0

google-cloud-container==2.12.0

google-cloud-bigquery-storage==2.14.0

does anyone know how to improve it? and we also find the proto_rows.serialized_rows.append will become slowly when append more data, that is the reason we choose the batch size as 300

for example we insert the 1000000 rows need 115 secs the sample code as below

import pandas as pd
import numpy as np
from configs import logger

import pandas as pd
import time
from logging import Logger
from google.cloud.bigquery_storage_v1 import types, writer, BigQueryWriteClient
from google.protobuf import descriptor_pb2
from google.cloud.bigquery import Client

import sample_demo_pb2
import google.auth
from google.auth import impersonated_credentials


def timing():
    import time
    from functools import wraps

    def decorator(func):
        @wraps(func)
        def func_wrapper(*args, **kwargs):
            time1 = time.time()
            ret = func(*args, **kwargs)
            time2 = time.time()
            print(
                "%s function complete in %.6f secs"
                % (func.__name__, (time2 - time1))
            )
            return ret

        return func_wrapper

    return decorator


class ServiceBase:
    """A base class for setting up service account credentials and project_id."""

    def __init__(self, lifetime: int = 3600) -> None:
        """Setup credentials and project_id from the host environment.

        Args:
            lifetime (int, optional): The lifetime of a service credential. Defaults to 3600.

        Raises:
            LookupError: project_id can not be found from the host environment.
        """
        super().__init__()
        source_credentials, project_id = google.auth.default()
        if project_id is None:
            if source_credentials.quota_project_id is not None:
                project_id = source_credentials.quota_project_id
            else:
                raise LookupError(
                    "Required project_id can not be found. Please try to setup GOOGLE_CLOUD_PROJECT environment variable."
                )
        service_account = (
            f"shared-service-account@{project_id}.iam.gserviceaccount.com"
        )
        target_scopes = [
            "https://www.googleapis.com/auth/cloud-platform",
            "https://www.googleapis.com/auth/userinfo.email",
        ]
        self.credentials = impersonated_credentials.Credentials(
            source_credentials=source_credentials,
            target_principal=service_account,
            target_scopes=target_scopes,
            lifetime=lifetime,
        )
        self.project_id = project_id


class BigQueryService(ServiceBase):
    __slots__ = ("client",)

    def __init__(self, logger: Logger) -> None:
        super().__init__()
        self.client = Client(
            project=self.project_id, credentials=self.credentials
        )
        self.writer_client = BigQueryWriteClient(credentials=self.credentials)
        self._logger = logger

    @timing()
    def to_bqtable(
        self,
        data: pd.DataFrame,
        full_table_name: str,
    ) -> None:
        init_time = time.time()
        table_name = full_table_name.split(".")[-1]

        project_id, dataset_id, table_id = full_table_name.split(".")
        bigquery_table = self.writer_client.table_path(
            project_id, dataset_id, table_id
        )
        write_stream = types.WriteStream()
        write_stream.type_ = types.WriteStream.Type.PENDING
        write_stream = self.writer_client.create_write_stream(
            parent=bigquery_table, write_stream=write_stream
        )

        stream_name = write_stream.name
        request_template = types.AppendRowsRequest()
        request_template.write_stream = stream_name

        proto_schema = types.ProtoSchema()
        proto_descriptor = descriptor_pb2.DescriptorProto()
        sample_demo_pb2.demo_data.DESCRIPTOR.CopyToProto(proto_descriptor)

        proto_schema.proto_descriptor = proto_descriptor
        proto_data = types.AppendRowsRequest.ProtoData()
        proto_data.writer_schema = proto_schema
        request_template.proto_rows = proto_data

        append_rows_stream = writer.AppendRowsStream(
            self.writer_client, request_template
        )

        init_convet_time = time.time()
        list_data = data.to_dict("records")
        int2_time = time.time()
        print(f"convert list_data using {int2_time - init_convet_time} secs")
        batch_szie = 1000
        n = len(list_data)
        request_tmps = []
        for batch_idx in range(0, n, batch_szie):
            print(f"{n} {batch_idx} to {min(batch_idx + batch_szie, n)}")
            request = self.serialized_data(
                list_data[batch_idx : min(batch_idx + batch_szie, n)],
                batch_idx,
                table_name,
            )
            request_tmp = append_rows_stream.send(request)
            request_tmps.append(request_tmp)
        for r in request_tmps:
            r.result()
        append_rows_stream.close()
        self.writer_client.finalize_write_stream(name=write_stream.name)

        batch_commit_write_streams_request = (
            types.BatchCommitWriteStreamsRequest()
        )
        batch_commit_write_streams_request.parent = bigquery_table
        batch_commit_write_streams_request.write_streams = [write_stream.name]
        self.writer_client.batch_commit_write_streams(
            batch_commit_write_streams_request
        )

        print(f"Writes to stream: '{write_stream.name}' have been committed.")

    @timing()
    def serialized_data(self, data: list, offset: int, table_name: str):
        proto_rows = types.ProtoRows()

        init_time = time.time()
        serialized_data = [self.create_serialized_data(row) for row in data]
        # print(f"serialized_data using {time.time() - init_time} secs")

        for idx, ele in enumerate(serialized_data):
            init2_time = time.time()
            proto_rows.serialized_rows.append(
                ele
            )  # this line will very slow when append a lot of data, ex more then 1k rows
            # if idx % 500 == 0:
            #     print(
            #         f"{idx} extend single serialized_data using {time.time() - init2_time} secs"
            #     )
        request = types.AppendRowsRequest()
        request.offset = offset
        proto_data = types.AppendRowsRequest.ProtoData()
        proto_data.rows = proto_rows
        request.proto_rows = proto_data
        return request

    # @timing()
    def create_serialized_data(self, rows_to_insert: dict):
        # Note XXX_pb2 is the file compiled by the protocol buffer compiler
        return sample_demo_pb2.demo_data(**rows_to_insert).SerializeToString()

    def __del__(self) -> None:
        self.client.close()


rows = 1000000
bq_client = BigQueryService(logger)
list_data = [{"Index": str(i), "Name": f"name_{i}"} for i in range(rows)]
data = pd.DataFrame(list_data)
bq_client.to_bqtable(
    data,
    "XXX.default_dataset.demo_data",
)
print("=============================")

the sample_demo_pb2.py in as the below

# Generated by the protocol buffer compiler.  DO NOT EDIT!
# source: sample_demo.proto

import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor.FileDescriptor(
  name='sample_demo.proto',
  package='',
  syntax='proto2',
  serialized_options=None,
  serialized_pb=_b('\n\x11sample_demo.proto\"(\n\tdemo_data\x12\r\n\x05Index\x18\x01 \x01(\t\x12\x0c\n\x04Name\x18\x02 \x01(\t')
)




_DEMO_DATA = _descriptor.Descriptor(
  name='demo_data',
  full_name='demo_data',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  fields=[
    _descriptor.FieldDescriptor(
      name='Index', full_name='demo_data.Index', index=0,
      number=1, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
    _descriptor.FieldDescriptor(
      name='Name', full_name='demo_data.Name', index=1,
      number=2, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=_b("").decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=21,
  serialized_end=61,
)

DESCRIPTOR.message_types_by_name['demo_data'] = _DEMO_DATA
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

demo_data = _reflection.GeneratedProtocolMessageType('demo_data', (_message.Message,), dict(
  DESCRIPTOR = _DEMO_DATA,
  __module__ = 'sample_demo_pb2'
  # @@protoc_insertion_point(class_scope:demo_data)
  ))
_sym_db.RegisterMessage(demo_data)


# @@protoc_insertion_point(module_scope)


Upvotes: 0

Views: 353

Answers (1)

guillaume blaquiere
guillaume blaquiere

Reputation: 75920

In your case, the Write API is not the right option. It is design for streaming.

If you want to batch insert your data, use a load job: faster and cheaper. Because you tell me the data comes from Cloud Storage, I think it's the best design for your 4M rows

If you need to transform your data to have a bigquery compliant format (CSV, AVRO, ORC,...) perform a pre-processing and create a new file on Cloud Storage. Then load this file.

You can use orchestration services to perform these steps in sequences with Cloud Workflows or Cloud Composer (managed Airflow)

Upvotes: 1

Related Questions