Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
import google.cloud.bigtable.data.exceptions as bt_exceptions
import google.cloud.bigtable_v2.types.bigtable as types_pb
from google.cloud.bigtable.data._cross_sync import CrossSync
from google.cloud.bigtable.data._helpers import (
_attempt_timeout_generator,
_retry_exception_factory,
)
from google.cloud.bigtable.data._helpers import _attempt_timeout_generator
from google.cloud.bigtable.data._metrics import tracked_retry

# mutate_rows requests are limited to this number of mutations
from google.cloud.bigtable.data.mutations import (
Expand All @@ -34,6 +32,7 @@
)

if TYPE_CHECKING:
from google.cloud.bigtable.data._metrics import ActiveOperationMetric
from google.cloud.bigtable.data.mutations import RowMutationEntry

if CrossSync.is_async:
Expand Down Expand Up @@ -72,6 +71,8 @@ class _MutateRowsOperationAsync:
operation_timeout: the timeout to use for the entire operation, in seconds.
attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds.
If not specified, the request will run until operation_timeout is reached.
metric: the metric object representing the active operation
retryable_exceptions: a list of exceptions that should be retried
"""

@CrossSync.convert
Expand All @@ -82,6 +83,7 @@ def __init__(
mutation_entries: list["RowMutationEntry"],
operation_timeout: float,
attempt_timeout: float | None,
metric: ActiveOperationMetric,
retryable_exceptions: Sequence[type[Exception]] = (),
):
# check that mutations are within limits
Expand All @@ -101,13 +103,12 @@ def __init__(
# Entry level errors
bt_exceptions._MutateRowsIncomplete,
)
sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60)
self._operation = lambda: CrossSync.retry_target(
self._run_attempt,
self.is_retryable,
sleep_generator,
operation_timeout,
exception_factory=_retry_exception_factory,
self._operation = lambda: tracked_retry(
retry_fn=CrossSync.retry_target,
operation=metric,
target=self._run_attempt,
predicate=self.is_retryable,
timeout=operation_timeout,
)
# initialize state
self.timeout_generator = _attempt_timeout_generator(
Expand All @@ -116,6 +117,8 @@ def __init__(
self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries]
self.remaining_indices = list(range(len(self.mutations)))
self.errors: dict[int, list[Exception]] = {}
# set up metrics
self._operation_metric = metric

@CrossSync.convert
async def start(self):
Expand All @@ -125,34 +128,35 @@ async def start(self):
Raises:
MutationsExceptionGroup: if any mutations failed
"""
try:
# trigger mutate_rows
await self._operation()
except Exception as exc:
# exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations
incomplete_indices = self.remaining_indices.copy()
for idx in incomplete_indices:
self._handle_entry_error(idx, exc)
finally:
# raise exception detailing incomplete mutations
all_errors: list[Exception] = []
for idx, exc_list in self.errors.items():
if len(exc_list) == 0:
raise core_exceptions.ClientError(
f"Mutation {idx} failed with no associated errors"
with self._operation_metric:
try:
# trigger mutate_rows
await self._operation()
except Exception as exc:
# exceptions raised by retryable are added to the list of exceptions for all unfinalized mutations
incomplete_indices = self.remaining_indices.copy()
for idx in incomplete_indices:
self._handle_entry_error(idx, exc)
finally:
# raise exception detailing incomplete mutations
all_errors: list[Exception] = []
for idx, exc_list in self.errors.items():
if len(exc_list) == 0:
raise core_exceptions.ClientError(
f"Mutation {idx} failed with no associated errors"
)
elif len(exc_list) == 1:
cause_exc = exc_list[0]
else:
cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
entry = self.mutations[idx].entry
all_errors.append(
bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
)
if all_errors:
raise bt_exceptions.MutationsExceptionGroup(
all_errors, len(self.mutations)
)
elif len(exc_list) == 1:
cause_exc = exc_list[0]
else:
cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
entry = self.mutations[idx].entry
all_errors.append(
bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
)
if all_errors:
raise bt_exceptions.MutationsExceptionGroup(
all_errors, len(self.mutations)
)

@CrossSync.convert
async def _run_attempt(self):
Expand All @@ -164,6 +168,8 @@ async def _run_attempt(self):
retry after the attempt is complete
GoogleAPICallError: if the gapic rpc fails
"""
# register attempt start
self._operation_metric.start_attempt()
request_entries = [self.mutations[idx].proto for idx in self.remaining_indices]
# track mutations in this request that have not been finalized yet
active_request_indices = {
Expand Down
Loading
Loading