From d03db7a40998fb8db0d8ab1ee06e33fdd681ac8b Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Thu, 28 May 2026 00:09:35 +0530 Subject: [PATCH] Made code changes --- src/databricks/sql/client.py | 53 ++++++++++++--- tests/unit/test_fetches.py | 126 +++++++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 59c74c6ba..9d4f828bf 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1324,9 +1324,23 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + # Hold 0-row chunks aside instead of concatenating them with real chunks. + # CloudFetchQueue may emit a placeholder empty table whose schema does + # not match the real downloaded chunks; pyarrow.concat_tables with + # promote_options="default" would silently merge it in as phantom + # columns. + partial_result_chunks: List["pyarrow.Table"] = [] + zero_row_table: Optional["pyarrow.Table"] = None + n_remaining_rows = size + results = self.results.next_n_rows(size) - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows + if results.num_rows == 0: + zero_row_table = results + else: + partial_result_chunks.append(results) + n_remaining_rows -= results.num_rows + self._next_row_index += results.num_rows while ( n_remaining_rows > 0 @@ -1335,13 +1349,17 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) - results = pyarrow.concat_tables( - [results, partial_results], promote_options="default" - ) + if partial_results.num_rows == 0: + continue + partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows - return results + if not partial_result_chunks: + return zero_row_table + return pyarrow.concat_tables( + partial_result_chunks, promote_options="default" + ) def merge_columnar(self, result1, result2): """ @@ -1387,18 +1405,31 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + # See ``fetchmany_arrow`` for why 0-row chunks are held aside rather than + # concatenated with the real chunks. + partial_result_chunks: List["pyarrow.Table"] = [] + zero_row_table: Optional["pyarrow.Table"] = None + results = self.results.remaining_rows() - self._next_row_index += results.num_rows + if results.num_rows == 0: + zero_row_table = results + else: + partial_result_chunks.append(results) + self._next_row_index += results.num_rows while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() - results = pyarrow.concat_tables( - [results, partial_results], promote_options="default" - ) + if partial_results.num_rows == 0: + continue + partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows - return results + if not partial_result_chunks: + return zero_row_table + return pyarrow.concat_tables( + partial_result_chunks, promote_options="default" + ) def fetchall_columnar(self): """Fetch all (remaining) rows of a query result, returning them as a Columnar table.""" diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 89cedcfae..d245169fd 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -7,6 +7,34 @@ from databricks.sql.utils import ExecuteResponse, ArrowQueue +class _StubArrowQueue: + """Minimal queue that hands back a pre-built pyarrow.Table once. + + Used to inject a placeholder whose schema differs from the real chunks — + what ``CloudFetchQueue._create_empty_table`` can produce when its + ``schema_bytes`` are stale. + """ + + def __init__(self, table): + self._table = table + self._consumed = False + + def _take(self): + if self._consumed: + return self._table.slice(0, 0) + self._consumed = True + return self._table + + def next_n_rows(self, num_rows): + return self._take() + + def remaining_rows(self): + return self._take() + + def close(self): + pass + + class FetchTests(unittest.TestCase): """ Unit tests for checking the fetch logic. @@ -98,6 +126,42 @@ def fetch_results( ) return rs + @staticmethod + def make_dummy_result_set_from_queue_list(queue_list, description=None): + """Like make_dummy_result_set_from_batch_list but yields pre-built queues. + + Lets tests inject queues whose returned tables have an arbitrary schema + — needed to reproduce the CloudFetch placeholder case that ``ArrowQueue`` + would never produce on its own. + """ + queue_index = 0 + + def fetch_results(**_): + nonlocal queue_index + q = queue_list[queue_index] + queue_index += 1 + return q, queue_index < len(queue_list) + + mock_thrift_backend = Mock() + mock_thrift_backend.fetch_results = fetch_results + + rs = client.ResultSet( + connection=Mock(), + thrift_backend=mock_thrift_backend, + execute_response=ExecuteResponse( + status=None, + has_been_closed_server_side=False, + has_more_rows=True, + description=description or [], + lz4_compressed=Mock(), + command_handle=None, + arrow_queue=None, + arrow_schema_bytes=None, + is_staging_operation=False, + ), + ) + return rs + def assertEqualRowValues(self, actual, expected): self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) for act, exp in zip(actual, expected): @@ -255,6 +319,68 @@ def test_fetchone_without_initial_results(self): dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) self.assertEqual(dummy_result_set.fetchone(), None) + # Regression tests for fetchmany_arrow / fetchall_arrow handling of + # the CloudFetch empty placeholder + def test_fetchall_arrow_drops_mismatched_empty_placeholder(self): + # First fetch_results call hands back a 0-row placeholder whose + # schema does not match the real chunks . The second + # call hands back the real data. + placeholder = pa.Table.from_pydict( + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) + ) + _, real_table = self.make_arrow_table([[1], [2], [3]]) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], + description=[("col0", "integer", None, None, None, None, None)], + ) + + result = rs.fetchall_arrow() + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ["col0"]) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_fetchall_arrow_all_empty_returns_zero_row_table(self): + # Every queue call returns the placeholder — the call site should + # fall back to ``zero_row_table`` and return a real pa.Table. + placeholder = pa.Table.from_pydict({}) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder)], + ) + + result = rs.fetchall_arrow() + + self.assertIsInstance(result, pa.Table) + self.assertEqual(result.num_rows, 0) + + def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self): + # See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``. + placeholder = pa.Table.from_pydict( + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) + ) + _, real_table = self.make_arrow_table([[1], [2], [3]]) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], + description=[("col0", "integer", None, None, None, None, None)], + ) + + result = rs.fetchmany_arrow(3) + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ["col0"]) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_fetchmany_arrow_all_empty_returns_zero_row_table(self): + placeholder = pa.Table.from_pydict({}) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder)], + ) + + result = rs.fetchmany_arrow(10) + + self.assertIsInstance(result, pa.Table) + self.assertEqual(result.num_rows, 0) + if __name__ == "__main__": unittest.main()