diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4acebc8..aba300a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,12 +7,33 @@ on: branches: [main, master] jobs: + lint: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install ruff + run: pip install ruff>=0.4.0 + + - name: Lint with ruff + run: ruff check ruf_common/ + test: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/.gitignore b/.gitignore index cf16bd6..d8bdffb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__/ .venv/ venv/ .vscode/ +.coverage +build/ diff --git a/README.md b/README.md index 567e095..0ed7121 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ pip install ruf-common from ruf_common import * # Or import specific modules -from ruf_common import data, helper, lfs +from ruf_common import data, helper, lfs # etc. ``` ## Modules @@ -45,3 +45,18 @@ The following modules are available: ## License MIT + +## Use of AI for Creating/Maintaining This Library + +**No portion of this library was "vibe coded".** + +Early versions of this library were written entirely without the use of AI tools. + +Claude/Claude Code and GitHub Co-pilot have been used in a manner similar to pair-programming. This includes: +- improving alignment with "pythonic" best practices +- targeted code reviews +- resolving linter issues +- aiding in debugging and testing +- drafting individual functions/methods that I refine and test +- drafting portions of documentation +- drafting unit tests diff --git a/docs/PUBLISHING.md b/docs/PUBLISHING.md index ada2eb4..ce6b7e6 100644 --- a/docs/PUBLISHING.md +++ b/docs/PUBLISHING.md @@ -34,7 +34,9 @@ git tag vX.Y.Z git push origin vX.Y.Z ``` -### 4. Create the GitHub Release +### 4. Merge with Main + +### 5. Create the GitHub Release 1. Go to [Releases → New release](https://github.com/brian-ruf/ruf-common-python/releases/new) 2. Select tag **`vX.Y.Z`** diff --git a/pyproject.toml b/pyproject.toml index 44b2a9c..bf29c1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "ruf-common" -version = "2.0.3" +version = "2.1.0" description = "Functions common to several of Brian's Python projects." requires-python = ">=3.9" license = "MIT" @@ -43,6 +43,7 @@ dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", "pytest-asyncio>=0.21.0", + "ruff>=0.4.0", ] [tool.setuptools.packages.find] @@ -65,3 +66,10 @@ exclude_lines = [ "pragma: no cover", "if __name__ == .__main__.:", ] + +[tool.ruff] +target-version = "py310" +line-length = 120 + +[tool.ruff.lint] +select = ["E9", "F"] diff --git a/ruf_common/country_code_converter.py b/ruf_common/country_code_converter.py index e0c8972..e066a2d 100644 --- a/ruf_common/country_code_converter.py +++ b/ruf_common/country_code_converter.py @@ -216,7 +216,7 @@ def country_name_to_code_api(country_name: str) -> str: except Exception: return country_name_to_code_fuzzy(country_name) -def demonstrate_conversion(): +def demonstrate_conversion() -> None: """Demonstrate different country code conversion methods.""" test_countries = [ diff --git a/ruf_common/data.py b/ruf_common/data.py index 2819a53..9ee2723 100644 --- a/ruf_common/data.py +++ b/ruf_common/data.py @@ -1,14 +1,17 @@ """ Functions for managing and manipulating XML, JSON and YAML content. """ +from __future__ import annotations import elementpath import xml.etree.ElementTree as ET from xml.etree.ElementTree import tostring from loguru import logger +from typing import Any, cast + # ------------------------------------------------------------------------- -def detect_data_format(content): +def detect_data_format(content: str) -> str: """Detect whether the content is XML, JSON, or YAML based on its starting characters.""" content = content.lstrip() # Remove leading whitespace @@ -25,7 +28,7 @@ def detect_data_format(content): # ------------------------------------------------------------------------- # ------------------------------------------------------------------------- -def safe_load(content, data_format=""): +def safe_load(content: str, data_format: str = "") -> object | None: """Check if the provided content string is well-formed based on its format.""" data_object = None if data_format == "": @@ -43,7 +46,7 @@ def safe_load(content, data_format=""): return data_object # ------------------------------------------------------------------------- -def safe_load_xml(content): +def safe_load_xml(content: str) -> ET.Element | None: """ Returns an XML tree if the provided XML string is well-formed. If not well-formed, returns None. @@ -64,7 +67,7 @@ def safe_load_xml(content): # ------------------------------------------------------------------------- -def safe_load_json(content): +def safe_load_json(content: str) -> dict | None: """ Returns a dict if the provided JSON string is well-formed. If not well-formed, returns None. @@ -82,7 +85,7 @@ def safe_load_json(content): return data_object # ------------------------------------------------------------------------- -def safe_load_yaml(content): +def safe_load_yaml(content: str) -> dict | None: """ Returns a dict if the provided YAML string is well-formed. If not well-formed, returns None. @@ -100,7 +103,7 @@ def safe_load_yaml(content): return data_object # ------------------------------------------------------------------------- -def xpath(tree, nsmap, xExpr, context=None): +def xpath(tree: Any, nsmap: dict, xExpr: str, context: ET.Element | None = None) -> object | None: """ Performs an xpath query either on the entire XML document or on a context within the document. @@ -143,7 +146,7 @@ def xpath(tree, nsmap, xExpr, context=None): return result # ------------------------------------------------------------------------- -def xpath_atomic(tree, nsmap, xExpr, context=None): +def xpath_atomic(tree: Any, nsmap: dict, xExpr: str, context: ET.Element | None = None) -> str: """ Performs an xpath query either on the entire XML document or on a context within the document. @@ -182,7 +185,7 @@ def xpath_atomic(tree, nsmap, xExpr, context=None): return str(ret_value) # ------------------------------------------------------------------------- -def remove_namespace(element): +def remove_namespace(element: ET.Element) -> None: """Remove namespace from an element and all its children""" # Remove namespace from this element if '}' in element.tag: @@ -200,7 +203,7 @@ def remove_namespace(element): # ------------------------------------------------------------------------- -def get_markup_content(tree, nsmap, xExpr, context=None): +def get_markup_content(tree: Any, nsmap: dict, xExpr: str, context: ET.Element | None = None) -> str: """ Get the content of a specific XML element using XPath, preserving HTML formatting. @@ -214,6 +217,7 @@ def get_markup_content(tree, nsmap, xExpr, context=None): The content of the element as a string with HTML preserved, or empty string if not found """ ret_value = "" + element = None try: # First, try to get the entire element (not just its children) @@ -244,7 +248,7 @@ def get_markup_content(tree, nsmap, xExpr, context=None): # Now we have the element, let's extract its complete content if hasattr(element, 'tag'): # This is an Element object - ret_value = extract_element_content(element) + ret_value = extract_element_content(cast(ET.Element, element)) else: # This might be a text node or something else ret_value = str(element) @@ -255,7 +259,7 @@ def get_markup_content(tree, nsmap, xExpr, context=None): return ret_value # ------------------------------------------------------------------------- -def xml_to_string(element): +def xml_to_string(element: Any) -> str: """Convert an XML element or list of elements to a string.""" import copy element_str = "" @@ -283,7 +287,7 @@ def xml_to_string(element): # ------------------------------------------------------------------------- -def extract_element_content(element): +def extract_element_content(element: ET.Element | None) -> str: """ Extract the complete inner content of an XML element, preserving all HTML formatting but removing namespaces. Handles both simple text content and complex mixed content. @@ -345,7 +349,7 @@ def extract_element_content(element): # ------------------------------------------------------------------------- -def remove_namespace_from_html(html_str): +def remove_namespace_from_html(html_str: str) -> str: """ Remove XML namespace declarations from HTML string. @@ -369,7 +373,7 @@ def remove_namespace_from_html(html_str): return html_str # ------------------------------------------------------------------------- -def deserialize_xml(xml_string, nsmap): +def deserialize_xml(xml_string: str, nsmap: str) -> ET.Element | None: """Deserialize an XML string into a Python dictionary.""" ret_value = None try: @@ -385,8 +389,7 @@ def deserialize_xml(xml_string, nsmap): return ret_value # ------------------------------------------------------------------------- -# ------------------------------------------------------------------------- -def get_attribute_value(element, attribute_name, default=""): +def get_attribute_value(element: ET.Element, attribute_name: str, default: str = "") -> str: """ Get the value of a specific attribute from an XML element. diff --git a/ruf_common/database.py b/ruf_common/database.py index 4984e72..b41b818 100644 --- a/ruf_common/database.py +++ b/ruf_common/database.py @@ -10,11 +10,11 @@ # ============================================================================= # TODO: Evaluate using SQLAlchemy for database handling # TODO: Handle additional database types beyond sqlite3 -# TODL: Ensure all sqlite3 specific code is in database_sqlite3.py +# TODO: Ensure all sqlite3 specific code is in database_sqlite3.py # ============================================================================= import sqlite3 import uuid as uuid_module -from typing import Optional +from typing import Any, Optional, Union from loguru import logger from . import helper from . import database_sqlite3 @@ -51,7 +51,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class Database: - def __init__(self, type, target): + def __init__(self, type: str, target: str) -> None: """ Creates a database object and opens the database. type: The type of database. Supported types: @@ -72,7 +72,7 @@ def __init__(self, type, target): self.open() # ------------------------------------------------------------------------- - def __str__(self): + def __str__(self) -> str: ret_val = "" ret_val += f"Database Type: {self.type}" ret_val += f"Database Target: {self.target}" @@ -80,12 +80,12 @@ def __str__(self): return ret_val # ------------------------------------------------------------------------- - def __del__(self): + def __del__(self) -> None: if self.conn: self.conn.close() # ------------------------------------------------------------------------- - def open(self): + def open(self) -> None: """Executes the correct open function/tasks based on the database type.""" if self.type == "sqlite3": @@ -96,7 +96,7 @@ def open(self): logger.error(f"Unsupported database type: {self.type}") # ------------------------------------------------------------------------- - def check_for_tables(self, tables): + def check_for_tables(self, tables: dict) -> bool: """ Check for the presence of the expected tables in the database. """ @@ -113,7 +113,7 @@ def check_for_tables(self, tables): return status # ------------------------------------------------------------------------- - def table_exists(self, name): + def table_exists(self, name: str) -> bool: """ Determines if a table exists in the database - name: A string containing the name of the table @@ -152,7 +152,7 @@ def table_exists(self, name): return status # ------------------------------------------------------------------------- - def record_count(self, table, where_clause): + def record_count(self, table: str, where_clause: str) -> int: """ Returns the number of records that include a value - table: A string containing the name of the table @@ -195,7 +195,7 @@ def record_count(self, table, where_clause): # ------------------------------------------------------------------------- # From: https://en.ittrip.xyz/python/sqlite-error-handling - def db_execute(self, SQL_statements): + def db_execute(self, SQL_statements: Union[str, list[str]]) -> bool: """Executes a list of SQL statements in a transaction.""" status = False @@ -235,7 +235,7 @@ def db_execute(self, SQL_statements): return status # ------------------------------------------------------------------------- - def query(self, SQL_statement): + def query(self, SQL_statement: str) -> list[dict]: """ Executes a query and returns the results. SQL_statement: The SQL statement to @@ -259,7 +259,7 @@ def query(self, SQL_statement): return results # ------------------------------------------------------------------------- - def create_table(self, table_definition): + def create_table(self, table_definition: dict) -> bool: """ Creates a table in the database. @@ -304,7 +304,7 @@ def create_table(self, table_definition): return status # ------------------------------------------------------------------------- - def insert(self, table_name, table_fields, table_blob_fields={}): + def insert(self, table_name: str, table_fields: dict, table_blob_fields: dict = {}) -> bool: """ Inserts a record into a table. table_name: String @@ -344,7 +344,7 @@ def insert(self, table_name, table_fields, table_blob_fields={}): return status # ------------------------------------------------------------------------- - def drop_table(self, table_name): + def drop_table(self, table_name: str) -> bool: """ Drops a table from the database. table_name: String @@ -357,7 +357,7 @@ def drop_table(self, table_name): return status # ------------------------------------------------------------------------- - def cache_file(self, content, uuid = None, attributes={}): + def cache_file(self, content: Any, uuid: Optional[str] = None, attributes: dict = {}) -> Union[str, bool]: """ Stores file content in the filecache table. content: The file contents to be cached @@ -392,7 +392,7 @@ def cache_file(self, content, uuid = None, attributes={}): return status # ------------------------------------------------------------------------- - def retrieve_file(self, uuid): + def retrieve_file(self, uuid: str) -> Any: """ Retrieves a file from the filecache table. uuid: The UUID of the file to be retrieved. @@ -418,7 +418,7 @@ def retrieve_file(self, uuid): return ret_value # ------------------------------------------------------------------------- - def retrieve_file_name(self, uuid): + def retrieve_file_name(self, uuid: str) -> Optional[str]: """ Retrieves the name of a file from the filecache table. uuid: The UUID of the file to be retrieved. @@ -438,7 +438,7 @@ def retrieve_file_name(self, uuid): return filename # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -def oscal_datatype(datatype): +def oscal_datatype(datatype: str) -> Optional[str]: """ Aligns the datatype to the OSCAL datatype. datatype: The datatype to be aligned @@ -446,7 +446,7 @@ def oscal_datatype(datatype): """ # ----------------------------------------------------------------------------- -def db_datatype(datatype, database_type): +def db_datatype(datatype: str, database_type: str) -> str: """ Aligns the datatype to the database type. datatype: The datatype to be aligned diff --git a/ruf_common/database_sqlite3.py b/ruf_common/database_sqlite3.py index 12c348a..a64912b 100644 --- a/ruf_common/database_sqlite3.py +++ b/ruf_common/database_sqlite3.py @@ -12,7 +12,7 @@ FILE_CACHE_TABLE = 'filecache' -def save_to_db(conn, table_name: str, content: Any, identifier: Optional[str] = None, +def save_to_db(conn: sqlite3.Connection, table_name: str, content: Any, identifier: Optional[str] = None, additional_fields: Optional[Dict] = None) -> str: """ Save content and additional fields to a SQLite database with type preservation. @@ -80,7 +80,7 @@ def save_to_db(conn, table_name: str, content: Any, identifier: Optional[str] = raise e -def get_from_db(conn, table_name: str, identifier: str) -> Any: +def get_from_db(conn: sqlite3.Connection, table_name: str, identifier: str) -> Any: """ Retrieve content from the database and restore its original Python data type. @@ -163,7 +163,7 @@ def get_from_db(conn, table_name: str, identifier: str) -> Any: conn.close() """ -def update_record_from_dict(conn, table_name: str, identifier: str, update_dict: Dict) -> bool: +def update_record_from_dict(conn: sqlite3.Connection, table_name: str, identifier: str, update_dict: Dict) -> bool: """ Update a record in the database with new values from a dictionary. Only non-BLOB fields will be updated. @@ -214,7 +214,7 @@ def update_record_from_dict(conn, table_name: str, identifier: str, update_dict: raise e -def get_record_metadata(conn, table_name: str, identifier: str) -> Dict: +def get_record_metadata(conn: sqlite3.Connection, table_name: str, identifier: str) -> Dict: """ Retrieve all non-BLOB fields from a record as a dictionary. @@ -256,7 +256,7 @@ def get_record_metadata(conn, table_name: str, identifier: str) -> Dict: raise e -def store_blob_to_db(conn, identifier: str, blob, attributes: dict) -> bool: +def store_blob_to_db(conn: sqlite3.Connection, identifier: str, blob: Any, attributes: dict) -> bool: """ Store a binary large object (BLOB) in the database. If the UUID exists, update the record. Otherwise, insert a new one. @@ -339,7 +339,7 @@ def store_blob_to_db(conn, identifier: str, blob, attributes: dict) -> bool: conn.rollback() raise e -def retrieve_blob_from_db(conn, identifier: str) -> Any: +def retrieve_blob_from_db(conn: sqlite3.Connection, identifier: str) -> Any: """ Retrieve a binary large object (BLOB) from the database. @@ -406,7 +406,7 @@ def retrieve_blob_from_db(conn, identifier: str) -> Any: raise e # ----------------------------------------------------------------------------- -def open_sqlite3(target): +def open_sqlite3(target: str) -> Optional[sqlite3.Connection]: """ Opens a SQLite3 database file. SQLite3 will automatically create the database if it does not exist. @@ -449,42 +449,41 @@ def open_sqlite3(target): # --- MAIN: Only runs if the module is executed stand-alone. --- # ============================================================================= if __name__ == '__main__': - print("SQLite3 Functions. Not intended to be run as a stand-alone file.") - + example_usage = """ + import sqlite3 -# Example usage: -""" -import sqlite3 + conn = sqlite3.connect('data.db') -conn = sqlite3.connect('data.db') - -try: - # First create a record with some metadata - content = ["example", "data"] - additional_fields = { - 'category': 'test', - 'created_by': 'user123', - 'status': 1, - 'notes': 'Sample record' - } - - # Save the record - record_id = save_to_db(conn, "my_table", content, - additional_fields=additional_fields) - - # Retrieve the metadata - metadata = get_record_metadata(conn, "my_table", record_id) - print("Record metadata:", metadata) - # Output might look like: - # { - # 'uuid': '123e4567-e89b-12d3-a456-426614174000', - # 'datatype': 'list', - # 'category': 'test', - # 'created_by': 'user123', - # 'status': 1, - # 'notes': 'Sample record' - # } - -finally: - conn.close() -""" + try: + # First create a record with some metadata + content = ["example", "data"] + additional_fields = { + 'category': 'test', + 'created_by': 'user123', + 'status': 1, + 'notes': 'Sample record' + } + + # Save the record + record_id = save_to_db(conn, "my_table", content, + additional_fields=additional_fields) + + # Retrieve the metadata + metadata = get_record_metadata(conn, "my_table", record_id) + print("Record metadata:", metadata) + # Output might look like: + # { + # 'uuid': '123e4567-e89b-12d3-a456-426614174000', + # 'datatype': 'list', + # 'category': 'test', + # 'created_by': 'user123', + # 'status': 1, + # 'notes': 'Sample record' + # } + + finally: + conn.close() + """ + print("SQLite3 Functions. Not intended to be run as a stand-alone file.") + print("Example usage:") + print(example_usage) diff --git a/ruf_common/helper.py b/ruf_common/helper.py index 7579342..16fbff5 100644 --- a/ruf_common/helper.py +++ b/ruf_common/helper.py @@ -19,14 +19,14 @@ import json import getpass as gt from loguru import logger -from typing import Dict, Any +from typing import Dict, Any, Union # ----------------------------------------------------------------------------- # ============================================================================= # DATE/TIME FUNCTIONS # ============================================================================= -def convert_datetime_format(date_input=datetime.now(), include_time=True, assume_localtime=True, format = "%Y-%m-%dT%H:%M:%SZ") -> str: +def convert_datetime_format(date_input: Union[str, datetime] = datetime.now(), include_time: bool = True, assume_localtime: bool = True, format: str = "%Y-%m-%dT%H:%M:%SZ") -> str: """ Converts various datetime inputs to a formatted date string. Handles both datetime objects and ISO 8601 datetime strings including: @@ -163,7 +163,7 @@ def convert_datetime_format(date_input=datetime.now(), include_time=True, assume # ----------------------------------------------------------------------------- -def datetime_string(date_time = datetime.now(), format = "%Y-%m-%d--%H-%M-%S")-> str: +def datetime_string(date_time: datetime = datetime.now(), format: str = "%Y-%m-%d--%H-%M-%S") -> str: """ Converts a date and time to a formatted string. Optional Parameters: @@ -189,7 +189,7 @@ def datetime_string(date_time = datetime.now(), format = "%Y-%m-%d--%H-%M-%S")-> # ============================================================================= # LOGIC UTILITIES # ============================================================================= -def iif(condition, if_true, if_false): +def iif(condition: Any, if_true: Any, if_false: Any) -> Any: """ Accepts and evaluates a condition Returns the first parameter if the condition is true @@ -205,7 +205,7 @@ def iif(condition, if_true, if_false): # ============================================================================= # STRING UTILITIES # ============================================================================= -def normalize_content(content): +def normalize_content(content: Union[str, bytes]) -> str: """ Normalize Content Converts any bytes content to string. @@ -226,7 +226,7 @@ def normalize_content(content): return content # ----------------------------------------------------------------------------- -def get_first_non_whitespace_char(data): +def get_first_non_whitespace_char(data: str) -> str: """ Returns the first character this is not a space or tab. Returns an empty string if there is no content or if an error occurs. @@ -243,7 +243,7 @@ def get_first_non_whitespace_char(data): return ret_val # ----------------------------------------------------------------------------- -def safeJSON(object, keys): +def safeJSON(object: dict, keys: list) -> str: """ Depreciated in favor of the JSON librariy's `.get()` method. Always returns a string from a JSON key or nested keys. @@ -268,7 +268,7 @@ def safeJSON(object, keys): return ret_value # ----------------------------------------------------------------------------- -def JSON_safe_atomic(object, key): +def JSON_safe_atomic(object: dict, key: str) -> str: """ Always returns a string from a JSON key. If the value at the key is string, int, float, complex or boolean, returns the value as a string. @@ -296,11 +296,11 @@ def JSON_safe_atomic(object, key): return ret_value # ----------------------------------------------------------------------------- -def indent(level, length=3) -> str: +def indent(level: int, length: int = 3) -> str: return (" " * length * level) # ------------------------------------------------------------------------- -def has_repeated_ending(full_string, suffix, frequency=2): +def has_repeated_ending(full_string: str, suffix: str, frequency: int = 2) -> bool: """ Check if a string ends with a specific suffix repeated multiple times. @@ -331,7 +331,7 @@ def has_repeated_ending(full_string, suffix, frequency=2): # OS INTERACTION HELPER UTILITIES # ============================================================================= # ----------------------------------------------------------------------------- -def handle_environment_variables(env_name, verbose = False, error_only = True): +def handle_environment_variables(env_name: str, verbose: bool = False, error_only: bool = True) -> str: """ If the environment variable identified in the argument exits, return the value as a string. If the environment variable identified in the argument does not exit, return an empty string. @@ -353,7 +353,7 @@ def handle_environment_variables(env_name, verbose = False, error_only = True): return ret_value # ----------------------------------------------------------------------------- -def get_user_information(): +def get_user_information() -> str: """ Returns the current user's username. Uses the getpass module to retrieve the username. @@ -508,7 +508,7 @@ def is_valid_html_content(html_content: str) -> bool: return len(stack) == 0 # ------------------------------------------------------------------------- -def html_to_json_safe(html_content): +def html_to_json_safe(html_content: str) -> str: """ Convert HTML content to a JSON-safe string that can still be interpreted by browsers. @@ -531,7 +531,7 @@ def html_to_json_safe(html_content): return json_safe[1:-1] # ------------------------------------------------------------------------- -def html_from_json_safe(json_safe_content): +def html_from_json_safe(json_safe_content: str) -> str: """ Convert a JSON-safe HTML string back to regular HTML. @@ -552,7 +552,7 @@ def html_from_json_safe(json_safe_content): # UI HELPER UTILITIES # ============================================================================= # ----------------------------------------------------------------------------- -def tell_user(message, log_as = ""): +def tell_user(message: str, log_as: str = "") -> None: """ Outputs a message to the console. """ @@ -569,7 +569,7 @@ def tell_user(message, log_as = ""): pass # no logging # ----------------------------------------------------------------------------- -def processing(out_char = "."): +def processing(out_char: str = ".") -> None: """ Outputs a character to console. Intended to be called iterativley from a loop ton indicate progress. @@ -583,7 +583,7 @@ def processing(out_char = "."): # MISCELLANEOUS FUNCTIONS # ============================================================================= # ----------------------------------------------------------------------------- -def compare_semver(version1, version2): +def compare_semver(version1: str, version2: str) -> int: """ Compare two semantic versions and return: -1 if version1 < version2 diff --git a/ruf_common/html_to_markdown.py b/ruf_common/html_to_markdown.py index 08e4537..137a225 100644 --- a/ruf_common/html_to_markdown.py +++ b/ruf_common/html_to_markdown.py @@ -9,7 +9,7 @@ import html -def html_to_markdown(html_content): +def html_to_markdown(html_content: str) -> str: """ Convert HTML formatting to markdown formatting. @@ -137,7 +137,7 @@ def convert_table(match): return content -def html_to_markdown_file(input_file, output_file): +def html_to_markdown_file(input_file: str, output_file: str) -> None: """ Convert HTML file to markdown file. diff --git a/ruf_common/lfs.py b/ruf_common/lfs.py index 14489f4..2c6f853 100644 --- a/ruf_common/lfs.py +++ b/ruf_common/lfs.py @@ -12,7 +12,7 @@ # ============================================================================= # --- PyInstaller Interactions --- # ============================================================================= -def resource_path(relative_path): +def resource_path(relative_path: str) -> str: """ Get absolute path to resource, works for dev and for PyInstaller """ # PyInstaller creates a temp folder and stores path in _MEIPASS base_path = getattr(sys, '_MEIPASS', os.path.abspath(".")) @@ -22,7 +22,7 @@ def resource_path(relative_path): # --- LFS File Level Interactions --- # ============================================================================= -def zip_file(file_to_zip, zip_filename, overwrite=False, recurse=False): +def zip_file(file_to_zip: str, zip_filename: str, overwrite: bool = False, recurse: bool = False) -> bool: """ Creates a zip archive containing files and/or directories. @@ -103,7 +103,7 @@ def zip_file(file_to_zip, zip_filename, overwrite=False, recurse=False): return status # ----------------------------------------------------------------------------- -def putfile(file_name, content): +def putfile(file_name: str, content: str) -> bool: """ Saves content to a file. Returns True if successful. @@ -122,7 +122,7 @@ def putfile(file_name, content): return status # ----------------------------------------------------------------------------- -def get_json(file_name) -> dict: +def get_json(file_name: str) -> dict: """ Opens a JSON file and returns the contents as a dict object. If an error occurs, an empty dict is returned. @@ -138,7 +138,7 @@ def get_json(file_name) -> dict: return json_data # ----------------------------------------------------------------------------- -def save_json(data, file_name): +def save_json(data: dict, file_name: str) -> bool: """ Saves a dict object as a JSON file. Returns True if successful. @@ -155,7 +155,7 @@ def save_json(data, file_name): return status # ----------------------------------------------------------------------------- -def chkfile(path) -> bool: +def chkfile(path: str) -> bool: """ Checks for the existence of a file. Returns: @@ -178,7 +178,7 @@ def chkfile(path) -> bool: return status # ----------------------------------------------------------------------------- -def getfile(file_name, normalize = True, mode="rb") -> str: +def getfile(file_name: str, normalize: bool = True, mode: str = "rb") -> str: """ Opens a file and returns the contents. Handles errors gracefully. If no optional parameters are passed, this will open the file as binary @@ -220,7 +220,7 @@ def getfile(file_name, normalize = True, mode="rb") -> str: return ret_value # ----------------------------------------------------------------------------- -def getjsonfile(file_name) -> dict: +def getjsonfile(file_name: str) -> dict: """ Gets a JSON file from the local file system @@ -239,7 +239,7 @@ def getjsonfile(file_name) -> dict: return json_results # ----------------------------------------------------------------------------- -def backup_file(filename): +def backup_file(filename: str) -> bool: """ Creates a backup of the specified file by duplicating it in its current location and appending the date and time to the root file name. @@ -280,7 +280,7 @@ def get_app_location() -> str: return application_path # ----------------------------------------------------------------------------- -def chkdir(path, make_if_not_present = False) -> bool: +def chkdir(path: str, make_if_not_present: bool = False) -> bool: """ Checks for the existence of a folder. @@ -313,7 +313,7 @@ def chkdir(path, make_if_not_present = False) -> bool: # ----------------------------------------------------------------------------- -def mkdir(path) -> bool: +def mkdir(path: str) -> bool: """ Creates any folders needed to ensure the specified path exists. Returns: diff --git a/ruf_common/logging.py b/ruf_common/logging.py index a1b71c1..7ead9b4 100644 --- a/ruf_common/logging.py +++ b/ruf_common/logging.py @@ -1,14 +1,14 @@ from loguru import logger -from typing import List, Dict, Optional +from typing import Any, List, Dict, Optional import sys class DictSink: """Custom sink that captures log records as dictionaries.""" - def __init__(self): + def __init__(self) -> None: self.records: List[Dict] = [] - def write(self, message): + def write(self, message: Any) -> None: record = message.record log_entry = { 'timestamp': record['time'].isoformat(), @@ -28,7 +28,7 @@ def write(self, message): def get_records(self) -> List[Dict]: return self.records.copy() - def clear(self): + def clear(self) -> None: self.records.clear() @@ -87,12 +87,12 @@ def get_logs(self) -> List[Dict]: return self._dict_sink.get_records() return [] - def clear_logs(self): + def clear_logs(self) -> None: """Clear captured log records.""" if hasattr(self, '_dict_sink') and self._dict_sink: self._dict_sink.clear() - def cleanup_logging(self): + def cleanup_logging(self) -> None: """Remove all handlers added by this instance.""" if hasattr(self, '_handler_ids'): for handler_id in self._handler_ids: diff --git a/ruf_common/network.py b/ruf_common/network.py index 5d07241..17fd183 100644 --- a/ruf_common/network.py +++ b/ruf_common/network.py @@ -3,9 +3,10 @@ from . import helper import socket import aiohttp +from typing import Any, Optional -def check_internet_connection(): +def check_internet_connection() -> bool: try: # Try to connect to a reliable host socket.create_connection(("8.8.8.8", 53), timeout=3) @@ -13,21 +14,21 @@ def check_internet_connection(): except OSError: return False -async def async_api_get(url, headers=None): +async def async_api_get(url: str, headers: Optional[dict] = None) -> Any: """Asynchronous version of api_get""" - try: + try: # type: ignore async with aiohttp.ClientSession() as session: async with session.get(url, headers=headers) as response: if response.status == 200: return await response.json() else: logger.error(f"API request failed with status {response.status}") - return None + return None # type: ignore except Exception as e: logger.error(f"Error during API request: {str(e)}") return None -async def async_download_file(url, filename): +async def async_download_file(url: str, filename: str) -> Optional[bytes]: """Asynchronous version of download_file""" try: async with aiohttp.ClientSession() as session: @@ -43,12 +44,12 @@ async def async_download_file(url, filename): def api_get(endpoint, http_headers={"Content-type": "application/json"}, timeout_seconds=10): """ - Calls a REST API and returns the response. + Calls a REST API and returns the response. endpoint: The full URL to the REST endpoint - http_headers: optional headers to include in the request + http_headers: optional headers to include in the request If not provided, this funciton requests a JSON response. - """ + rest_ret: Optional[requests.Response] = None rest_ret = None try: rest_ret = requests.get(endpoint, headers=http_headers, timeout=timeout_seconds) @@ -64,10 +65,10 @@ def api_get(endpoint, http_headers={"Content-type": "application/json"}, timeout except Exception as err: logger.error(f"Unrecognized error {type(err).__name__}: {str(err)}\n--for GET {endpoint}") - if not rest_ret.status_code == 200: - logger.error(f"HTTP Error: {str(rest_ret.status_code)} {rest_ret.text}\n--for GET {endpoint}") - - logger.debug(f"GET {endpoint} returned {rest_ret.status_code}") + if rest_ret is not None: + if rest_ret.status_code != 200: + logger.error(f"HTTP Error: {str(rest_ret.status_code)} {rest_ret.text}\n--for GET {endpoint}") + logger.debug(f"GET {endpoint} returned {rest_ret.status_code}") return rest_ret @@ -75,7 +76,7 @@ def download_file(url, filename): """ Downloads a file from a URL and saves it to the specified filename. """ - ret_value = "" + ret_value: str = "" try: response = requests.get(url) ret_value = helper.normalize_content(response.content) diff --git a/tests/run_unit_tests.sh b/tests/run_unit_tests.sh new file mode 100755 index 0000000..a6ffc28 --- /dev/null +++ b/tests/run_unit_tests.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Manually trigger unit tests + +VENV_DIR="$(dirname "$0")/.venv" + +if [[ ! -d "$VENV_DIR" ]]; then + echo "Virtual environment not found. Creating $VENV_DIR ..." + python3 -m venv "$VENV_DIR" + source "$VENV_DIR/bin/activate" + echo "Installing dependencies ..." + pip install --quiet "../[dev]" +else + source "$VENV_DIR/bin/activate" +fi + +clear +python -m pytest unit/ -v diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index 16fa79d..0000000 --- a/tests/test_data.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for the data module.""" - -import pytest -from ruf_common import data - - -class TestData: - """Test cases for data module functions.""" - - def test_module_imports(self): - """Verify the data module can be imported.""" - assert data is not None - - # Add more specific tests for XML, JSON, YAML functions - # Example: - # def test_parse_json(self): - # result = data.parse_json('{"key": "value"}') - # assert result == {"key": "value"} diff --git a/tests/test_helper.py b/tests/test_helper.py deleted file mode 100644 index 490969e..0000000 --- a/tests/test_helper.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests for the helper module.""" - -import pytest -from ruf_common import helper - - -class TestHelper: - """Test cases for helper module functions.""" - - def test_module_imports(self): - """Verify the helper module can be imported.""" - assert helper is not None - - # Add more specific tests as needed - # Example: - # def test_some_function(self): - # result = helper.some_function(input_value) - # assert result == expected_value diff --git a/tests/unit/test_country_code_converter.py b/tests/unit/test_country_code_converter.py new file mode 100644 index 0000000..96eb8c8 --- /dev/null +++ b/tests/unit/test_country_code_converter.py @@ -0,0 +1,93 @@ +"""Tests for the country_code_converter module.""" + + +from ruf_common import country_code_converter as ccc + + +class TestCountryNameToCodeSimple: + def test_known_country(self): + result = ccc.country_name_to_code_simple("United States") + assert result == "US" + + def test_unknown_country_returns_empty(self): + result = ccc.country_name_to_code_simple("Atlantis") + assert result == "" + + def test_case_sensitive(self): + result = ccc.country_name_to_code_simple("united states") + assert result == "" + + def test_germany(self): + result = ccc.country_name_to_code_simple("Germany") + assert result == "DE" + + def test_japan(self): + result = ccc.country_name_to_code_simple("Japan") + assert result == "JP" + + +class TestCountryNameToCodeFuzzy: + def test_exact_match(self): + result = ccc.country_name_to_code_fuzzy("United States") + assert result == "US" + + def test_case_insensitive(self): + result = ccc.country_name_to_code_fuzzy("united states") + assert result == "US" + + def test_partial_match(self): + result = ccc.country_name_to_code_fuzzy("United Kingdom") + assert result != "" + + def test_no_match_returns_empty(self): + result = ccc.country_name_to_code_fuzzy("Neverland123") + assert result == "" + + +class TestSafeCountryNameToCodeMap: + def test_known_country(self): + result = ccc.safe_country_name_to_code_map("United States") + assert result == "US" + + def test_unknown_returns_empty(self): + result = ccc.safe_country_name_to_code_map("Fantasyland") + assert result == "" + + +class TestBatchConvertCountries: + def test_converts_list(self): + result = ccc.batch_convert_countries(["United States", "Germany"]) + assert isinstance(result, dict) + assert result.get("United States") == "US" + assert result.get("Germany") == "DE" + + def test_empty_list(self): + result = ccc.batch_convert_countries([]) + assert result == {} + + def test_unknown_country_in_batch(self): + result = ccc.batch_convert_countries(["Neverland"]) + assert isinstance(result, dict) + assert result.get("Neverland") == "" + + +class TestProcessLocationWithCountryCodes: + def test_adds_country_code(self): + locations = [{"country": "United States", "city": "New York"}] + result = ccc.process_location_with_country_codes(locations) + assert result[0]["country_code"] == "US" + + def test_location_without_country_skipped(self): + locations = [{"city": "Unknown"}] + result = ccc.process_location_with_country_codes(locations) + assert "country_code" not in result[0] or result[0].get("country_code") == "" + + def test_multiple_locations(self): + locations = [ + {"country": "Germany"}, + {"country": "Japan"}, + ] + result = ccc.process_location_with_country_codes(locations) + codes = [r.get("country_code") for r in result] + assert "DE" in codes + assert "JP" in codes diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py new file mode 100644 index 0000000..e4383c7 --- /dev/null +++ b/tests/unit/test_data.py @@ -0,0 +1,255 @@ +"""Tests for the data module.""" + +import xml.etree.ElementTree as ET + +import pytest + +from ruf_common import data + +SIMPLE_XML = "text" +NS_XML = 'text' +SIMPLE_JSON = '{"key": "value", "num": 42}' +SIMPLE_YAML = "key: value\nnum: 42" + + +class TestDetectDataFormat: + def test_xml_declaration(self): + assert data.detect_data_format("") == "xml" + + def test_xml_tag(self): + assert data.detect_data_format("") == "xml" + + def test_xml_with_leading_whitespace(self): + assert data.detect_data_format(" \n") == "xml" + + def test_json_object(self): + assert data.detect_data_format('{"a": 1}') == "json" + + def test_json_array(self): + assert data.detect_data_format('[1, 2, 3]') == "json" + + def test_yaml(self): + assert data.detect_data_format("key: value") == "yaml" + + def test_unknown(self): + assert data.detect_data_format("just some plain text") == "unknown" + + +class TestSafeLoadXml: + def test_valid_xml(self): + result = data.safe_load_xml(SIMPLE_XML) + assert result is not None + assert isinstance(result, ET.Element) + assert result.tag == "root" + + def test_invalid_xml(self): + result = data.safe_load_xml("") + assert result is None + + def test_empty_string(self): + result = data.safe_load_xml("") + assert result is None + + +class TestSafeLoadJson: + def test_valid_json_object(self): + result = data.safe_load_json(SIMPLE_JSON) + assert result == {"key": "value", "num": 42} + + def test_valid_json_array(self): + result = data.safe_load_json("[1, 2, 3]") + assert result == [1, 2, 3] + + def test_invalid_json(self): + result = data.safe_load_json("{bad json}") + assert result is None + + +class TestSafeLoadYaml: + def test_valid_yaml(self): + result = data.safe_load_yaml(SIMPLE_YAML) + assert result == {"key": "value", "num": 42} + + def test_invalid_yaml(self): + result = data.safe_load_yaml("key: [unclosed") + assert result is None + + +class TestSafeLoad: + def test_auto_detect_xml(self): + result = data.safe_load(SIMPLE_XML) + assert isinstance(result, ET.Element) + + def test_auto_detect_json(self): + result = data.safe_load(SIMPLE_JSON) + assert result == {"key": "value", "num": 42} + + def test_auto_detect_yaml(self): + result = data.safe_load(SIMPLE_YAML) + assert result == {"key": "value", "num": 42} + + def test_explicit_format_override(self): + result = data.safe_load(SIMPLE_JSON, data_format="json") + assert result is not None + + def test_unknown_format_returns_none(self): + result = data.safe_load("some text", data_format="csv") + assert result is None + + +class TestXpath: + @pytest.fixture + def tree(self): + return ET.fromstring(SIMPLE_XML) + + def test_finds_single_element(self, tree): + result = data.xpath(tree, {}, "//child") + assert result is not None + assert hasattr(result, "tag") + assert result.tag == "child" + + def test_returns_none_for_no_match(self, tree): + result = data.xpath(tree, {}, "//nonexistent") + assert result is None + + def test_returns_list_for_multiple_matches(self): + xml = "ab" + tree = ET.fromstring(xml) + result = data.xpath(tree, {}, "//item") + assert isinstance(result, list) + assert len(result) == 2 + + def test_with_context(self, tree): + child = tree.find("child") + result = data.xpath(tree, {}, ".", child) + assert result is not None + + def test_invalid_xpath_returns_none(self, tree): + result = data.xpath(tree, {}, "///invalid xpath!!!") + assert result is None + + +class TestXpathAtomic: + @pytest.fixture + def tree(self): + return ET.fromstring(SIMPLE_XML) + + def test_returns_string_value(self, tree): + result = data.xpath_atomic(tree, {}, "//child/text()") + assert result == "text" + + def test_returns_empty_string_for_no_match(self, tree): + result = data.xpath_atomic(tree, {}, "//nonexistent/text()") + assert result == "" + + def test_returns_string_type(self, tree): + result = data.xpath_atomic(tree, {}, "//child/text()") + assert isinstance(result, str) + + +class TestRemoveNamespace: + def test_removes_namespace_from_tag(self): + element = ET.fromstring('') + # ET parses ns:root as {http://example.com}root + data.remove_namespace(element) + assert "}" not in element.tag + + def test_removes_namespace_from_children(self): + xml = '' + element = ET.fromstring(xml) + data.remove_namespace(element) + for child in element: + assert "}" not in child.tag + + def test_no_namespace_unchanged(self): + element = ET.fromstring("") + data.remove_namespace(element) + assert element.tag == "root" + + +class TestExtractElementContent: + def test_simple_text(self): + element = ET.fromstring("Hello world") + result = data.extract_element_content(element) + assert result == "Hello world" + + def test_none_returns_empty(self): + result = data.extract_element_content(None) + assert result == "" + + def test_mixed_content(self): + element = ET.fromstring("

