Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,23 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
if size < 0:
raise ValueError("size argument for fetchmany is %s but must be >= 0", size)

# Hold 0-row chunks aside instead of concatenating them with real chunks.
# CloudFetchQueue may emit a placeholder empty table whose schema does
# not match the real downloaded chunks; pyarrow.concat_tables with
# promote_options="default" would silently merge it in as phantom
# columns.
partial_result_chunks: List["pyarrow.Table"] = []
zero_row_table: Optional["pyarrow.Table"] = None
n_remaining_rows = size

results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows
if results.num_rows == 0:
zero_row_table = results
else:
partial_result_chunks.append(results)
n_remaining_rows -= results.num_rows
self._next_row_index += results.num_rows

while (
n_remaining_rows > 0
Expand All @@ -1335,13 +1349,17 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = pyarrow.concat_tables(
[results, partial_results], promote_options="default"
)
if partial_results.num_rows == 0:
continue
partial_result_chunks.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results
if not partial_result_chunks:
return zero_row_table
return pyarrow.concat_tables(
partial_result_chunks, promote_options="default"
)

def merge_columnar(self, result1, result2):
"""
Expand Down Expand Up @@ -1387,18 +1405,31 @@ def fetchmany_columnar(self, size: int):

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
# See ``fetchmany_arrow`` for why 0-row chunks are held aside rather than
# concatenated with the real chunks.
partial_result_chunks: List["pyarrow.Table"] = []
zero_row_table: Optional["pyarrow.Table"] = None

results = self.results.remaining_rows()
self._next_row_index += results.num_rows
if results.num_rows == 0:
zero_row_table = results
else:
partial_result_chunks.append(results)
self._next_row_index += results.num_rows

while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = pyarrow.concat_tables(
[results, partial_results], promote_options="default"
)
if partial_results.num_rows == 0:
continue
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

return results
if not partial_result_chunks:
return zero_row_table
return pyarrow.concat_tables(
partial_result_chunks, promote_options="default"
)

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
Expand Down
126 changes: 126 additions & 0 deletions tests/unit/test_fetches.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@
from databricks.sql.utils import ExecuteResponse, ArrowQueue


class _StubArrowQueue:
"""Minimal queue that hands back a pre-built pyarrow.Table once.

Used to inject a placeholder whose schema differs from the real chunks —
what ``CloudFetchQueue._create_empty_table`` can produce when its
``schema_bytes`` are stale.
"""

def __init__(self, table):
self._table = table
self._consumed = False

def _take(self):
if self._consumed:
return self._table.slice(0, 0)
self._consumed = True
return self._table

def next_n_rows(self, num_rows):
return self._take()

def remaining_rows(self):
return self._take()

def close(self):
pass


class FetchTests(unittest.TestCase):
"""
Unit tests for checking the fetch logic.
Expand Down Expand Up @@ -98,6 +126,42 @@ def fetch_results(
)
return rs

@staticmethod
def make_dummy_result_set_from_queue_list(queue_list, description=None):
"""Like make_dummy_result_set_from_batch_list but yields pre-built queues.

Lets tests inject queues whose returned tables have an arbitrary schema
— needed to reproduce the CloudFetch placeholder case that ``ArrowQueue``
would never produce on its own.
"""
queue_index = 0

def fetch_results(**_):
nonlocal queue_index
q = queue_list[queue_index]
queue_index += 1
return q, queue_index < len(queue_list)

mock_thrift_backend = Mock()
mock_thrift_backend.fetch_results = fetch_results

rs = client.ResultSet(
connection=Mock(),
thrift_backend=mock_thrift_backend,
execute_response=ExecuteResponse(
status=None,
has_been_closed_server_side=False,
has_more_rows=True,
description=description or [],
lz4_compressed=Mock(),
command_handle=None,
arrow_queue=None,
arrow_schema_bytes=None,
is_staging_operation=False,
),
)
return rs

def assertEqualRowValues(self, actual, expected):
self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0)
for act, exp in zip(actual, expected):
Expand Down Expand Up @@ -255,6 +319,68 @@ def test_fetchone_without_initial_results(self):
dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2)
self.assertEqual(dummy_result_set.fetchone(), None)

# Regression tests for fetchmany_arrow / fetchall_arrow handling of
# the CloudFetch empty placeholder
def test_fetchall_arrow_drops_mismatched_empty_placeholder(self):
# First fetch_results call hands back a 0-row placeholder whose
# schema does not match the real chunks . The second
# call hands back the real data.
placeholder = pa.Table.from_pydict(
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
)
_, real_table = self.make_arrow_table([[1], [2], [3]])
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
description=[("col0", "integer", None, None, None, None, None)],
)

result = rs.fetchall_arrow()

self.assertEqual(result.num_rows, 3)
self.assertEqual(result.schema.names, ["col0"])
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])

def test_fetchall_arrow_all_empty_returns_zero_row_table(self):
# Every queue call returns the placeholder — the call site should
# fall back to ``zero_row_table`` and return a real pa.Table.
placeholder = pa.Table.from_pydict({})
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder)],
)

result = rs.fetchall_arrow()

self.assertIsInstance(result, pa.Table)
self.assertEqual(result.num_rows, 0)

def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self):
# See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``.
placeholder = pa.Table.from_pydict(
{"stale_col": []}, schema=pa.schema({"stale_col": pa.string()})
)
_, real_table = self.make_arrow_table([[1], [2], [3]])
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder), _StubArrowQueue(real_table)],
description=[("col0", "integer", None, None, None, None, None)],
)

result = rs.fetchmany_arrow(3)

self.assertEqual(result.num_rows, 3)
self.assertEqual(result.schema.names, ["col0"])
self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])

def test_fetchmany_arrow_all_empty_returns_zero_row_table(self):
placeholder = pa.Table.from_pydict({})
rs = self.make_dummy_result_set_from_queue_list(
[_StubArrowQueue(placeholder)],
)

result = rs.fetchmany_arrow(10)

self.assertIsInstance(result, pa.Table)
self.assertEqual(result.num_rows, 0)


if __name__ == "__main__":
unittest.main()
Loading