hello") is False
+
+ def test_self_closing(self):
+ assert helper.is_valid_html_content("
") is True
+
+ def test_empty_string(self):
+ assert helper.is_valid_html_content("") is False
+
+ def test_whitespace_only(self):
+ assert helper.is_valid_html_content(" ") is False
+
+ def test_nested_valid(self):
+ assert helper.is_valid_html_content("
") is True
+
+
+class TestHtmlToJsonSafe:
+ def test_roundtrip(self):
+ original = '
Hello & world\n
'
+ safe = helper.html_to_json_safe(original)
+ recovered = helper.html_from_json_safe(safe)
+ assert recovered == original
+
+ def test_empty_input(self):
+ assert helper.html_to_json_safe("") == ""
+
+ def test_returns_string(self):
+ result = helper.html_to_json_safe("
hi
")
+ assert isinstance(result, str)
+
+
+class TestHtmlFromJsonSafe:
+ def test_empty_input(self):
+ assert helper.html_from_json_safe("") == ""
+
+ def test_unescapes_quotes(self):
+ safe = helper.html_to_json_safe('
link')
+ result = helper.html_from_json_safe(safe)
+ assert '"url"' in result
+
+
+class TestDatetimeString:
+ def test_returns_formatted_string(self):
+ dt = datetime(2025, 1, 15, 10, 30, 45)
+ result = helper.datetime_string(dt)
+ assert result == "2025-01-15--10-30-45"
+
+ def test_custom_format(self):
+ dt = datetime(2025, 6, 1)
+ result = helper.datetime_string(dt, format="%Y/%m/%d")
+ assert result == "2025/06/01"
+
+
+class TestConvertDatetimeFormat:
+ def test_datetime_object_returns_string(self):
+ dt = datetime(2025, 3, 15, 12, 0, 0, tzinfo=pytz.UTC)
+ result = helper.convert_datetime_format(dt)
+ assert isinstance(result, str)
+ assert len(result) > 0
+
+ def test_empty_string_returns_empty(self):
+ assert helper.convert_datetime_format("") == ""
+
+ def test_invalid_string_returns_empty(self):
+ assert helper.convert_datetime_format("not-a-date") == ""
+
+ def test_date_only(self):
+ dt = datetime(2025, 3, 15, 12, 0, 0, tzinfo=pytz.UTC)
+ result = helper.convert_datetime_format(dt, include_time=False)
+ assert "2025" in result
+ assert ":" not in result
+
+
+class TestCompareSemver:
+ def test_equal(self):
+ assert helper.compare_semver("1.2.3", "1.2.3") == 0
+
+ def test_less_than(self):
+ assert helper.compare_semver("1.0.0", "2.0.0") == -1
+
+ def test_greater_than(self):
+ assert helper.compare_semver("2.0.0", "1.9.9") == 1
+
+ def test_patch_difference(self):
+ assert helper.compare_semver("1.0.1", "1.0.0") == 1
+
+ def test_minor_difference(self):
+ assert helper.compare_semver("1.1.0", "1.2.0") == -1
diff --git a/tests/unit/test_html_to_markdown.py b/tests/unit/test_html_to_markdown.py
new file mode 100644
index 0000000..45e9cde
--- /dev/null
+++ b/tests/unit/test_html_to_markdown.py
@@ -0,0 +1,153 @@
+"""Tests for the html_to_markdown module."""
+
+
+from ruf_common.html_to_markdown import html_to_markdown
+
+
+class TestHtmlToMarkdownBasic:
+ def test_empty_string_returns_empty(self):
+ assert html_to_markdown("") == ""
+
+ def test_plain_text_passthrough(self):
+ result = html_to_markdown("Hello world")
+ assert "Hello world" in result
+
+ def test_whitespace_only_returns_empty_marker(self):
+ result = html_to_markdown("
")
+ assert result == "_Empty_" or result.strip() == ""
+
+
+class TestHeaders:
+ def test_h1(self):
+ result = html_to_markdown("
Title
")
+ assert result.startswith("# Title")
+
+ def test_h2(self):
+ result = html_to_markdown("
Sub
")
+ assert result.startswith("## Sub")
+
+ def test_h6(self):
+ result = html_to_markdown("
Tiny
")
+ assert result.startswith("###### Tiny")
+
+
+class TestInlineFormatting:
+ def test_bold_strong(self):
+ result = html_to_markdown("
bold")
+ assert "**bold**" in result
+
+ def test_bold_b_tag(self):
+ result = html_to_markdown("
bold")
+ assert "**bold**" in result
+
+ def test_italic_em(self):
+ result = html_to_markdown("
italic")
+ assert "*italic*" in result
+
+ def test_italic_i_tag(self):
+ result = html_to_markdown("
italic")
+ assert "*italic*" in result
+
+ def test_inline_code(self):
+ result = html_to_markdown("
x = 1")
+ assert "`x = 1`" in result
+
+ def test_link(self):
+ result = html_to_markdown('
click')
+ assert "[click](https://example.com)" in result
+
+
+class TestLineBreakAndHr:
+ def test_br_becomes_newline(self):
+ result = html_to_markdown("line1
line2")
+ assert "line1" in result
+ assert "line2" in result
+
+ def test_hr_becomes_dashes(self):
+ result = html_to_markdown("
")
+ assert "---" in result
+
+
+class TestLists:
+ def test_unordered_list(self):
+ html = "
"
+ result = html_to_markdown(html)
+ assert "- one" in result
+ assert "- two" in result
+
+ def test_ordered_list(self):
+ html = "
- first
- second
"
+ result = html_to_markdown(html)
+ assert "1. first" in result
+ assert "2. second" in result
+
+
+class TestParagraph:
+ def test_paragraph_content_preserved(self):
+ result = html_to_markdown("
Hello paragraph
")
+ assert "Hello paragraph" in result
+
+
+class TestCodeBlock:
+ def test_pre_becomes_fenced(self):
+ result = html_to_markdown("
code here
")
+ assert "```" in result
+ assert "code here" in result
+
+
+class TestBlockquote:
+ def test_blockquote_prefix(self):
+ result = html_to_markdown("
quoted text
")
+ assert "> " in result
+ assert "quoted text" in result
+
+
+class TestImage:
+ def test_image_with_alt(self):
+ result = html_to_markdown('

')
+ assert "" in result
+
+ def test_image_without_alt(self):
+ result = html_to_markdown('

')
+ assert "img.png" in result
+
+
+class TestTable:
+ def test_table_with_headers(self):
+ html = (
+ "
"
+ )
+ result = html_to_markdown(html)
+ assert "| A |" in result or "| A" in result
+ assert "---" in result
+
+ def test_table_rows(self):
+ html = (
+ "
"
+ )
+ result = html_to_markdown(html)
+ assert "val" in result
+
+
+class TestHtmlEntities:
+ def test_ampersand_entity(self):
+ result = html_to_markdown("
a & b
")
+ assert "a & b" in result
+
+ def test_nbsp_entity(self):
+ result = html_to_markdown("
hello world
")
+ assert "hello" in result and "world" in result
+
+
+class TestStripRemainingTags:
+ def test_unknown_tags_removed(self):
+ result = html_to_markdown("
text
")
+ assert "
" not in result
+ assert "
" not in result
+ assert "text" in result
diff --git a/tests/test_imports.py b/tests/unit/test_imports.py
similarity index 98%
rename from tests/test_imports.py
rename to tests/unit/test_imports.py
index 162f149..8bde192 100644
--- a/tests/test_imports.py
+++ b/tests/unit/test_imports.py
@@ -6,7 +6,7 @@
import pytest
-PACKAGE_DIR = Path(__file__).resolve().parent.parent / "ruf_common"
+PACKAGE_DIR = Path(__file__).resolve().parent.parent.parent / "ruf_common"
# Internal sub-modules that are not part of the public API
# (imported by other modules, not re-exported from __init__.py)
diff --git a/tests/unit/test_lfs.py b/tests/unit/test_lfs.py
new file mode 100644
index 0000000..4ad8c41
--- /dev/null
+++ b/tests/unit/test_lfs.py
@@ -0,0 +1,164 @@
+"""Tests for the lfs (local file system) module."""
+
+import json
+import os
+import zipfile
+
+from ruf_common import lfs
+
+
+class TestChkfile:
+ def test_existing_file(self, tmp_path):
+ f = tmp_path / "test.txt"
+ f.write_text("hello")
+ assert lfs.chkfile(str(f)) is True
+
+ def test_missing_file(self, tmp_path):
+ assert lfs.chkfile(str(tmp_path / "no_such_file.txt")) is False
+
+
+class TestChkdir:
+ def test_existing_dir(self, tmp_path):
+ assert lfs.chkdir(str(tmp_path)) is True
+
+ def test_missing_dir(self, tmp_path):
+ assert lfs.chkdir(str(tmp_path / "nonexistent")) is False
+
+ def test_make_if_not_present(self, tmp_path):
+ new_dir = str(tmp_path / "newdir")
+ assert lfs.chkdir(new_dir, make_if_not_present=True) is True
+ assert os.path.isdir(new_dir)
+
+
+class TestMkdir:
+ def test_creates_directory(self, tmp_path):
+ new_dir = str(tmp_path / "created")
+ result = lfs.mkdir(new_dir)
+ assert result is True
+ assert os.path.isdir(new_dir)
+
+ def test_already_exists(self, tmp_path):
+ result = lfs.mkdir(str(tmp_path))
+ assert result is True
+
+
+class TestPutfile:
+ def test_writes_content(self, tmp_path):
+ f = str(tmp_path / "out.txt")
+ result = lfs.putfile(f, "hello world")
+ assert result is True
+ assert open(f).read() == "hello world"
+
+ def test_returns_false_on_invalid_path(self):
+ result = lfs.putfile("/no_such_dir/file.txt", "data")
+ assert result is False
+
+
+class TestGetfile:
+ def test_reads_text_file(self, tmp_path):
+ f = tmp_path / "test.txt"
+ f.write_text("content here", encoding="utf-8")
+ result = lfs.getfile(str(f))
+ assert "content here" in result
+
+ def test_missing_file_returns_empty(self, tmp_path):
+ result = lfs.getfile(str(tmp_path / "missing.txt"))
+ assert result == ""
+
+ def test_normalize_false_may_return_bytes(self, tmp_path):
+ f = tmp_path / "test.txt"
+ f.write_bytes(b"raw bytes")
+ result = lfs.getfile(str(f), normalize=False)
+ assert result is not None
+
+
+class TestGetJson:
+ def test_reads_valid_json(self, tmp_path):
+ f = tmp_path / "data.json"
+ f.write_text('{"key": "val"}', encoding="utf-8")
+ result = lfs.get_json(str(f))
+ assert result == {"key": "val"}
+
+ def test_missing_file_returns_empty_dict(self, tmp_path):
+ result = lfs.get_json(str(tmp_path / "missing.json"))
+ assert result == {}
+
+ def test_invalid_json_returns_empty_dict(self, tmp_path):
+ f = tmp_path / "bad.json"
+ f.write_text("{bad json}", encoding="utf-8")
+ result = lfs.get_json(str(f))
+ assert result == {}
+
+
+class TestSaveJson:
+ def test_writes_json_file(self, tmp_path):
+ f = str(tmp_path / "out.json")
+ result = lfs.save_json({"a": 1}, f)
+ assert result is True
+ loaded = json.loads(open(f).read())
+ assert loaded == {"a": 1}
+
+ def test_json_is_indented(self, tmp_path):
+ f = str(tmp_path / "out.json")
+ lfs.save_json({"x": 1}, f)
+ content = open(f).read()
+ assert "\n" in content
+
+
+class TestGetjsonfile:
+ def test_reads_json_file(self, tmp_path):
+ f = tmp_path / "data.json"
+ f.write_text('{"k": "v"}', encoding="utf-8")
+ result = lfs.getjsonfile(str(f))
+ assert result == {"k": "v"}
+
+ def test_missing_returns_empty(self, tmp_path):
+ result = lfs.getjsonfile(str(tmp_path / "nope.json"))
+ assert result == {}
+
+
+class TestBackupFile:
+ def test_creates_backup(self, tmp_path):
+ f = tmp_path / "config.json"
+ f.write_text('{"a": 1}', encoding="utf-8")
+ result = lfs.backup_file(str(f))
+ assert result is True
+ backups = [p for p in tmp_path.iterdir() if "config_" in p.name]
+ assert len(backups) == 1
+
+ def test_backup_has_same_content(self, tmp_path):
+ f = tmp_path / "config.json"
+ f.write_text('{"a": 1}', encoding="utf-8")
+ lfs.backup_file(str(f))
+ backups = [p for p in tmp_path.iterdir() if "config_" in p.name]
+ assert backups[0].read_text() == '{"a": 1}'
+
+ def test_missing_file_returns_false(self, tmp_path):
+ result = lfs.backup_file(str(tmp_path / "nonexistent.txt"))
+ assert result is False
+
+
+class TestZipFile:
+ def test_zips_single_file(self, tmp_path):
+ src = tmp_path / "source.txt"
+ src.write_text("hello")
+ zip_path = str(tmp_path / "out.zip")
+ result = lfs.zip_file(str(src), zip_path)
+ assert result is True
+ assert zipfile.is_zipfile(zip_path)
+
+ def test_overwrite_false_blocked_if_exists(self, tmp_path):
+ src = tmp_path / "source.txt"
+ src.write_text("hello")
+ zip_path = str(tmp_path / "out.zip")
+ lfs.zip_file(str(src), zip_path)
+ result = lfs.zip_file(str(src), zip_path, overwrite=False)
+ assert result is None or result is False
+
+ def test_overwrite_true_replaces(self, tmp_path):
+ src = tmp_path / "source.txt"
+ src.write_text("hello")
+ zip_path = str(tmp_path / "out.zip")
+ lfs.zip_file(str(src), zip_path)
+ result = lfs.zip_file(str(src), zip_path, overwrite=True)
+ assert result is True
diff --git a/tests/unit/test_logging_module.py b/tests/unit/test_logging_module.py
new file mode 100644
index 0000000..a1d587f
--- /dev/null
+++ b/tests/unit/test_logging_module.py
@@ -0,0 +1,120 @@
+"""Tests for the logging module (DictSink and LoggableMixin)."""
+
+import pytest
+from loguru import logger
+
+from ruf_common.logging import DictSink, LoggableMixin
+
+
+class TestDictSink:
+ @pytest.fixture
+ def sink(self):
+ s = DictSink()
+ handler_id = logger.add(s, format="{message}", level="DEBUG")
+ yield s
+ logger.remove(handler_id)
+
+ def test_captures_log_record(self, sink):
+ logger.info("test message")
+ records = sink.get_records()
+ assert any(r["message"] == "test message" for r in records)
+
+ def test_record_has_required_fields(self, sink):
+ logger.warning("check fields")
+ records = sink.get_records()
+ assert len(records) > 0
+ record = records[-1]
+ assert "timestamp" in record
+ assert "level" in record
+ assert "message" in record
+ assert "module" in record
+ assert "function" in record
+ assert "line" in record
+
+ def test_level_captured_correctly(self, sink):
+ logger.error("an error")
+ records = sink.get_records()
+ error_records = [r for r in records if r["message"] == "an error"]
+ assert len(error_records) == 1
+ assert error_records[0]["level"] == "ERROR"
+
+ def test_get_records_returns_copy(self, sink):
+ logger.info("msg")
+ r1 = sink.get_records()
+ r2 = sink.get_records()
+ assert r1 == r2
+ r1.clear()
+ assert len(sink.get_records()) > 0 # original unaffected
+
+ def test_clear_removes_records(self, sink):
+ logger.info("something")
+ sink.clear()
+ assert sink.get_records() == []
+
+ def test_multiple_records(self, sink):
+ logger.info("one")
+ logger.info("two")
+ logger.info("three")
+ records = sink.get_records()
+ messages = [r["message"] for r in records]
+ assert "one" in messages
+ assert "two" in messages
+ assert "three" in messages
+
+
+class TestLoggableMixin:
+ class _MyClass(LoggableMixin):
+ pass
+
+ @pytest.fixture
+ def obj(self):
+ instance = self._MyClass()
+ yield instance
+ try:
+ instance.cleanup_logging()
+ except ValueError:
+ pass # handlers already removed in the test
+
+ def test_setup_dict_mode(self, obj):
+ obj.setup_logging(log_mode="dict")
+ assert obj._dict_sink is not None
+
+ def test_get_logs_returns_list(self, obj):
+ obj.setup_logging(log_mode="dict")
+ logs = obj.get_logs()
+ assert isinstance(logs, list)
+
+ def test_clear_logs(self, obj):
+ obj.setup_logging(log_mode="dict")
+ logger.debug("captured")
+ obj.clear_logs()
+ assert obj.get_logs() == []
+
+ def test_get_logs_without_dict_mode_returns_empty(self, obj):
+ assert obj.get_logs() == []
+
+ def test_cleanup_removes_handlers(self, obj):
+ obj.setup_logging(log_mode="dict")
+ handler_ids_before = list(obj._handler_ids)
+ obj.cleanup_logging()
+ assert obj._handler_ids == handler_ids_before # ids still stored, but removed from loguru
+
+ def test_file_mode_requires_log_file(self, obj):
+ with pytest.raises(ValueError):
+ obj.setup_logging(log_mode="file", log_file=None)
+
+ def test_file_mode_creates_file(self, obj, tmp_path):
+ log_file = str(tmp_path / "test.log")
+ obj.setup_logging(log_mode="file", log_file=log_file)
+ logger.info("file test")
+ obj.cleanup_logging()
+ assert os.path.exists(log_file)
+
+ def test_setup_reinitializes(self, obj):
+ obj.setup_logging(log_mode="dict")
+ first_id = list(obj._handler_ids)
+ obj.setup_logging(log_mode="dict")
+ assert obj._handler_ids != first_id or len(obj._handler_ids) > 0
+
+
+import os # noqa: E402 — needed for test_file_mode_creates_file
diff --git a/tests/unit/test_network.py b/tests/unit/test_network.py
new file mode 100644
index 0000000..2d3d23d
--- /dev/null
+++ b/tests/unit/test_network.py
@@ -0,0 +1,85 @@
+"""Tests for the network module (mocked external calls)."""
+
+from unittest.mock import MagicMock, patch
+
+from ruf_common import network
+
+
+class TestCheckInternetConnection:
+ def test_returns_bool(self):
+ result = network.check_internet_connection()
+ assert isinstance(result, bool)
+
+ def test_true_when_connection_succeeds(self):
+ with patch("socket.create_connection", return_value=MagicMock()):
+ result = network.check_internet_connection()
+ assert result is True
+
+ def test_false_when_socket_raises(self):
+ with patch("socket.create_connection", side_effect=OSError("unreachable")):
+ result = network.check_internet_connection()
+ assert result is False
+
+
+class TestApiGet:
+ def test_returns_response_on_success(self):
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ with patch("requests.get", return_value=mock_response):
+ result = network.api_get("https://example.com/api")
+ assert result is not None
+ assert result.status_code == 200
+
+ def test_returns_none_on_connection_error(self):
+ import requests
+ with patch("requests.get", side_effect=requests.exceptions.ConnectionError("no connection")):
+ result = network.api_get("https://example.com/api")
+ assert result is None
+
+ def test_returns_none_on_timeout(self):
+ import requests
+ with patch("requests.get", side_effect=requests.exceptions.Timeout("timeout")):
+ result = network.api_get("https://example.com/api")
+ assert result is None
+
+ def test_passes_headers(self):
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ with patch("requests.get", return_value=mock_response) as mock_get:
+ custom_headers = {"Authorization": "Bearer token"}
+ network.api_get("https://example.com/api", custom_headers)
+ call_kwargs = mock_get.call_args
+ assert call_kwargs is not None
+
+
+class TestDownloadFile:
+ def test_returns_content_on_success(self):
+ mock_response = MagicMock()
+ mock_response.content = b"file content"
+ with patch("requests.get", return_value=mock_response):
+ result = network.download_file("https://example.com/file.bin", "output.bin")
+ assert result is not None
+
+ def test_returns_empty_string_on_error(self):
+ import requests
+ with patch("requests.get", side_effect=requests.exceptions.RequestException("error")):
+ result = network.download_file("https://example.com/file.bin", "output.bin")
+ assert result == ""
+
+
+class TestAsyncApiGet:
+ def test_is_coroutine(self):
+ import inspect
+ assert inspect.iscoroutinefunction(network.async_api_get)
+
+ def test_returns_none_on_error(self):
+ import asyncio
+
+ async def run():
+ with patch("aiohttp.ClientSession") as mock_session:
+ mock_session.return_value.__aenter__ = MagicMock(side_effect=Exception("network error"))
+ mock_session.return_value.__aexit__ = MagicMock(return_value=False)
+ return await network.async_api_get("https://example.com/api")
+
+ result = asyncio.run(run())
+ assert result is None
diff --git a/tests/unit/test_stats.py b/tests/unit/test_stats.py
new file mode 100644
index 0000000..52f1e77
--- /dev/null
+++ b/tests/unit/test_stats.py
@@ -0,0 +1,77 @@
+"""Tests for the stats module."""
+
+import pytest
+
+import ruf_common.stats as stats_module
+
+
+@pytest.fixture(autouse=True)
+def reset_stats():
+ """Reset the global stats dict before each test."""
+ stats_module.stats.clear()
+ yield
+ stats_module.stats.clear()
+
+
+class TestIncrementStat:
+ def test_creates_new_stat(self):
+ stats_module.increment_stat("hits")
+ assert stats_module.stats["hits"] == 1
+
+ def test_increments_existing_stat(self):
+ stats_module.increment_stat("hits")
+ stats_module.increment_stat("hits")
+ assert stats_module.stats["hits"] == 2
+
+ def test_custom_increment(self):
+ stats_module.increment_stat("bytes", 100)
+ assert stats_module.stats["bytes"] == 100
+
+ def test_multiple_stats_independent(self):
+ stats_module.increment_stat("a")
+ stats_module.increment_stat("b")
+ stats_module.increment_stat("b")
+ assert stats_module.stats["a"] == 1
+ assert stats_module.stats["b"] == 2
+
+
+class TestGetStat:
+ def test_returns_zero_for_missing(self):
+ assert stats_module.get_stat("nonexistent") == 0
+
+ def test_returns_correct_value(self):
+ stats_module.increment_stat("score", 5)
+ assert stats_module.get_stat("score") == 5
+
+ def test_after_multiple_increments(self):
+ stats_module.increment_stat("x")
+ stats_module.increment_stat("x")
+ stats_module.increment_stat("x")
+ assert stats_module.get_stat("x") == 3
+
+
+class TestStatsSummary:
+ def test_empty_stats_has_heading(self):
+ result = stats_module.stats_summary()
+ assert "Summary" in result
+
+ def test_custom_heading(self):
+ result = stats_module.stats_summary("My Report")
+ assert "My Report" in result
+
+ def test_contains_stat_names_and_values(self):
+ stats_module.increment_stat("errors", 3)
+ result = stats_module.stats_summary()
+ assert "errors" in result
+ assert "3" in result
+
+ def test_empty_heading_no_label(self):
+ result = stats_module.stats_summary(heading="")
+ assert "Summary" not in result
+
+ def test_multiple_stats_all_present(self):
+ stats_module.increment_stat("a", 1)
+ stats_module.increment_stat("b", 2)
+ result = stats_module.stats_summary()
+ assert "a" in result
+ assert "b" in result
diff --git a/tests/unit/test_timezone_lookup.py b/tests/unit/test_timezone_lookup.py
new file mode 100644
index 0000000..5277228
--- /dev/null
+++ b/tests/unit/test_timezone_lookup.py
@@ -0,0 +1,95 @@
+"""Tests for the timezone_lookup module (external calls mocked)."""
+
+from unittest.mock import MagicMock, patch
+
+from ruf_common import timezone_lookup
+
+# The module imports Nominatim and TimezoneFinder at the top level, so we must
+# patch them within the ruf_common.timezone_lookup namespace.
+_NOMINATIM = "ruf_common.timezone_lookup.Nominatim"
+_TIMEZONEFINDER = "ruf_common.timezone_lookup.TimezoneFinder"
+_TIME_SLEEP = "ruf_common.timezone_lookup.time.sleep"
+
+
+def _make_location(lat=40.7128, lng=-74.0060):
+ loc = MagicMock()
+ loc.latitude = lat
+ loc.longitude = lng
+ return loc
+
+
+class TestLookupTimezone:
+ def test_returns_timezone_string_on_success(self):
+ mock_location = _make_location()
+
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf:
+ mock_nom.return_value.geocode.return_value = mock_location
+ mock_tf.return_value.timezone_at.return_value = "America/New_York"
+
+ result = timezone_lookup.lookup_timezone("New York", "US")
+ assert result == "America/New_York"
+
+ def test_returns_none_when_city_not_found(self):
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER):
+ mock_nom.return_value.geocode.return_value = None
+
+ result = timezone_lookup.lookup_timezone("Nonexistentville", "XX")
+ assert result is None
+
+ def test_returns_none_on_geocoding_exception(self):
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER):
+ mock_nom.return_value.geocode.side_effect = Exception("geocode error")
+
+ result = timezone_lookup.lookup_timezone("Berlin", "DE")
+ assert result is None
+
+ def test_returns_none_when_no_timezone_found(self):
+ mock_location = _make_location(lat=0.0, lng=0.0)
+
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf:
+ mock_nom.return_value.geocode.return_value = mock_location
+ mock_tf.return_value.timezone_at.return_value = None
+
+ result = timezone_lookup.lookup_timezone("Nowhere", "ZZ")
+ assert result is None
+
+
+class TestLookupTimezoneBatch:
+ def test_returns_dict(self):
+ mock_location = _make_location(lat=51.5074, lng=-0.1278)
+
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf, \
+ patch(_TIME_SLEEP):
+ mock_nom.return_value.geocode.return_value = mock_location
+ mock_tf.return_value.timezone_at.return_value = "Europe/London"
+
+ result = timezone_lookup.lookup_timezone_batch([("London", "UK")])
+ assert isinstance(result, dict)
+ assert result.get("London, UK") == "Europe/London"
+
+ def test_empty_list_returns_empty_dict(self):
+ result = timezone_lookup.lookup_timezone_batch([])
+ assert result == {}
+
+ def test_failed_lookup_stored_as_none(self):
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER), \
+ patch(_TIME_SLEEP):
+ mock_nom.return_value.geocode.return_value = None
+
+ result = timezone_lookup.lookup_timezone_batch([("Nowhere", "ZZ")])
+ assert result.get("Nowhere, ZZ") is None
+
+ def test_multiple_locations(self):
+ def geocode_side_effect(query, timeout=10):
+ return _make_location()
+
+ with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf, \
+ patch(_TIME_SLEEP):
+ mock_nom.return_value.geocode.side_effect = geocode_side_effect
+ mock_tf.return_value.timezone_at.return_value = "UTC"
+
+ result = timezone_lookup.lookup_timezone_batch([
+ ("City1", "C1"),
+ ("City2", "C2"),
+ ])
+ assert len(result) == 2
diff --git a/tests/unit/test_xml_formatter.py b/tests/unit/test_xml_formatter.py
new file mode 100644
index 0000000..5074bba
--- /dev/null
+++ b/tests/unit/test_xml_formatter.py
@@ -0,0 +1,145 @@
+"""Tests for the xml_formatter module."""
+
+
+import pytest
+
+from ruf_common import xml_formatter
+
+SIMPLE_XML = "text"
+MULTI_CHILD_XML = "123"
+
+
+class TestFormatXmlString:
+ def test_returns_string(self):
+ result = xml_formatter.format_xml_string(SIMPLE_XML)
+ assert isinstance(result, str)
+
+ def test_output_is_valid_xml(self):
+ import xml.etree.ElementTree as ET
+ result = xml_formatter.format_xml_string(SIMPLE_XML)
+ # Should parse without error
+ ET.fromstring(result.strip())
+
+ def test_adds_indentation(self):
+ result = xml_formatter.format_xml_string(SIMPLE_XML)
+ assert "\n" in result
+
+ def test_preserves_content(self):
+ result = xml_formatter.format_xml_string(SIMPLE_XML)
+ assert "text" in result
+ assert "child" in result
+ assert "root" in result
+
+ def test_invalid_xml_raises(self):
+ with pytest.raises((ValueError, Exception)):
+ xml_formatter.format_xml_string("")
+
+ def test_custom_line_wrap(self):
+ result = xml_formatter.format_xml_string(SIMPLE_XML, line_wrap_column=120)
+ assert result is not None
+
+ def test_ends_with_newline(self):
+ result = xml_formatter.format_xml_string(SIMPLE_XML)
+ assert result.endswith("\n")
+
+
+class TestFormatXmlFileToString:
+ @pytest.fixture
+ def xml_file(self, tmp_path):
+ f = tmp_path / "test.xml"
+ f.write_text(SIMPLE_XML, encoding="utf-8")
+ return str(f)
+
+ def test_returns_formatted_string(self, xml_file):
+ result = xml_formatter.format_xml_file_to_string(xml_file)
+ assert isinstance(result, str)
+ assert "root" in result
+
+ def test_does_not_modify_file(self, xml_file):
+ original = open(xml_file).read()
+ xml_formatter.format_xml_file_to_string(xml_file)
+ assert open(xml_file).read() == original
+
+
+class TestFormatXmlFileProgrammatic:
+ @pytest.fixture
+ def xml_file(self, tmp_path):
+ f = tmp_path / "test.xml"
+ f.write_text(SIMPLE_XML, encoding="utf-8")
+ return str(f)
+
+ def test_in_place_false_returns_string(self, xml_file):
+ result = xml_formatter.format_xml_file_programmatic(xml_file, in_place=False)
+ assert isinstance(result, str)
+ assert "root" in result
+
+ def test_in_place_false_does_not_modify(self, xml_file):
+ original = open(xml_file).read()
+ xml_formatter.format_xml_file_programmatic(xml_file, in_place=False)
+ assert open(xml_file).read() == original
+
+ def test_in_place_true_modifies_file(self, xml_file):
+ result = xml_formatter.format_xml_file_programmatic(xml_file, in_place=True)
+ assert result is True
+ content = open(xml_file).read()
+ assert "\n" in content
+
+
+class TestFormatXmlFolder:
+ @pytest.fixture
+ def xml_dir(self, tmp_path):
+ (tmp_path / "a.xml").write_text(SIMPLE_XML, encoding="utf-8")
+ (tmp_path / "b.xml").write_text(MULTI_CHILD_XML, encoding="utf-8")
+ (tmp_path / "other.txt").write_text("not xml", encoding="utf-8")
+ return str(tmp_path)
+
+ def test_in_place_true_returns_dict_of_bools(self, xml_dir):
+ result = xml_formatter.format_xml_folder(xml_dir, in_place=True)
+ assert isinstance(result, dict)
+ assert len(result) == 2
+ for v in result.values():
+ assert v is True
+
+ def test_in_place_false_returns_dict_of_strings(self, xml_dir):
+ result = xml_formatter.format_xml_folder(xml_dir, in_place=False)
+ assert isinstance(result, dict)
+ for v in result.values():
+ assert isinstance(v, str)
+
+ def test_only_xml_files_processed(self, xml_dir):
+ result = xml_formatter.format_xml_folder(xml_dir, in_place=False)
+ for path in result.keys():
+ assert path.endswith(".xml")
+
+
+class TestFindXmlFiles:
+ @pytest.fixture
+ def dir_with_files(self, tmp_path):
+ (tmp_path / "a.xml").write_text(SIMPLE_XML)
+ (tmp_path / "b.xml").write_text(SIMPLE_XML)
+ (tmp_path / "c.txt").write_text("text")
+ sub = tmp_path / "sub"
+ sub.mkdir()
+ (sub / "d.xml").write_text(SIMPLE_XML)
+ return str(tmp_path)
+
+ def test_finds_xml_in_dir(self, dir_with_files):
+ result = xml_formatter.find_xml_files(dir_with_files)
+ assert len(result) == 2
+
+ def test_recursive_finds_all(self, dir_with_files):
+ result = xml_formatter.find_xml_files(dir_with_files, recursive=True)
+ assert len(result) == 3
+
+
+class TestWrapXmlElement:
+ def test_short_line_unchanged(self):
+ line = ''
+ result = xml_formatter.wrap_xml_element(line)
+ assert result == [line]
+
+ def test_long_line_wrapped(self):
+ long_line = ''
+ result = xml_formatter.wrap_xml_element(long_line)
+ assert isinstance(result, list)
+ assert len(result) >= 1