Para

Two

") + result = data.extract_element_content(element) + assert "Para" in result + assert "Two" in result + + def test_preserves_child_tags(self): + element = ET.fromstring("bold") + result = data.extract_element_content(element) + assert "bold" in result + + +class TestXmlToString: + def test_single_element(self): + element = ET.fromstring(SIMPLE_XML) + result = data.xml_to_string(element) + assert isinstance(result, str) + assert "root" in result + + def test_list_returns_first(self): + xml = "ab" + tree = ET.fromstring(xml) + items = list(tree) + result = data.xml_to_string(items) + assert "item" in result + + def test_empty_list_returns_empty(self): + result = data.xml_to_string([]) + assert result == "" + + +class TestRemoveNamespaceFromHtml: + def test_removes_xmlns_attribute(self): + html = '

text

' + result = data.remove_namespace_from_html(html) + assert "xmlns" not in result + + def test_removes_prefixed_xmlns(self): + html = '' + result = data.remove_namespace_from_html(html) + assert "xmlns" not in result + + def test_removes_namespace_prefix_from_tags(self): + html = "text" + result = data.remove_namespace_from_html(html) + assert "ns:" not in result + + +class TestDeserializeXml: + def test_valid_xml(self): + result = data.deserialize_xml(SIMPLE_XML, "") + assert result is not None + assert isinstance(result, ET.Element) + + def test_invalid_xml_returns_none(self): + result = data.deserialize_xml("", "") + assert result is None + + +class TestGetAttributeValue: + def test_existing_attribute(self): + element = ET.fromstring("") + assert data.get_attribute_value(element, "key") == "hello" + + def test_missing_attribute_returns_default(self): + element = ET.fromstring("") + assert data.get_attribute_value(element, "missing") == "" + + def test_custom_default(self): + element = ET.fromstring("") + assert data.get_attribute_value(element, "missing", "N/A") == "N/A" + + def test_namespaced_attribute_name_stripped(self): + # Namespace is stripped from the lookup key, so {ns}key looks up plain "key" + element = ET.fromstring("") + assert data.get_attribute_value(element, "{http://ns}key") == "val" diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py new file mode 100644 index 0000000..0eb4904 --- /dev/null +++ b/tests/unit/test_database.py @@ -0,0 +1,204 @@ +"""Tests for the Database class in the database module.""" + +import pytest + +from ruf_common.database import OSCAL_COMMON_TABLES, Database, db_datatype + +# create_table expects SQL types directly and "attributes" (not "constraints") +TABLE_DEF = { + "table_name": "items", + "table_fields": [ + {"name": "uuid", "type": "TEXT", "attributes": "PRIMARY KEY"}, + {"name": "label", "type": "TEXT"}, + {"name": "value", "type": "INTEGER"}, + ], + "table_indexes": [], +} + + +@pytest.fixture +def db(tmp_path): + d = Database("sqlite3", str(tmp_path / "test.db")) + yield d + del d + + +@pytest.fixture +def db_with_table(db): + db.create_table(TABLE_DEF) + return db + + +class TestDatabaseInit: + def test_conn_is_not_none_on_success(self, db): + assert db.conn is not None + + def test_type_stored(self, db): + assert db.type == "sqlite3" + + def test_target_stored(self, db, tmp_path): + assert str(tmp_path / "test.db") in db.target + + def test_unsupported_type_has_no_conn(self, tmp_path): + d = Database("postgres", str(tmp_path / "pg.db")) + assert d.conn is None + + +class TestDatabaseStr: + def test_returns_string(self, db): + assert isinstance(str(db), str) + + +class TestTableExists: + def test_nonexistent_table(self, db): + assert db.table_exists("no_such_table") is False + + def test_created_table_exists(self, db_with_table): + assert db_with_table.table_exists("items") is True + + +class TestCreateTable: + def test_creates_table(self, db): + result = db.create_table(TABLE_DEF) + assert result is True + assert db.table_exists("items") + + def test_idempotent(self, db): + db.create_table(TABLE_DEF) + result = db.create_table(TABLE_DEF) + assert result is True + + +class TestCheckForTables: + def test_creates_missing_tables(self, db): + tables = { + "items": { + "table_name": "items", + "table_fields": [ + {"name": "uuid", "type": "TEXT", "attributes": "PRIMARY KEY"}, + ], + "table_indexes": [], + } + } + result = db.check_for_tables(tables) + assert result is True + assert db.table_exists("items") + + +class TestInsert: + def test_inserts_record(self, db_with_table): + result = db_with_table.insert("items", {"uuid": "abc", "label": "test", "value": 1}) + assert result is True + + def test_inserted_record_queryable(self, db_with_table): + db_with_table.insert("items", {"uuid": "xyz", "label": "hello", "value": 42}) + rows = db_with_table.query("SELECT * FROM items WHERE uuid='xyz'") + assert len(rows) == 1 + assert rows[0]["label"] == "hello" + + +class TestQuery: + def test_returns_list(self, db_with_table): + result = db_with_table.query("SELECT * FROM items") + assert isinstance(result, list) + + def test_empty_table_returns_empty_list(self, db_with_table): + result = db_with_table.query("SELECT * FROM items") + assert result == [] + + def test_returns_dicts(self, db_with_table): + db_with_table.insert("items", {"uuid": "q1", "label": "lbl", "value": 7}) + result = db_with_table.query("SELECT * FROM items") + assert isinstance(result[0], dict) + assert "uuid" in result[0] + + +class TestRecordCount: + def test_empty_table(self, db_with_table): + assert db_with_table.record_count("items", "1=1") == 0 + + def test_after_insert(self, db_with_table): + db_with_table.insert("items", {"uuid": "c1", "label": "x", "value": 1}) + db_with_table.insert("items", {"uuid": "c2", "label": "y", "value": 2}) + assert db_with_table.record_count("items", "1=1") == 2 + + def test_with_where_clause(self, db_with_table): + db_with_table.insert("items", {"uuid": "w1", "label": "a", "value": 1}) + db_with_table.insert("items", {"uuid": "w2", "label": "b", "value": 2}) + count = db_with_table.record_count("items", "value = 1") + assert count == 1 + + def test_bad_table_returns_minus_one(self, db): + assert db.record_count("no_table", "1=1") == -1 + + +class TestDbExecute: + def test_executes_ddl(self, db): + result = db.db_execute("CREATE TABLE tmp (id TEXT PRIMARY KEY)") + assert result is True + + def test_executes_list_of_statements(self, db): + result = db.db_execute([ + "CREATE TABLE t1 (id TEXT PRIMARY KEY)", + "CREATE TABLE t2 (id TEXT PRIMARY KEY)", + ]) + assert result is True + + def test_bad_sql_returns_false(self, db): + result = db.db_execute("NOT VALID SQL!!!") + assert result is False + + +class TestDropTable: + def test_drops_existing_table(self, db_with_table): + result = db_with_table.drop_table("items") + assert result is True + assert not db_with_table.table_exists("items") + + def test_drop_nonexistent_table_does_not_raise(self, db): + result = db.drop_table("nonexistent") + assert isinstance(result, bool) + + +class TestCacheFile: + @pytest.fixture + def db_with_cache(self, db): + db.check_for_tables(OSCAL_COMMON_TABLES) + return db + + def test_cache_returns_truthy_on_success(self, db_with_cache): + result = db_with_cache.cache_file(b"binary content", attributes={"filename": "test.bin"}) + assert result # True on success + + def test_cache_with_custom_uuid_stores_retrievable(self, db_with_cache): + db_with_cache.cache_file(b"data", uuid="my-uuid", attributes={"filename": "f.bin"}) + result = db_with_cache.retrieve_file("my-uuid") + assert result == b"data" + + def test_retrieve_cached_file(self, db_with_cache): + my_uuid = "test-retrieve-uuid" + db_with_cache.cache_file(b"test data", uuid=my_uuid, attributes={"filename": "t.bin"}) + result = db_with_cache.retrieve_file(my_uuid) + assert result == b"test data" + + def test_retrieve_nonexistent_raises(self, db_with_cache): + with pytest.raises(Exception): + db_with_cache.retrieve_file("no-such-uuid") + + +class TestDbDatatype: + def test_string_to_text(self): + assert db_datatype("string", "sqlite3") == "TEXT" + + def test_integer_to_integer(self): + assert db_datatype("integer", "sqlite3") == "INTEGER" + + def test_datetime_to_numeric(self): + assert db_datatype("date-time", "sqlite3") == "NUMERIC" + + def test_unknown_defaults_to_text(self): + result = db_datatype("unknown_type", "sqlite3") + assert result == "TEXT" + + def test_boolean_to_integer(self): + assert db_datatype("boolean", "sqlite3") == "INTEGER" diff --git a/tests/unit/test_database_sqlite3.py b/tests/unit/test_database_sqlite3.py new file mode 100644 index 0000000..1737e8a --- /dev/null +++ b/tests/unit/test_database_sqlite3.py @@ -0,0 +1,254 @@ +"""Tests for the database_sqlite3 module.""" + +import sqlite3 + +import pytest + +from ruf_common import database_sqlite3 + +FILECACHE_DDL = """ + CREATE TABLE IF NOT EXISTS filecache ( + uuid TEXT PRIMARY KEY, + filename TEXT, + original_location TEXT, + mime_type TEXT, + file_type TEXT, + acquired TEXT, + datatype TEXT, + compressed INTEGER, + content BLOB + ) +""" + +GENERIC_DDL = """ + CREATE TABLE IF NOT EXISTS items ( + uuid TEXT PRIMARY KEY, + content BLOB NOT NULL, + datatype TEXT NOT NULL, + label TEXT + ) +""" + + +@pytest.fixture +def conn(): + c = sqlite3.connect(":memory:") + yield c + c.close() + + +@pytest.fixture +def conn_with_items(conn): + conn.execute(GENERIC_DDL) + conn.commit() + return conn + + +@pytest.fixture +def conn_with_filecache(conn): + conn.execute(FILECACHE_DDL) + conn.commit() + return conn + + +class TestSaveToDb: + def test_saves_string(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "hello") + assert uid != "" + + def test_saves_list(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", [1, 2, 3]) + assert uid != "" + + def test_saves_dict(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", {"a": 1}) + assert uid != "" + + def test_returns_provided_identifier(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data", identifier="my-uuid") + assert uid == "my-uuid" + + def test_auto_generates_uuid(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data") + assert len(uid) == 36 # UUID4 format + + def test_saves_additional_fields(self, conn_with_items): + uid = database_sqlite3.save_to_db( + conn_with_items, "items", "data", + additional_fields={"label": "test-label"} + ) + row = conn_with_items.execute("SELECT label FROM items WHERE uuid=?", (uid,)).fetchone() + assert row[0] == "test-label" + + def test_missing_table_returns_empty(self, conn): + result = database_sqlite3.save_to_db(conn, "nonexistent_table", "data") + assert result == "" + + def test_update_existing_record(self, conn_with_items): + database_sqlite3.save_to_db(conn_with_items, "items", "original", identifier="fixed-id") + uid2 = database_sqlite3.save_to_db(conn_with_items, "items", "updated", identifier="fixed-id") + assert uid2 == "fixed-id" + count = conn_with_items.execute("SELECT COUNT(*) FROM items WHERE uuid=?", ("fixed-id",)).fetchone()[0] + assert count == 1 + + +class TestGetFromDb: + def test_retrieves_string(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "hello world") + result = database_sqlite3.get_from_db(conn_with_items, "items", uid) + assert result == "hello world" + + def test_retrieves_list(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", [10, 20, 30]) + result = database_sqlite3.get_from_db(conn_with_items, "items", uid) + assert result == [10, 20, 30] + + def test_retrieves_dict(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", {"x": 99}) + result = database_sqlite3.get_from_db(conn_with_items, "items", uid) + assert result == {"x": 99} + + def test_raises_on_missing_record(self, conn_with_items): + with pytest.raises(ValueError): + database_sqlite3.get_from_db(conn_with_items, "items", "no-such-uuid") + + +class TestUpdateRecordFromDict: + def test_updates_text_field(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data", + additional_fields={"label": "old"}) + database_sqlite3.update_record_from_dict(conn_with_items, "items", uid, {"label": "new"}) + row = conn_with_items.execute("SELECT label FROM items WHERE uuid=?", (uid,)).fetchone() + assert row[0] == "new" + + def test_ignores_unknown_fields_when_valid_fields_present(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data", + additional_fields={"label": "old"}) + result = database_sqlite3.update_record_from_dict( + conn_with_items, "items", uid, {"label": "new", "nonexistent_col": "val"} + ) + assert result is True + row = conn_with_items.execute("SELECT label FROM items WHERE uuid=?", (uid,)).fetchone() + assert row[0] == "new" + + def test_returns_true_on_success(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data", + additional_fields={"label": "x"}) + result = database_sqlite3.update_record_from_dict(conn_with_items, "items", uid, {"label": "y"}) + assert result is True + + +class TestGetRecordMetadata: + def test_returns_dict(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data", + additional_fields={"label": "meta"}) + result = database_sqlite3.get_record_metadata(conn_with_items, "items", uid) + assert isinstance(result, dict) + assert result["label"] == "meta" + + def test_raises_on_missing_record(self, conn_with_items): + with pytest.raises(ValueError): + database_sqlite3.get_record_metadata(conn_with_items, "items", "no-such-uuid") + + def test_does_not_include_blob_columns(self, conn_with_items): + uid = database_sqlite3.save_to_db(conn_with_items, "items", "data") + result = database_sqlite3.get_record_metadata(conn_with_items, "items", uid) + assert "content" not in result + + +class TestStoreBlobToDb: + def test_stores_bytes(self, conn_with_filecache): + result = database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-1", b"binary data", {} + ) + assert result is True + + def test_stores_string(self, conn_with_filecache): + result = database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-2", "text content", {} + ) + assert result is True + + def test_stores_list(self, conn_with_filecache): + result = database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-3", [1, 2, 3], {} + ) + assert result is True + + def test_stores_dict(self, conn_with_filecache): + result = database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-4", {"k": "v"}, {} + ) + assert result is True + + def test_stores_none(self, conn_with_filecache): + result = database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-5", None, {} + ) + assert result is True + + def test_unsupported_type_raises(self, conn_with_filecache): + with pytest.raises(ValueError): + database_sqlite3.store_blob_to_db( + conn_with_filecache, "blob-uuid-6", object(), {} + ) + + def test_missing_table_returns_false(self, conn): + result = database_sqlite3.store_blob_to_db(conn, "blob-uuid-7", b"data", {}) + assert result is False + + +class TestRetrieveBlobFromDb: + def test_retrieves_bytes(self, conn_with_filecache): + database_sqlite3.store_blob_to_db(conn_with_filecache, "r1", b"hello", {}) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r1") + assert result["content"] == b"hello" + assert result["datatype"] == "bytes" + + def test_retrieves_string(self, conn_with_filecache): + database_sqlite3.store_blob_to_db(conn_with_filecache, "r2", "text", {}) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r2") + assert result["content"] == "text" + + def test_retrieves_list(self, conn_with_filecache): + database_sqlite3.store_blob_to_db(conn_with_filecache, "r3", [9, 8, 7], {}) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r3") + assert result["content"] == [9, 8, 7] + + def test_retrieves_dict(self, conn_with_filecache): + database_sqlite3.store_blob_to_db(conn_with_filecache, "r4", {"a": "b"}, {}) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r4") + assert result["content"] == {"a": "b"} + + def test_result_has_metadata_fields(self, conn_with_filecache): + attrs = {"filename": "test.bin", "mime_type": "application/octet-stream"} + database_sqlite3.store_blob_to_db(conn_with_filecache, "r5", b"data", attrs) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r5") + assert result["filename"] == "test.bin" + assert result["mime_type"] == "application/octet-stream" + + def test_raises_on_missing_uuid(self, conn_with_filecache): + with pytest.raises(ValueError): + database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "nonexistent") + + def test_compressed_roundtrip(self, conn_with_filecache): + database_sqlite3.store_blob_to_db( + conn_with_filecache, "r6", b"compress me", {"compress": True} + ) + result = database_sqlite3.retrieve_blob_from_db(conn_with_filecache, "r6") + assert result["content"] == b"compress me" + + +class TestOpenSqlite3: + def test_opens_in_memory(self, tmp_path): + path = str(tmp_path / "test.db") + conn = database_sqlite3.open_sqlite3(path) + assert conn is not None + conn.close() + + def test_creates_directory_if_needed(self, tmp_path): + nested = tmp_path / "sub" / "db" / "test.db" + conn = database_sqlite3.open_sqlite3(str(nested)) + assert conn is not None + conn.close() + assert nested.exists() diff --git a/tests/unit/test_helper.py b/tests/unit/test_helper.py new file mode 100644 index 0000000..7e0c423 --- /dev/null +++ b/tests/unit/test_helper.py @@ -0,0 +1,268 @@ +"""Tests for the helper module.""" + +import json +import os +from datetime import datetime + +import pytest +import pytz + +from ruf_common import helper + + +class TestIif: + def test_true_condition(self): + assert helper.iif(True, "yes", "no") == "yes" + + def test_false_condition(self): + assert helper.iif(False, "yes", "no") == "no" + + def test_truthy_value(self): + assert helper.iif(1, "a", "b") == "a" + + def test_falsy_value(self): + assert helper.iif(0, "a", "b") == "b" + + def test_returns_any_type(self): + assert helper.iif(True, [1, 2], None) == [1, 2] + + +class TestNormalizeContent: + def test_string_passthrough(self): + assert helper.normalize_content("hello") == "hello" + + def test_bytes_decoded(self): + assert helper.normalize_content(b"hello") == "hello" + + def test_utf8_bytes(self): + assert helper.normalize_content("café".encode("utf-8")) == "café" + + +class TestGetFirstNonWhitespaceChar: + def test_simple(self): + assert helper.get_first_non_whitespace_char(" hello") == "h" + + def test_no_whitespace(self): + assert helper.get_first_non_whitespace_char("abc") == "a" + + def test_all_whitespace(self): + assert helper.get_first_non_whitespace_char(" ") == "" + + def test_empty_string(self): + assert helper.get_first_non_whitespace_char("") == "" + + def test_tab_then_char(self): + assert helper.get_first_non_whitespace_char("\t\tX") == "X" + + +class TestJsonSafeAtomic: + def test_string_value(self): + assert helper.JSON_safe_atomic({"k": "v"}, "k") == "v" + + def test_int_value(self): + assert helper.JSON_safe_atomic({"n": 42}, "n") == "42" + + def test_float_value(self): + assert helper.JSON_safe_atomic({"f": 3.14}, "f") == "3.14" + + def test_bool_value(self): + assert helper.JSON_safe_atomic({"b": True}, "b") == "True" + + def test_dict_value_serialized(self): + result = helper.JSON_safe_atomic({"d": {"a": 1}}, "d") + assert json.loads(result) == {"a": 1} + + def test_missing_key_returns_empty(self): + assert helper.JSON_safe_atomic({"a": 1}, "missing") == "" + + +class TestIndent: + def test_zero_level(self): + assert helper.indent(0) == "" + + def test_one_level_default(self): + assert helper.indent(1) == " " + + def test_two_levels_default(self): + assert helper.indent(2) == " " + + def test_custom_length(self): + assert helper.indent(1, 4) == " " + + +class TestHasRepeatedEnding: + def test_double_suffix(self): + assert helper.has_repeated_ending("abcabc", "abc") is True + + def test_single_suffix_not_double(self): + assert helper.has_repeated_ending("xyzabc", "abc") is False + + def test_empty_suffix_false(self): + assert helper.has_repeated_ending("abc", "") is False + + def test_empty_string_false(self): + assert helper.has_repeated_ending("", "abc") is False + + def test_triple_with_frequency_3(self): + assert helper.has_repeated_ending("xabcabcabc", "abc", 3) is True + + def test_too_short_string(self): + assert helper.has_repeated_ending("ab", "abc") is False + + +class TestHandleEnvironmentVariables: + def test_existing_variable(self): + os.environ["TEST_VAR_RUF"] = "hello" + try: + assert helper.handle_environment_variables("TEST_VAR_RUF") == "hello" + finally: + del os.environ["TEST_VAR_RUF"] + + def test_missing_variable(self): + result = helper.handle_environment_variables("DEFINITELY_NOT_SET_RUF_XYZ") + assert result == "" + + +class TestGetUserInformation: + def test_returns_string(self): + result = helper.get_user_information() + assert isinstance(result, str) + assert len(result) > 0 + + +class TestPrepareHtmlForJson: + def test_escapes_quotes(self): + result = helper.prepare_html_for_json('
y
') + assert '\\"' in result + + def test_escapes_newlines(self): + result = helper.prepare_html_for_json("line1\nline2") + assert "\\n" in result + + def test_escapes_tabs(self): + result = helper.prepare_html_for_json("a\tb") + assert "\\t" in result + + def test_type_error_on_non_string(self): + with pytest.raises(TypeError): + helper.prepare_html_for_json(123) + + def test_passthrough_plain_text(self): + result = helper.prepare_html_for_json("hello world") + assert "hello world" in result + + +class TestCreateHtmlUpdateMessage: + def test_basic_message(self): + msg = helper.create_html_update_message("my-div", "

