diff --git a/src/databricks/sql/result_set.py b/src/databricks/sql/result_set.py index 6c4c3a43a..54af3e534 100644 --- a/src/databricks/sql/result_set.py +++ b/src/databricks/sql/result_set.py @@ -306,10 +306,21 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": """ if size < 0: raise ValueError("size argument for fetchmany is %s but must be >= 0", size) + + # Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``. + # CloudFetchQueue may return a placeholder empty table whose schema does not + # match the real downloaded chunks; concatenating it would corrupt the result. + partial_result_chunks: List["pyarrow.Table"] = [] + zero_row_table: Optional["pyarrow.Table"] = None + n_remaining_rows = size + results = self.results.next_n_rows(size) - partial_result_chunks = [results] - n_remaining_rows = size - results.num_rows - self._next_row_index += results.num_rows + if results.num_rows == 0: + zero_row_table = results + else: + partial_result_chunks.append(results) + n_remaining_rows -= results.num_rows + self._next_row_index += results.num_rows while ( n_remaining_rows > 0 @@ -318,10 +329,14 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table": ): self._fill_results_buffer() partial_results = self.results.next_n_rows(n_remaining_rows) + if partial_results.num_rows == 0: + continue partial_result_chunks.append(partial_results) n_remaining_rows -= partial_results.num_rows self._next_row_index += partial_results.num_rows + if not partial_result_chunks: + partial_result_chunks.append(zero_row_table) return concat_table_chunks(partial_result_chunks) def fetchmany_columnar(self, size: int): @@ -351,15 +366,30 @@ def fetchmany_columnar(self, size: int): def fetchall_arrow(self) -> "pyarrow.Table": """Fetch all (remaining) rows of a query result, returning them as a PyArrow table.""" + # Hold 0-row chunks aside instead of appending them to ``partial_result_chunks``. + # CloudFetchQueue may return a placeholder empty table whose schema does not + # match the real downloaded chunks; concatenating it would corrupt the result. + partial_result_chunks: List = [] + zero_row_table: Optional["pyarrow.Table"] = None + results = self.results.remaining_rows() - self._next_row_index += results.num_rows - partial_result_chunks = [results] + if results.num_rows == 0: + zero_row_table = results + else: + partial_result_chunks.append(results) + self._next_row_index += results.num_rows + while not self.has_been_closed_server_side and self.has_more_rows: self._fill_results_buffer() partial_results = self.results.remaining_rows() + if partial_results.num_rows == 0: + continue partial_result_chunks.append(partial_results) self._next_row_index += partial_results.num_rows + if not partial_result_chunks: + partial_result_chunks.append(zero_row_table) + result_table = concat_table_chunks(partial_result_chunks) # If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table # Valid only for metadata commands result set diff --git a/tests/unit/test_fetches.py b/tests/unit/test_fetches.py index 7a0706838..e3a963355 100644 --- a/tests/unit/test_fetches.py +++ b/tests/unit/test_fetches.py @@ -14,6 +14,34 @@ from databricks.sql.result_set import ThriftResultSet +class _StubArrowQueue: + """Minimal queue that hands back a pre-built pyarrow.Table once. + + Used to inject a schemaless / wrong-schema placeholder that the real + ArrowQueue would never produce — this is what CloudFetchQueue emits + when ``self.table is None`` and ``schema_bytes`` is missing. + """ + + def __init__(self, table): + self._table = table + self._consumed = False + + def _take(self): + if self._consumed: + return self._table.slice(0, 0) + self._consumed = True + return self._table + + def next_n_rows(self, num_rows): + return self._take() + + def remaining_rows(self): + return self._take() + + def close(self): + pass + + @pytest.mark.skipif(pa is None, reason="PyArrow is not installed") class FetchTests(unittest.TestCase): """ @@ -110,6 +138,39 @@ def fetch_results( ) return rs + @staticmethod + def make_dummy_result_set_from_queue_list(queue_list, description=None): + """Like make_dummy_result_set_from_batch_list but yields pre-built queues. + + Lets tests inject queues whose returned tables have an arbitrary schema + (or no schema at all) — needed to reproduce the CloudFetch placeholder + case that ``ArrowQueue`` would never produce. + """ + queue_index = 0 + + def fetch_results(**_): + nonlocal queue_index + q = queue_list[queue_index] + queue_index += 1 + return q, queue_index < len(queue_list), 0 + + mock_thrift_backend = Mock(spec=ThriftDatabricksClient) + mock_thrift_backend.fetch_results = fetch_results + + rs = ThriftResultSet( + connection=Mock(), + execute_response=ExecuteResponse( + command_id=None, + status=None, + has_been_closed_server_side=False, + description=description or [], + lz4_compressed=True, + is_staging_operation=False, + ), + thrift_client=mock_thrift_backend, + ) + return rs + def assertEqualRowValues(self, actual, expected): self.assertEqual(len(actual) if actual else 0, len(expected) if expected else 0) for act, exp in zip(actual, expected): @@ -267,6 +328,68 @@ def test_fetchone_without_initial_results(self): dummy_result_set = self.make_dummy_result_set_from_batch_list(batch_list_2) self.assertEqual(dummy_result_set.fetchone(), None) + # Regression tests for fetchmany_arrow / fetchall_arrow handling of + # the schemaless CloudFetch placeholder. + def test_fetchall_arrow_drops_mismatched_empty_placeholder(self): + # First fetch_results() call hands back a 0-row placeholder whose + # schema does not match the real chunks. The second call + # hands back real data. + placeholder = pa.Table.from_pydict( + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) + ) + _, real_table = self.make_arrow_table([[1], [2], [3]]) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], + description=[("col0", "integer", None, None, None, None, None)], + ) + + result = rs.fetchall_arrow() + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ["col0"]) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_fetchall_arrow_all_empty_returns_zero_row_table(self): + # Every queue call returns the schemaless placeholder — the + # call site should fall back to zero_row_table without crashing. + placeholder = pa.Table.from_pydict({}) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder)], + ) + + result = rs.fetchall_arrow() + + self.assertIsInstance(result, pa.Table) + self.assertEqual(result.num_rows, 0) + + def test_fetchmany_arrow_drops_mismatched_empty_placeholder(self): + # See ``test_fetchall_arrow_drops_mismatched_empty_placeholder``. + placeholder = pa.Table.from_pydict( + {"stale_col": []}, schema=pa.schema({"stale_col": pa.string()}) + ) + _, real_table = self.make_arrow_table([[1], [2], [3]]) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder), _StubArrowQueue(real_table)], + description=[("col0", "integer", None, None, None, None, None)], + ) + + result = rs.fetchmany_arrow(3) + + self.assertEqual(result.num_rows, 3) + self.assertEqual(result.schema.names, ["col0"]) + self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) + + def test_fetchmany_arrow_all_empty_returns_zero_row_table(self): + placeholder = pa.Table.from_pydict({}) + rs = self.make_dummy_result_set_from_queue_list( + [_StubArrowQueue(placeholder)], + ) + + result = rs.fetchmany_arrow(10) + + self.assertIsInstance(result, pa.Table) + self.assertEqual(result.num_rows, 0) + if __name__ == "__main__": unittest.main()