diff --git a/cassandra/cluster.py b/cassandra/cluster.py index a9c1d00e97..099043eae0 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3831,14 +3831,16 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host_by_host_id(host_id) if host and host.endpoint != endpoint: log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) + old_endpoint = host.endpoint + host.endpoint = endpoint + self._cluster.metadata.update_host(host, old_endpoint) + self._cluster.on_up(host) + if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..5ffc9e5bd8 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -14,7 +14,7 @@ import random from collections import namedtuple -from itertools import islice, cycle, groupby, repeat +from itertools import islice, cycle, groupby, repeat, chain import logging from random import randint, shuffle from threading import Lock @@ -157,6 +157,18 @@ def make_query_plan(self, working_keyspace=None, query=None): """ raise NotImplementedError() + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + """ + Same as :meth:`make_query_plan`, but with an additional `excluded` parameter. + `excluded` should be a container (set, list, etc.) of hosts to skip. + + The default implementation simply delegates to `make_query_plan` and filters the result. + Subclasses may override this for performance. + """ + for host in self.make_query_plan(working_keyspace, query): + if host not in excluded: + yield host + def check_supported(self): """ This will be called after the cluster Metadata has been initialized. @@ -198,6 +210,20 @@ def make_query_plan(self, working_keyspace=None, query=None): else: return [] + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + pos = self._position + self._position += 1 + + hosts = self._live_hosts + length = len(hosts) + if length: + pos %= length + for host in islice(cycle(hosts), pos, pos + length): + if host not in excluded: + yield host + else: + return + def on_up(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.union((host, )) @@ -244,34 +270,36 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0): self.local_dc = local_dc self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + self._remote_hosts = {} self._position = 0 LoadBalancingPolicy.__init__(self) def _dc(self, host): return host.datacenter or self.local_dc + def _refresh_remote_hosts(self): + remote_hosts = {} + if self.used_hosts_per_remote_dc > 0: + for datacenter, hosts in self._dc_live_hosts.items(): + if datacenter != self.local_dc: + remote_hosts.update(dict.fromkeys(hosts[:self.used_hosts_per_remote_dc])) + self._remote_hosts = remote_hosts + def populate(self, cluster, hosts): for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) self._position = randint(0, len(hosts) - 1) if hosts else 0 + self._refresh_remote_hosts() def distance(self, host): dc = self._dc(host) if dc == self.local_dc: return HostDistance.LOCAL - if not self.used_hosts_per_remote_dc: - return HostDistance.IGNORED - else: - dc_hosts = self._dc_live_hosts.get(dc) - if not dc_hosts: - return HostDistance.IGNORED - - if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]: - return HostDistance.REMOTE - else: - return HostDistance.IGNORED + if host in self._remote_hosts: + return HostDistance.REMOTE + return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): # not thread-safe, but we don't care much about lost increments @@ -280,26 +308,59 @@ def make_query_plan(self, working_keyspace=None, query=None): self._position += 1 local_live = self._dc_live_hosts.get(self.local_dc, ()) - pos = (pos % len(local_live)) if local_live else 0 - for host in islice(cycle(local_live), pos, pos + len(local_live)): + length = len(local_live) + if length: + pos %= length + for i in range(length): + yield local_live[(pos + i) % length] + + for host in self._remote_hosts: yield host - # the dict can change, so get candidate DCs iterating over keys of a copy - other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc] - for dc in other_dcs: - remote_live = self._dc_live_hosts.get(dc, ()) - for host in remote_live[:self.used_hosts_per_remote_dc]: + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + # not thread-safe, but we don't care much about lost increments + # for the purposes of load balancing + pos = self._position + self._position += 1 + + local_live = self._dc_live_hosts.get(self.local_dc, ()) + length = len(local_live) + if not excluded: + if length: + pos %= length + for i in range(length): + yield local_live[(pos + i) % length] + for host in self._remote_hosts: + yield host + return + + if not isinstance(excluded, set): + excluded = set(excluded) + + if length: + pos %= length + for i in range(length): + host = local_live[(pos + i) % length] + if host in excluded: + continue yield host + for host in self._remote_hosts: + if host in excluded: + continue + yield host + def on_up(self, host): # not worrying about threads because this will happen during # control connection startup/refresh + refresh_remote = False if not self.local_dc and host.datacenter: self.local_dc = host.datacenter log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); " "if incorrect, please specify a local_dc to the constructor, " "or limit contact points to local cluster nodes" % (self.local_dc, host.endpoint)) + refresh_remote = True dc = self._dc(host) with self._hosts_lock: @@ -307,6 +368,9 @@ def on_up(self, host): if host not in current_hosts: self._dc_live_hosts[dc] = current_hosts + (host, ) + if refresh_remote or dc != self.local_dc: + self._refresh_remote_hosts() + def on_down(self, host): dc = self._dc(host) with self._hosts_lock: @@ -318,6 +382,9 @@ def on_down(self, host): else: del self._dc_live_hosts[dc] + if dc != self.local_dc: + self._refresh_remote_hosts() + def on_add(self, host): self.on_up(host) @@ -353,6 +420,8 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0): self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._live_hosts = {} self._dc_live_hosts = {} + self._remote_hosts = {} + self._non_local_rack_hosts = [] self._endpoints = [] self._position = 0 LoadBalancingPolicy.__init__(self) @@ -363,6 +432,18 @@ def _rack(self, host): def _dc(self, host): return host.datacenter or self.local_dc + def _refresh_remote_hosts(self): + remote_hosts = {} + if self.used_hosts_per_remote_dc > 0: + for datacenter, hosts in self._dc_live_hosts.items(): + if datacenter != self.local_dc: + remote_hosts.update(dict.fromkeys(hosts[:self.used_hosts_per_remote_dc])) + self._remote_hosts = remote_hosts + + def _refresh_non_local_rack_hosts(self): + local_live = self._dc_live_hosts.get(self.local_dc, ()) + self._non_local_rack_hosts = [h for h in local_live if self._rack(h) != self.local_rack] + def populate(self, cluster, hosts): for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))): self._live_hosts[(dc, rack)] = tuple({*rack_hosts, *self._live_hosts.get((dc, rack), [])}) @@ -370,71 +451,113 @@ def populate(self, cluster, hosts): self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])}) self._position = randint(0, len(hosts) - 1) if hosts else 0 + self._refresh_remote_hosts() + self._refresh_non_local_rack_hosts() def distance(self, host): - rack = self._rack(host) dc = self._dc(host) - if rack == self.local_rack and dc == self.local_dc: - return HostDistance.LOCAL_RACK - if dc == self.local_dc: + if self._rack(host) == self.local_rack: + return HostDistance.LOCAL_RACK return HostDistance.LOCAL - if not self.used_hosts_per_remote_dc: - return HostDistance.IGNORED - - dc_hosts = self._dc_live_hosts.get(dc, ()) - if not dc_hosts: - return HostDistance.IGNORED - if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc: + if host in self._remote_hosts: return HostDistance.REMOTE - else: - return HostDistance.IGNORED + return HostDistance.IGNORED def make_query_plan(self, working_keyspace=None, query=None): pos = self._position self._position += 1 local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) - pos = (pos % len(local_rack_live)) if local_rack_live else 0 - # Slice the cyclic iterator to start from pos and include the next len(local_live) elements - # This ensures we get exactly one full cycle starting from pos - for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): - yield host + length = len(local_rack_live) + if length: + p = pos % length + for i in range(length): + yield local_rack_live[(p + i) % length] + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + yield local_non_rack[(p + i) % length] - local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] - pos = (pos % len(local_live)) if local_live else 0 - for host in islice(cycle(local_live), pos, pos + len(local_live)): + for host in self._remote_hosts: yield host + + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + pos = self._position + self._position += 1 - # the dict can change, so get candidate DCs iterating over keys of a copy - for dc, remote_live in self._dc_live_hosts.copy().items(): - if dc != self.local_dc: - for host in remote_live[:self.used_hosts_per_remote_dc]: - yield host + local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + length = len(local_rack_live) + if not excluded: + if length: + p = pos % length + for i in range(length): + yield local_rack_live[(p + i) % length] + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + yield local_non_rack[(p + i) % length] + + for host in self._remote_hosts: + yield host + return + + if not isinstance(excluded, set): + excluded = set(excluded) + + if length: + p = pos % length + for i in range(length): + host = local_rack_live[(p + i) % length] + if host in excluded: + continue + yield host + + local_non_rack = self._non_local_rack_hosts + length = len(local_non_rack) + if length: + p = pos % length + for i in range(length): + host = local_non_rack[(p + i) % length] + if host in excluded: + continue + yield host + + for host in self._remote_hosts: + if host in excluded: + continue + yield host def on_up(self, host): dc = self._dc(host) rack = self._rack(host) with self._hosts_lock: - current_rack_hosts = self._live_hosts.get((dc, rack), ()) - if host not in current_rack_hosts: - self._live_hosts[(dc, rack)] = current_rack_hosts + (host, ) current_dc_hosts = self._dc_live_hosts.get(dc, ()) if host not in current_dc_hosts: self._dc_live_hosts[dc] = current_dc_hosts + (host, ) + if dc != self.local_dc: + self._refresh_remote_hosts() + else: + self._refresh_non_local_rack_hosts() + + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host not in current_rack_hosts: + self._live_hosts[(dc, rack)] = current_rack_hosts + (host, ) + if dc == self.local_dc: + self._refresh_non_local_rack_hosts() + def on_down(self, host): dc = self._dc(host) rack = self._rack(host) with self._hosts_lock: - current_rack_hosts = self._live_hosts.get((dc, rack), ()) - if host in current_rack_hosts: - hosts = tuple(h for h in current_rack_hosts if h != host) - if hosts: - self._live_hosts[(dc, rack)] = hosts - else: - del self._live_hosts[(dc, rack)] current_dc_hosts = self._dc_live_hosts.get(dc, ()) if host in current_dc_hosts: hosts = tuple(h for h in current_dc_hosts if h != host) @@ -443,6 +566,22 @@ def on_down(self, host): else: del self._dc_live_hosts[dc] + if dc != self.local_dc: + self._refresh_remote_hosts() + else: + self._refresh_non_local_rack_hosts() + + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host in current_rack_hosts: + hosts = tuple(h for h in current_rack_hosts if h != host) + if hosts: + self._live_hosts[(dc, rack)] = hosts + else: + del self._live_hosts[(dc, rack)] + + if dc == self.local_dc: + self._refresh_non_local_rack_hosts() + def on_add(self, host): self.on_up(host) @@ -466,16 +605,12 @@ class TokenAwarePolicy(LoadBalancingPolicy): policy's query plan will be used as is. """ - _child_policy = None - _cluster_metadata = None - shuffle_replicas = True - """ - Yield local replicas in a random order. - """ + __slots__ = ('_child_policy', '_cluster_metadata', 'shuffle_replicas') def __init__(self, child_policy, shuffle_replicas=True): self._child_policy = child_policy self.shuffle_replicas = shuffle_replicas + self._cluster_metadata = None def populate(self, cluster, hosts): self._cluster_metadata = cluster.metadata @@ -498,35 +633,73 @@ def make_query_plan(self, working_keyspace=None, query=None): child = self._child_policy if query is None or query.routing_key is None or keyspace is None: - for host in child.make_query_plan(keyspace, query): - yield host + yield from child.make_query_plan(keyspace, query) return + cluster_metadata = self._cluster_metadata + token_map = cluster_metadata.token_map replicas = [] - tablet = self._cluster_metadata._tablets.get_tablet_for_key( - keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key)) - if tablet is not None: - replicas_mapped = set(map(lambda r: r[0], tablet.replicas)) - child_plan = child.make_query_plan(keyspace, query) + if token_map: + try: + token = token_map.token_class.from_key(query.routing_key) + tablet = cluster_metadata._tablets.get_tablet_for_key( + keyspace, query.table, token) + + if tablet is not None: + replicas_mapped = {r[0] for r in tablet.replicas} + for host_id in replicas_mapped: + host = cluster_metadata.get_host_by_host_id(host_id) + if host: + replicas.append(host) + else: + try: + replicas = token_map.get_replicas(keyspace, token) + except Exception: + replicas = cluster_metadata.get_replicas(keyspace, query.routing_key) + except Exception: + pass - replicas = [host for host in child_plan if host.host_id in replicas_mapped] - else: - replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key) if self.shuffle_replicas and not query.is_lwt(): + replicas = list(replicas) shuffle(replicas) - def yield_in_order(hosts): - for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]: - for replica in hosts: - if replica.is_up and child.distance(replica) == distance: - yield replica - - # yield replicas: local_rack, local, remote - yield from yield_in_order(replicas) - # yield rest of the cluster: local_rack, local, remote - yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas]) + local_rack = [] + local = [] + remote = [] + + child_distance = child.distance + + for replica in replicas: + if replica.is_up: + d = child_distance(replica) + if d == HostDistance.LOCAL_RACK: + local_rack.append(replica) + elif d == HostDistance.LOCAL: + local.append(replica) + elif d == HostDistance.REMOTE: + remote.append(replica) + + if local_rack or local or remote: + yielded = set() + + for replica in local_rack: + yielded.add(replica) + yield replica + + for replica in local: + yielded.add(replica) + yield replica + + for replica in remote: + yielded.add(replica) + yield replica + + # yield rest of the cluster + yield from child.make_query_plan_with_exclusion(keyspace, query, yielded) + else: + yield from child.make_query_plan(keyspace, query) def on_up(self, *args, **kwargs): return self._child_policy.on_up(*args, **kwargs) @@ -702,6 +875,16 @@ def make_query_plan(self, working_keyspace=None, query=None): if self.predicate(host): yield host + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + if excluded: + excluded = set(excluded) + child_qp = self._child_policy.make_query_plan_with_exclusion( + working_keyspace=working_keyspace, query=query, excluded=excluded + ) + for host in child_qp: + if self.predicate(host): + yield host + def check_supported(self): return self._child_policy.check_supported() @@ -1356,6 +1539,27 @@ def make_query_plan(self, working_keyspace=None, query=None): for h in child.make_query_plan(keyspace, query): yield h + def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()): + if query and query.keyspace: + keyspace = query.keyspace + else: + keyspace = working_keyspace + + addr = getattr(query, 'target_host', None) if query else None + target_host = self._cluster_metadata.get_host(addr) + + if excluded: + excluded = set(excluded) + + child = self._child_policy + if target_host and target_host.is_up and target_host not in excluded: + yield target_host + for h in child.make_query_plan_with_exclusion(keyspace, query, excluded): + if h != target_host: + yield h + else: + yield from child.make_query_plan_with_exclusion(keyspace, query, excluded) + # TODO for backward compatibility, remove in next major class DSELoadBalancingPolicy(DefaultLoadBalancingPolicy): diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..20ea672f45 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -274,6 +274,16 @@ def test_get_distance(self, policy_specialization, constructor_args): assert policy.distance(host) == HostDistance.LOCAL_RACK # same dc different rack + # Reset policy state to simulate a fresh view or handle the "move" correctly + # In a real scenario, a host moving racks would be handled by on_down/on_up or distinct host objects. + # Here we are reusing the same policy instance with populate(), which merges hosts. + # To avoid the host existing in both rack1 and rack2 buckets due to address equality, + # we clear the internal state. + if hasattr(policy, '_live_hosts'): + policy._live_hosts.clear() + if hasattr(policy, '_dc_live_hosts'): + policy._dc_live_hosts.clear() + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) host.set_location_info("dc1", "rack2") policy.populate(Mock(), [host]) @@ -577,6 +587,8 @@ def test_wrap_round_robin(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -586,8 +598,9 @@ def get_replicas(keyspace, packed_key): return list(islice(cycle(hosts), index, index + 2)) cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(RoundRobinPolicy()) + policy = TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -610,6 +623,8 @@ def test_wrap_dc_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() @@ -627,8 +642,9 @@ def get_replicas(keyspace, packed_key): return [hosts[1], hosts[3]] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2)) + policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -659,6 +675,8 @@ def test_wrap_rack_aware(self): cluster.metadata = Mock(spec=Metadata) cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() @@ -680,8 +698,9 @@ def get_replicas(keyspace, packed_key): return [hosts[4], hosts[5], hosts[6], hosts[7]] cluster.metadata.get_replicas.side_effect = get_replicas + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - policy = TokenAwarePolicy(RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=4)) + policy = TokenAwarePolicy(RackAwareRoundRobinPolicy("dc1", "rack1", used_hosts_per_remote_dc=4), shuffle_replicas=False) policy.populate(cluster, hosts) for i in range(4): @@ -804,12 +823,16 @@ def test_statement_keyspace(self): replicas = hosts[2:] cluster.metadata.get_replicas.return_value = replicas cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas child_policy = Mock() child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] child_policy.distance.return_value = HostDistance.LOCAL - policy = TokenAwarePolicy(child_policy) + policy = TokenAwarePolicy(child_policy, shuffle_replicas=False) policy.populate(cluster, hosts) # no keyspace, child policy is called @@ -897,6 +920,9 @@ def _prepare_cluster_with_vnodes(self): cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas return cluster def _prepare_cluster_with_tablets(self): @@ -909,14 +935,22 @@ def _prepare_cluster_with_tablets(self): cluster.metadata.all_hosts.return_value = hosts cluster.metadata.get_replicas.return_value = hosts[2:] cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(replicas=[(h.host_id, 0) for h in hosts[2:]]) + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas return cluster @patch('cassandra.policies.shuffle') def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): hosts = cluster.metadata.all_hosts() - replicas = cluster.metadata.get_replicas() + # Configure get_host_by_host_id to return hosts from the list + host_map = {h.host_id: h for h in hosts} + cluster.metadata.get_host_by_host_id.side_effect = lambda hid: host_map.get(hid) + + replicas = list(cluster.metadata.get_replicas()) child_policy = Mock() child_policy.make_query_plan.return_value = hosts + child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [h for h in hosts if h not in e] child_policy.distance.return_value = HostDistance.LOCAL policy = TokenAwarePolicy(child_policy, shuffle_replicas=True) @@ -926,6 +960,7 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): cluster.metadata.get_replicas.reset_mock() child_policy.make_query_plan.reset_mock() + child_policy.make_query_plan_with_exclusion.reset_mock() query = Statement(routing_key=routing_key) qplan = list(policy.make_query_plan(keyspace, query)) if keyspace is None or routing_key is None: @@ -936,7 +971,10 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key): else: assert set(replicas) == set(qplan[:2]) assert hosts[:2] == qplan[2:] - if is_tablets: + + if child_policy.make_query_plan_with_exclusion.called: + child_policy.make_query_plan_with_exclusion.assert_called() + elif is_tablets: child_policy.make_query_plan.assert_called_with(keyspace, query) assert child_policy.make_query_plan.call_count == 2 else: @@ -1630,8 +1668,11 @@ def get_replicas(keyspace, packed_key): cluster.metadata.get_replicas.side_effect = get_replicas cluster.metadata._tablets = Mock(spec=Tablets) cluster.metadata._tablets.get_tablet_for_key.return_value = None + cluster.metadata.token_map = Mock() + cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key + cluster.metadata.token_map.get_replicas.side_effect = cluster.metadata.get_replicas - child_policy = TokenAwarePolicy(RoundRobinPolicy()) + child_policy = TokenAwarePolicy(RoundRobinPolicy(), shuffle_replicas=False) hfp = HostFilterPolicy( child_policy=child_policy,