Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,15 +1219,19 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message,
# queue the decoder function with the request
# this allows us to inject custom functions per request to encode, decode messages
self._requests[request_id] = (cb, decoder, result_metadata)
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)
try:
msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor,
allow_beta_protocol_version=self.allow_beta_protocol_version)

if self._is_checksumming_enabled:
buffer = io.BytesIO()
self._segment_codec.encode(buffer, msg)
msg = buffer.getvalue()
if self._is_checksumming_enabled:
buffer = io.BytesIO()
self._segment_codec.encode(buffer, msg)
msg = buffer.getvalue()

self.push(msg)
self.push(msg)
except Exception:
self._requests.pop(request_id, None)
raise
return len(msg)

def wait_for_response(self, msg, timeout=None, **kwargs):
Expand Down Expand Up @@ -1262,9 +1266,16 @@ def wait_for_responses(self, *msgs, **kwargs):
self.in_flight += available

for i, request_id in enumerate(request_ids):
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
try:
self.send_msg(msgs[messages_sent + i],
request_id,
partial(waiter.got_response, index=messages_sent + i))
except Exception:
unsent_request_ids = request_ids[i:]
with self.lock:
self.in_flight -= len(unsent_request_ids)
self.request_ids.extend(unsent_request_ids)
raise
messages_sent += available

if messages_sent == len(msgs):
Expand Down
31 changes: 28 additions & 3 deletions tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from threading import Lock
from unittest.mock import Mock, ANY, call, patch

from cassandra import OperationTimedOut
from cassandra import ConsistencyLevel, OperationTimedOut
from cassandra.cluster import Cluster
from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError,
locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager,
ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator)
ConnectionException, ConnectionShutdown, ConnectionBusy, DefaultEndPoint, ShardAwarePortGenerator)
from cassandra.marshal import uint8_pack, uint32_pack, int32_pack
from cassandra.protocol import (write_stringmultimap, write_int, write_string,
SupportedMessage, ProtocolHandler, ResultMessage,
SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage,
RESULT_KIND_SET_KEYSPACE)

from tests.util import wait_until, assertRegex
Expand Down Expand Up @@ -363,6 +363,31 @@ def test_wait_for_responses_shutdown_includes_last_error(self):
assert "already closed" in error_message
assert "Bad file descriptor" in error_message

def test_wait_for_responses_releases_request_id_when_send_fails(self):
c = self.make_connection()
c._socket_writable = False
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)

with pytest.raises(ConnectionBusy):
c.wait_for_responses(Mock())

assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests

def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self):
c = self.make_connection()
c.push = Mock(side_effect=ConnectionException("write failed"))
initial_in_flight = c.in_flight
initial_request_ids = len(c.request_ids)

with pytest.raises(ConnectionException):
c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE))

assert c.in_flight == initial_in_flight
assert len(c.request_ids) == initial_request_ids
assert not c._requests

@patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped')
class ConnectionHeartbeatTest(unittest.TestCase):
Expand Down
Loading