hi

") + parsed = json.loads(msg) + assert parsed["type"] == "html" + assert parsed["targetId"] == "my-div" + assert "hi" in parsed["content"] + + def test_additional_data_merged(self): + msg = helper.create_html_update_message("div", "

", {"extra": "val"}) + parsed = json.loads(msg) + assert parsed["extra"] == "val" + + def test_empty_target_raises(self): + with pytest.raises(ValueError): + helper.create_html_update_message("", "

") + + def test_invalid_additional_data_raises(self): + with pytest.raises(TypeError): + helper.create_html_update_message("div", "

", "not a dict") + + +class TestIsValidHtmlContent: + def test_valid_simple(self): + assert helper.is_valid_html_content("

hello
") is True + + def test_mismatched_tags(self): + assert helper.is_valid_html_content("
hello") is False + + def test_self_closing(self): + assert helper.is_valid_html_content("
") is True + + def test_empty_string(self): + assert helper.is_valid_html_content("") is False + + def test_whitespace_only(self): + assert helper.is_valid_html_content(" ") is False + + def test_nested_valid(self): + assert helper.is_valid_html_content("

text

") is True + + +class TestHtmlToJsonSafe: + def test_roundtrip(self): + original = '
Hello & world\n
' + safe = helper.html_to_json_safe(original) + recovered = helper.html_from_json_safe(safe) + assert recovered == original + + def test_empty_input(self): + assert helper.html_to_json_safe("") == "" + + def test_returns_string(self): + result = helper.html_to_json_safe("

hi

") + assert isinstance(result, str) + + +class TestHtmlFromJsonSafe: + def test_empty_input(self): + assert helper.html_from_json_safe("") == "" + + def test_unescapes_quotes(self): + safe = helper.html_to_json_safe('link') + result = helper.html_from_json_safe(safe) + assert '"url"' in result + + +class TestDatetimeString: + def test_returns_formatted_string(self): + dt = datetime(2025, 1, 15, 10, 30, 45) + result = helper.datetime_string(dt) + assert result == "2025-01-15--10-30-45" + + def test_custom_format(self): + dt = datetime(2025, 6, 1) + result = helper.datetime_string(dt, format="%Y/%m/%d") + assert result == "2025/06/01" + + +class TestConvertDatetimeFormat: + def test_datetime_object_returns_string(self): + dt = datetime(2025, 3, 15, 12, 0, 0, tzinfo=pytz.UTC) + result = helper.convert_datetime_format(dt) + assert isinstance(result, str) + assert len(result) > 0 + + def test_empty_string_returns_empty(self): + assert helper.convert_datetime_format("") == "" + + def test_invalid_string_returns_empty(self): + assert helper.convert_datetime_format("not-a-date") == "" + + def test_date_only(self): + dt = datetime(2025, 3, 15, 12, 0, 0, tzinfo=pytz.UTC) + result = helper.convert_datetime_format(dt, include_time=False) + assert "2025" in result + assert ":" not in result + + +class TestCompareSemver: + def test_equal(self): + assert helper.compare_semver("1.2.3", "1.2.3") == 0 + + def test_less_than(self): + assert helper.compare_semver("1.0.0", "2.0.0") == -1 + + def test_greater_than(self): + assert helper.compare_semver("2.0.0", "1.9.9") == 1 + + def test_patch_difference(self): + assert helper.compare_semver("1.0.1", "1.0.0") == 1 + + def test_minor_difference(self): + assert helper.compare_semver("1.1.0", "1.2.0") == -1 diff --git a/tests/unit/test_html_to_markdown.py b/tests/unit/test_html_to_markdown.py new file mode 100644 index 0000000..45e9cde --- /dev/null +++ b/tests/unit/test_html_to_markdown.py @@ -0,0 +1,153 @@ +"""Tests for the html_to_markdown module.""" + + +from ruf_common.html_to_markdown import html_to_markdown + + +class TestHtmlToMarkdownBasic: + def test_empty_string_returns_empty(self): + assert html_to_markdown("") == "" + + def test_plain_text_passthrough(self): + result = html_to_markdown("Hello world") + assert "Hello world" in result + + def test_whitespace_only_returns_empty_marker(self): + result = html_to_markdown("

") + assert result == "_Empty_" or result.strip() == "" + + +class TestHeaders: + def test_h1(self): + result = html_to_markdown("

Title

") + assert result.startswith("# Title") + + def test_h2(self): + result = html_to_markdown("

Sub

") + assert result.startswith("## Sub") + + def test_h6(self): + result = html_to_markdown("
Tiny
") + assert result.startswith("###### Tiny") + + +class TestInlineFormatting: + def test_bold_strong(self): + result = html_to_markdown("bold") + assert "**bold**" in result + + def test_bold_b_tag(self): + result = html_to_markdown("bold") + assert "**bold**" in result + + def test_italic_em(self): + result = html_to_markdown("italic") + assert "*italic*" in result + + def test_italic_i_tag(self): + result = html_to_markdown("italic") + assert "*italic*" in result + + def test_inline_code(self): + result = html_to_markdown("x = 1") + assert "`x = 1`" in result + + def test_link(self): + result = html_to_markdown('click') + assert "[click](https://example.com)" in result + + +class TestLineBreakAndHr: + def test_br_becomes_newline(self): + result = html_to_markdown("line1
line2") + assert "line1" in result + assert "line2" in result + + def test_hr_becomes_dashes(self): + result = html_to_markdown("
") + assert "---" in result + + +class TestLists: + def test_unordered_list(self): + html = "
  • one
  • two
" + result = html_to_markdown(html) + assert "- one" in result + assert "- two" in result + + def test_ordered_list(self): + html = "
  1. first
  2. second
" + result = html_to_markdown(html) + assert "1. first" in result + assert "2. second" in result + + +class TestParagraph: + def test_paragraph_content_preserved(self): + result = html_to_markdown("

Hello paragraph

") + assert "Hello paragraph" in result + + +class TestCodeBlock: + def test_pre_becomes_fenced(self): + result = html_to_markdown("
code here
") + assert "```" in result + assert "code here" in result + + +class TestBlockquote: + def test_blockquote_prefix(self): + result = html_to_markdown("
quoted text
") + assert "> " in result + assert "quoted text" in result + + +class TestImage: + def test_image_with_alt(self): + result = html_to_markdown('desc') + assert "![desc](img.png)" in result + + def test_image_without_alt(self): + result = html_to_markdown('') + assert "img.png" in result + + +class TestTable: + def test_table_with_headers(self): + html = ( + "" + "" + "" + "
AB
12
" + ) + result = html_to_markdown(html) + assert "| A |" in result or "| A" in result + assert "---" in result + + def test_table_rows(self): + html = ( + "" + "" + "" + "
H
val
" + ) + result = html_to_markdown(html) + assert "val" in result + + +class TestHtmlEntities: + def test_ampersand_entity(self): + result = html_to_markdown("

a & b

") + assert "a & b" in result + + def test_nbsp_entity(self): + result = html_to_markdown("

hello world

") + assert "hello" in result and "world" in result + + +class TestStripRemainingTags: + def test_unknown_tags_removed(self): + result = html_to_markdown("
text
") + assert "
" not in result + assert "" not in result + assert "text" in result diff --git a/tests/test_imports.py b/tests/unit/test_imports.py similarity index 98% rename from tests/test_imports.py rename to tests/unit/test_imports.py index 162f149..8bde192 100644 --- a/tests/test_imports.py +++ b/tests/unit/test_imports.py @@ -6,7 +6,7 @@ import pytest -PACKAGE_DIR = Path(__file__).resolve().parent.parent / "ruf_common" +PACKAGE_DIR = Path(__file__).resolve().parent.parent.parent / "ruf_common" # Internal sub-modules that are not part of the public API # (imported by other modules, not re-exported from __init__.py) diff --git a/tests/unit/test_lfs.py b/tests/unit/test_lfs.py new file mode 100644 index 0000000..4ad8c41 --- /dev/null +++ b/tests/unit/test_lfs.py @@ -0,0 +1,164 @@ +"""Tests for the lfs (local file system) module.""" + +import json +import os +import zipfile + +from ruf_common import lfs + + +class TestChkfile: + def test_existing_file(self, tmp_path): + f = tmp_path / "test.txt" + f.write_text("hello") + assert lfs.chkfile(str(f)) is True + + def test_missing_file(self, tmp_path): + assert lfs.chkfile(str(tmp_path / "no_such_file.txt")) is False + + +class TestChkdir: + def test_existing_dir(self, tmp_path): + assert lfs.chkdir(str(tmp_path)) is True + + def test_missing_dir(self, tmp_path): + assert lfs.chkdir(str(tmp_path / "nonexistent")) is False + + def test_make_if_not_present(self, tmp_path): + new_dir = str(tmp_path / "newdir") + assert lfs.chkdir(new_dir, make_if_not_present=True) is True + assert os.path.isdir(new_dir) + + +class TestMkdir: + def test_creates_directory(self, tmp_path): + new_dir = str(tmp_path / "created") + result = lfs.mkdir(new_dir) + assert result is True + assert os.path.isdir(new_dir) + + def test_already_exists(self, tmp_path): + result = lfs.mkdir(str(tmp_path)) + assert result is True + + +class TestPutfile: + def test_writes_content(self, tmp_path): + f = str(tmp_path / "out.txt") + result = lfs.putfile(f, "hello world") + assert result is True + assert open(f).read() == "hello world" + + def test_returns_false_on_invalid_path(self): + result = lfs.putfile("/no_such_dir/file.txt", "data") + assert result is False + + +class TestGetfile: + def test_reads_text_file(self, tmp_path): + f = tmp_path / "test.txt" + f.write_text("content here", encoding="utf-8") + result = lfs.getfile(str(f)) + assert "content here" in result + + def test_missing_file_returns_empty(self, tmp_path): + result = lfs.getfile(str(tmp_path / "missing.txt")) + assert result == "" + + def test_normalize_false_may_return_bytes(self, tmp_path): + f = tmp_path / "test.txt" + f.write_bytes(b"raw bytes") + result = lfs.getfile(str(f), normalize=False) + assert result is not None + + +class TestGetJson: + def test_reads_valid_json(self, tmp_path): + f = tmp_path / "data.json" + f.write_text('{"key": "val"}', encoding="utf-8") + result = lfs.get_json(str(f)) + assert result == {"key": "val"} + + def test_missing_file_returns_empty_dict(self, tmp_path): + result = lfs.get_json(str(tmp_path / "missing.json")) + assert result == {} + + def test_invalid_json_returns_empty_dict(self, tmp_path): + f = tmp_path / "bad.json" + f.write_text("{bad json}", encoding="utf-8") + result = lfs.get_json(str(f)) + assert result == {} + + +class TestSaveJson: + def test_writes_json_file(self, tmp_path): + f = str(tmp_path / "out.json") + result = lfs.save_json({"a": 1}, f) + assert result is True + loaded = json.loads(open(f).read()) + assert loaded == {"a": 1} + + def test_json_is_indented(self, tmp_path): + f = str(tmp_path / "out.json") + lfs.save_json({"x": 1}, f) + content = open(f).read() + assert "\n" in content + + +class TestGetjsonfile: + def test_reads_json_file(self, tmp_path): + f = tmp_path / "data.json" + f.write_text('{"k": "v"}', encoding="utf-8") + result = lfs.getjsonfile(str(f)) + assert result == {"k": "v"} + + def test_missing_returns_empty(self, tmp_path): + result = lfs.getjsonfile(str(tmp_path / "nope.json")) + assert result == {} + + +class TestBackupFile: + def test_creates_backup(self, tmp_path): + f = tmp_path / "config.json" + f.write_text('{"a": 1}', encoding="utf-8") + result = lfs.backup_file(str(f)) + assert result is True + backups = [p for p in tmp_path.iterdir() if "config_" in p.name] + assert len(backups) == 1 + + def test_backup_has_same_content(self, tmp_path): + f = tmp_path / "config.json" + f.write_text('{"a": 1}', encoding="utf-8") + lfs.backup_file(str(f)) + backups = [p for p in tmp_path.iterdir() if "config_" in p.name] + assert backups[0].read_text() == '{"a": 1}' + + def test_missing_file_returns_false(self, tmp_path): + result = lfs.backup_file(str(tmp_path / "nonexistent.txt")) + assert result is False + + +class TestZipFile: + def test_zips_single_file(self, tmp_path): + src = tmp_path / "source.txt" + src.write_text("hello") + zip_path = str(tmp_path / "out.zip") + result = lfs.zip_file(str(src), zip_path) + assert result is True + assert zipfile.is_zipfile(zip_path) + + def test_overwrite_false_blocked_if_exists(self, tmp_path): + src = tmp_path / "source.txt" + src.write_text("hello") + zip_path = str(tmp_path / "out.zip") + lfs.zip_file(str(src), zip_path) + result = lfs.zip_file(str(src), zip_path, overwrite=False) + assert result is None or result is False + + def test_overwrite_true_replaces(self, tmp_path): + src = tmp_path / "source.txt" + src.write_text("hello") + zip_path = str(tmp_path / "out.zip") + lfs.zip_file(str(src), zip_path) + result = lfs.zip_file(str(src), zip_path, overwrite=True) + assert result is True diff --git a/tests/unit/test_logging_module.py b/tests/unit/test_logging_module.py new file mode 100644 index 0000000..a1d587f --- /dev/null +++ b/tests/unit/test_logging_module.py @@ -0,0 +1,120 @@ +"""Tests for the logging module (DictSink and LoggableMixin).""" + +import pytest +from loguru import logger + +from ruf_common.logging import DictSink, LoggableMixin + + +class TestDictSink: + @pytest.fixture + def sink(self): + s = DictSink() + handler_id = logger.add(s, format="{message}", level="DEBUG") + yield s + logger.remove(handler_id) + + def test_captures_log_record(self, sink): + logger.info("test message") + records = sink.get_records() + assert any(r["message"] == "test message" for r in records) + + def test_record_has_required_fields(self, sink): + logger.warning("check fields") + records = sink.get_records() + assert len(records) > 0 + record = records[-1] + assert "timestamp" in record + assert "level" in record + assert "message" in record + assert "module" in record + assert "function" in record + assert "line" in record + + def test_level_captured_correctly(self, sink): + logger.error("an error") + records = sink.get_records() + error_records = [r for r in records if r["message"] == "an error"] + assert len(error_records) == 1 + assert error_records[0]["level"] == "ERROR" + + def test_get_records_returns_copy(self, sink): + logger.info("msg") + r1 = sink.get_records() + r2 = sink.get_records() + assert r1 == r2 + r1.clear() + assert len(sink.get_records()) > 0 # original unaffected + + def test_clear_removes_records(self, sink): + logger.info("something") + sink.clear() + assert sink.get_records() == [] + + def test_multiple_records(self, sink): + logger.info("one") + logger.info("two") + logger.info("three") + records = sink.get_records() + messages = [r["message"] for r in records] + assert "one" in messages + assert "two" in messages + assert "three" in messages + + +class TestLoggableMixin: + class _MyClass(LoggableMixin): + pass + + @pytest.fixture + def obj(self): + instance = self._MyClass() + yield instance + try: + instance.cleanup_logging() + except ValueError: + pass # handlers already removed in the test + + def test_setup_dict_mode(self, obj): + obj.setup_logging(log_mode="dict") + assert obj._dict_sink is not None + + def test_get_logs_returns_list(self, obj): + obj.setup_logging(log_mode="dict") + logs = obj.get_logs() + assert isinstance(logs, list) + + def test_clear_logs(self, obj): + obj.setup_logging(log_mode="dict") + logger.debug("captured") + obj.clear_logs() + assert obj.get_logs() == [] + + def test_get_logs_without_dict_mode_returns_empty(self, obj): + assert obj.get_logs() == [] + + def test_cleanup_removes_handlers(self, obj): + obj.setup_logging(log_mode="dict") + handler_ids_before = list(obj._handler_ids) + obj.cleanup_logging() + assert obj._handler_ids == handler_ids_before # ids still stored, but removed from loguru + + def test_file_mode_requires_log_file(self, obj): + with pytest.raises(ValueError): + obj.setup_logging(log_mode="file", log_file=None) + + def test_file_mode_creates_file(self, obj, tmp_path): + log_file = str(tmp_path / "test.log") + obj.setup_logging(log_mode="file", log_file=log_file) + logger.info("file test") + obj.cleanup_logging() + assert os.path.exists(log_file) + + def test_setup_reinitializes(self, obj): + obj.setup_logging(log_mode="dict") + first_id = list(obj._handler_ids) + obj.setup_logging(log_mode="dict") + assert obj._handler_ids != first_id or len(obj._handler_ids) > 0 + + +import os # noqa: E402 — needed for test_file_mode_creates_file diff --git a/tests/unit/test_network.py b/tests/unit/test_network.py new file mode 100644 index 0000000..2d3d23d --- /dev/null +++ b/tests/unit/test_network.py @@ -0,0 +1,85 @@ +"""Tests for the network module (mocked external calls).""" + +from unittest.mock import MagicMock, patch + +from ruf_common import network + + +class TestCheckInternetConnection: + def test_returns_bool(self): + result = network.check_internet_connection() + assert isinstance(result, bool) + + def test_true_when_connection_succeeds(self): + with patch("socket.create_connection", return_value=MagicMock()): + result = network.check_internet_connection() + assert result is True + + def test_false_when_socket_raises(self): + with patch("socket.create_connection", side_effect=OSError("unreachable")): + result = network.check_internet_connection() + assert result is False + + +class TestApiGet: + def test_returns_response_on_success(self): + mock_response = MagicMock() + mock_response.status_code = 200 + with patch("requests.get", return_value=mock_response): + result = network.api_get("https://example.com/api") + assert result is not None + assert result.status_code == 200 + + def test_returns_none_on_connection_error(self): + import requests + with patch("requests.get", side_effect=requests.exceptions.ConnectionError("no connection")): + result = network.api_get("https://example.com/api") + assert result is None + + def test_returns_none_on_timeout(self): + import requests + with patch("requests.get", side_effect=requests.exceptions.Timeout("timeout")): + result = network.api_get("https://example.com/api") + assert result is None + + def test_passes_headers(self): + mock_response = MagicMock() + mock_response.status_code = 200 + with patch("requests.get", return_value=mock_response) as mock_get: + custom_headers = {"Authorization": "Bearer token"} + network.api_get("https://example.com/api", custom_headers) + call_kwargs = mock_get.call_args + assert call_kwargs is not None + + +class TestDownloadFile: + def test_returns_content_on_success(self): + mock_response = MagicMock() + mock_response.content = b"file content" + with patch("requests.get", return_value=mock_response): + result = network.download_file("https://example.com/file.bin", "output.bin") + assert result is not None + + def test_returns_empty_string_on_error(self): + import requests + with patch("requests.get", side_effect=requests.exceptions.RequestException("error")): + result = network.download_file("https://example.com/file.bin", "output.bin") + assert result == "" + + +class TestAsyncApiGet: + def test_is_coroutine(self): + import inspect + assert inspect.iscoroutinefunction(network.async_api_get) + + def test_returns_none_on_error(self): + import asyncio + + async def run(): + with patch("aiohttp.ClientSession") as mock_session: + mock_session.return_value.__aenter__ = MagicMock(side_effect=Exception("network error")) + mock_session.return_value.__aexit__ = MagicMock(return_value=False) + return await network.async_api_get("https://example.com/api") + + result = asyncio.run(run()) + assert result is None diff --git a/tests/unit/test_stats.py b/tests/unit/test_stats.py new file mode 100644 index 0000000..52f1e77 --- /dev/null +++ b/tests/unit/test_stats.py @@ -0,0 +1,77 @@ +"""Tests for the stats module.""" + +import pytest + +import ruf_common.stats as stats_module + + +@pytest.fixture(autouse=True) +def reset_stats(): + """Reset the global stats dict before each test.""" + stats_module.stats.clear() + yield + stats_module.stats.clear() + + +class TestIncrementStat: + def test_creates_new_stat(self): + stats_module.increment_stat("hits") + assert stats_module.stats["hits"] == 1 + + def test_increments_existing_stat(self): + stats_module.increment_stat("hits") + stats_module.increment_stat("hits") + assert stats_module.stats["hits"] == 2 + + def test_custom_increment(self): + stats_module.increment_stat("bytes", 100) + assert stats_module.stats["bytes"] == 100 + + def test_multiple_stats_independent(self): + stats_module.increment_stat("a") + stats_module.increment_stat("b") + stats_module.increment_stat("b") + assert stats_module.stats["a"] == 1 + assert stats_module.stats["b"] == 2 + + +class TestGetStat: + def test_returns_zero_for_missing(self): + assert stats_module.get_stat("nonexistent") == 0 + + def test_returns_correct_value(self): + stats_module.increment_stat("score", 5) + assert stats_module.get_stat("score") == 5 + + def test_after_multiple_increments(self): + stats_module.increment_stat("x") + stats_module.increment_stat("x") + stats_module.increment_stat("x") + assert stats_module.get_stat("x") == 3 + + +class TestStatsSummary: + def test_empty_stats_has_heading(self): + result = stats_module.stats_summary() + assert "Summary" in result + + def test_custom_heading(self): + result = stats_module.stats_summary("My Report") + assert "My Report" in result + + def test_contains_stat_names_and_values(self): + stats_module.increment_stat("errors", 3) + result = stats_module.stats_summary() + assert "errors" in result + assert "3" in result + + def test_empty_heading_no_label(self): + result = stats_module.stats_summary(heading="") + assert "Summary" not in result + + def test_multiple_stats_all_present(self): + stats_module.increment_stat("a", 1) + stats_module.increment_stat("b", 2) + result = stats_module.stats_summary() + assert "a" in result + assert "b" in result diff --git a/tests/unit/test_timezone_lookup.py b/tests/unit/test_timezone_lookup.py new file mode 100644 index 0000000..5277228 --- /dev/null +++ b/tests/unit/test_timezone_lookup.py @@ -0,0 +1,95 @@ +"""Tests for the timezone_lookup module (external calls mocked).""" + +from unittest.mock import MagicMock, patch + +from ruf_common import timezone_lookup + +# The module imports Nominatim and TimezoneFinder at the top level, so we must +# patch them within the ruf_common.timezone_lookup namespace. +_NOMINATIM = "ruf_common.timezone_lookup.Nominatim" +_TIMEZONEFINDER = "ruf_common.timezone_lookup.TimezoneFinder" +_TIME_SLEEP = "ruf_common.timezone_lookup.time.sleep" + + +def _make_location(lat=40.7128, lng=-74.0060): + loc = MagicMock() + loc.latitude = lat + loc.longitude = lng + return loc + + +class TestLookupTimezone: + def test_returns_timezone_string_on_success(self): + mock_location = _make_location() + + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf: + mock_nom.return_value.geocode.return_value = mock_location + mock_tf.return_value.timezone_at.return_value = "America/New_York" + + result = timezone_lookup.lookup_timezone("New York", "US") + assert result == "America/New_York" + + def test_returns_none_when_city_not_found(self): + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER): + mock_nom.return_value.geocode.return_value = None + + result = timezone_lookup.lookup_timezone("Nonexistentville", "XX") + assert result is None + + def test_returns_none_on_geocoding_exception(self): + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER): + mock_nom.return_value.geocode.side_effect = Exception("geocode error") + + result = timezone_lookup.lookup_timezone("Berlin", "DE") + assert result is None + + def test_returns_none_when_no_timezone_found(self): + mock_location = _make_location(lat=0.0, lng=0.0) + + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf: + mock_nom.return_value.geocode.return_value = mock_location + mock_tf.return_value.timezone_at.return_value = None + + result = timezone_lookup.lookup_timezone("Nowhere", "ZZ") + assert result is None + + +class TestLookupTimezoneBatch: + def test_returns_dict(self): + mock_location = _make_location(lat=51.5074, lng=-0.1278) + + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf, \ + patch(_TIME_SLEEP): + mock_nom.return_value.geocode.return_value = mock_location + mock_tf.return_value.timezone_at.return_value = "Europe/London" + + result = timezone_lookup.lookup_timezone_batch([("London", "UK")]) + assert isinstance(result, dict) + assert result.get("London, UK") == "Europe/London" + + def test_empty_list_returns_empty_dict(self): + result = timezone_lookup.lookup_timezone_batch([]) + assert result == {} + + def test_failed_lookup_stored_as_none(self): + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER), \ + patch(_TIME_SLEEP): + mock_nom.return_value.geocode.return_value = None + + result = timezone_lookup.lookup_timezone_batch([("Nowhere", "ZZ")]) + assert result.get("Nowhere, ZZ") is None + + def test_multiple_locations(self): + def geocode_side_effect(query, timeout=10): + return _make_location() + + with patch(_NOMINATIM) as mock_nom, patch(_TIMEZONEFINDER) as mock_tf, \ + patch(_TIME_SLEEP): + mock_nom.return_value.geocode.side_effect = geocode_side_effect + mock_tf.return_value.timezone_at.return_value = "UTC" + + result = timezone_lookup.lookup_timezone_batch([ + ("City1", "C1"), + ("City2", "C2"), + ]) + assert len(result) == 2 diff --git a/tests/unit/test_xml_formatter.py b/tests/unit/test_xml_formatter.py new file mode 100644 index 0000000..5074bba --- /dev/null +++ b/tests/unit/test_xml_formatter.py @@ -0,0 +1,145 @@ +"""Tests for the xml_formatter module.""" + + +import pytest + +from ruf_common import xml_formatter + +SIMPLE_XML = "text" +MULTI_CHILD_XML = "123" + + +class TestFormatXmlString: + def test_returns_string(self): + result = xml_formatter.format_xml_string(SIMPLE_XML) + assert isinstance(result, str) + + def test_output_is_valid_xml(self): + import xml.etree.ElementTree as ET + result = xml_formatter.format_xml_string(SIMPLE_XML) + # Should parse without error + ET.fromstring(result.strip()) + + def test_adds_indentation(self): + result = xml_formatter.format_xml_string(SIMPLE_XML) + assert "\n" in result + + def test_preserves_content(self): + result = xml_formatter.format_xml_string(SIMPLE_XML) + assert "text" in result + assert "child" in result + assert "root" in result + + def test_invalid_xml_raises(self): + with pytest.raises((ValueError, Exception)): + xml_formatter.format_xml_string("") + + def test_custom_line_wrap(self): + result = xml_formatter.format_xml_string(SIMPLE_XML, line_wrap_column=120) + assert result is not None + + def test_ends_with_newline(self): + result = xml_formatter.format_xml_string(SIMPLE_XML) + assert result.endswith("\n") + + +class TestFormatXmlFileToString: + @pytest.fixture + def xml_file(self, tmp_path): + f = tmp_path / "test.xml" + f.write_text(SIMPLE_XML, encoding="utf-8") + return str(f) + + def test_returns_formatted_string(self, xml_file): + result = xml_formatter.format_xml_file_to_string(xml_file) + assert isinstance(result, str) + assert "root" in result + + def test_does_not_modify_file(self, xml_file): + original = open(xml_file).read() + xml_formatter.format_xml_file_to_string(xml_file) + assert open(xml_file).read() == original + + +class TestFormatXmlFileProgrammatic: + @pytest.fixture + def xml_file(self, tmp_path): + f = tmp_path / "test.xml" + f.write_text(SIMPLE_XML, encoding="utf-8") + return str(f) + + def test_in_place_false_returns_string(self, xml_file): + result = xml_formatter.format_xml_file_programmatic(xml_file, in_place=False) + assert isinstance(result, str) + assert "root" in result + + def test_in_place_false_does_not_modify(self, xml_file): + original = open(xml_file).read() + xml_formatter.format_xml_file_programmatic(xml_file, in_place=False) + assert open(xml_file).read() == original + + def test_in_place_true_modifies_file(self, xml_file): + result = xml_formatter.format_xml_file_programmatic(xml_file, in_place=True) + assert result is True + content = open(xml_file).read() + assert "\n" in content + + +class TestFormatXmlFolder: + @pytest.fixture + def xml_dir(self, tmp_path): + (tmp_path / "a.xml").write_text(SIMPLE_XML, encoding="utf-8") + (tmp_path / "b.xml").write_text(MULTI_CHILD_XML, encoding="utf-8") + (tmp_path / "other.txt").write_text("not xml", encoding="utf-8") + return str(tmp_path) + + def test_in_place_true_returns_dict_of_bools(self, xml_dir): + result = xml_formatter.format_xml_folder(xml_dir, in_place=True) + assert isinstance(result, dict) + assert len(result) == 2 + for v in result.values(): + assert v is True + + def test_in_place_false_returns_dict_of_strings(self, xml_dir): + result = xml_formatter.format_xml_folder(xml_dir, in_place=False) + assert isinstance(result, dict) + for v in result.values(): + assert isinstance(v, str) + + def test_only_xml_files_processed(self, xml_dir): + result = xml_formatter.format_xml_folder(xml_dir, in_place=False) + for path in result.keys(): + assert path.endswith(".xml") + + +class TestFindXmlFiles: + @pytest.fixture + def dir_with_files(self, tmp_path): + (tmp_path / "a.xml").write_text(SIMPLE_XML) + (tmp_path / "b.xml").write_text(SIMPLE_XML) + (tmp_path / "c.txt").write_text("text") + sub = tmp_path / "sub" + sub.mkdir() + (sub / "d.xml").write_text(SIMPLE_XML) + return str(tmp_path) + + def test_finds_xml_in_dir(self, dir_with_files): + result = xml_formatter.find_xml_files(dir_with_files) + assert len(result) == 2 + + def test_recursive_finds_all(self, dir_with_files): + result = xml_formatter.find_xml_files(dir_with_files, recursive=True) + assert len(result) == 3 + + +class TestWrapXmlElement: + def test_short_line_unchanged(self): + line = '' + result = xml_formatter.wrap_xml_element(line) + assert result == [line] + + def test_long_line_wrapped(self): + long_line = '' + result = xml_formatter.wrap_xml_element(long_line) + assert isinstance(result, list) + assert len(result) >= 1