Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 208 additions & 18 deletions lib/src/rdswrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class RdsReader {

std::string get_rtype() const {
if (!ptr) throw std::runtime_error("Null pointer in 'get_rtype'.");
// py::print("arg::", static_cast<int>(ptr->type()));
switch (ptr->type()) {
case rds2cpp::SEXPType::S4: return "S4";
case rds2cpp::SEXPType::INT: return "integer";
Expand Down Expand Up @@ -239,29 +238,220 @@ class RdaObject {
}
};

// ---- writers ----

std::unique_ptr<rds2cpp::RObject> py_to_robject(const py::object& obj, std::vector<rds2cpp::Symbol>& symbols);

void add_names_attribute(
std::vector<rds2cpp::Attribute>& attributes,
const py::list& names,
std::vector<rds2cpp::Symbol>& symbols)
{
auto svec = std::make_unique<rds2cpp::StringVector>();
for (size_t i = 0; i < py::len(names); ++i) {
auto item = names[i];
if (item.is_none()) {
svec->data.emplace_back();
} else {
svec->data.emplace_back(item.cast<std::string>(), rds2cpp::StringEncoding::UTF8);
}
}
attributes.emplace_back(
rds2cpp::register_symbol("names", rds2cpp::StringEncoding::UTF8, symbols),
std::move(svec)
);
}

std::unique_ptr<rds2cpp::RObject> py_to_robject(const py::object& obj, std::vector<rds2cpp::Symbol>& symbols) {
// None -> Null
if (obj.is_none()) {
return std::make_unique<rds2cpp::Null>();
}

// numpy array
if (py::isinstance<py::array>(obj)) {
auto arr = obj.cast<py::array>();
auto dtype = arr.dtype();

// bool arrays
if (dtype.is(py::dtype::of<bool>())) {
auto buf = arr.cast<py::array_t<bool, py::array::c_style | py::array::forcecast>>();
auto r = buf.unchecked<1>();
auto vec = std::make_unique<rds2cpp::LogicalVector>();

vec->data.reserve(r.shape(0));
for (ssize_t i = 0; i < r.shape(0); ++i) {
vec->data.push_back(r(i) ? 1 : 0);
}

return vec;
}

// integer arrays
if (py::isinstance<py::array_t<int32_t>>(arr) ||
py::isinstance<py::array_t<int64_t>>(arr) ||
py::isinstance<py::array_t<int16_t>>(arr) ||
py::isinstance<py::array_t<int8_t>>(arr)) {
auto buf = arr.cast<py::array_t<int32_t, py::array::c_style | py::array::forcecast>>();
auto r = buf.unchecked<1>();
auto vec = std::make_unique<rds2cpp::IntegerVector>();

vec->data.reserve(r.shape(0));
for (ssize_t i = 0; i < r.shape(0); ++i) {
vec->data.push_back(r(i));
}

return vec;
}

// float arrays
if (py::isinstance<py::array_t<double>>(arr) ||
py::isinstance<py::array_t<float>>(arr)) {
auto buf = arr.cast<py::array_t<double, py::array::c_style | py::array::forcecast>>();
auto r = buf.unchecked<1>();
auto vec = std::make_unique<rds2cpp::DoubleVector>();

vec->data.reserve(r.shape(0));
for (ssize_t i = 0; i < r.shape(0); ++i) {
vec->data.push_back(r(i));
}
return vec;
}

throw std::runtime_error("Unsupported numpy dtype for RDS writing");
}

// dict -> GenericVector with names attribute
if (py::isinstance<py::dict>(obj)) {
auto d = obj.cast<py::dict>();
auto gvec = std::make_unique<rds2cpp::GenericVector>();

py::list keys;
for (auto& item : d) {
keys.append(item.first);
gvec->data.push_back(py_to_robject(py::reinterpret_borrow<py::object>(item.second), symbols));
}
add_names_attribute(gvec->attributes, keys, symbols);

return gvec;
}

// list
if (py::isinstance<py::list>(obj)) {
auto lst = obj.cast<py::list>();
if (py::len(lst) == 0) {
return std::make_unique<rds2cpp::GenericVector>();
}

// Check if all elements are strings (or None) -> StringVector
bool all_strings = true;
for (size_t i = 0; i < py::len(lst); ++i) {
auto item = lst[i];
if (!item.is_none() && !py::isinstance<py::str>(item)) {
all_strings = false;
break;
}
}

if (all_strings) {
auto svec = std::make_unique<rds2cpp::StringVector>();
for (size_t i = 0; i < py::len(lst); ++i) {
auto item = lst[i];
if (item.is_none()) {
svec->data.emplace_back();
} else {
svec->data.emplace_back(item.cast<std::string>(), rds2cpp::StringEncoding::UTF8);
}
}

return svec;
}

// Otherwise -> GenericVector
auto gvec = std::make_unique<rds2cpp::GenericVector>();
for (size_t i = 0; i < py::len(lst); ++i) {
gvec->data.push_back(py_to_robject(lst[i].cast<py::object>(), symbols));
}

return gvec;
}

// bool check before int, since bool is a subclass of int
if (py::isinstance<py::bool_>(obj)) {
auto vec = std::make_unique<rds2cpp::LogicalVector>();
vec->data.push_back(obj.cast<bool>() ? 1 : 0);
return vec;
}

if (py::isinstance<py::int_>(obj)) {
auto vec = std::make_unique<rds2cpp::IntegerVector>();
vec->data.push_back(obj.cast<int32_t>());
return vec;
}

if (py::isinstance<py::float_>(obj)) {
auto vec = std::make_unique<rds2cpp::DoubleVector>();
vec->data.push_back(obj.cast<double>());
return vec;
}

if (py::isinstance<py::str>(obj)) {
auto svec = std::make_unique<rds2cpp::StringVector>();
svec->data.emplace_back(obj.cast<std::string>(), rds2cpp::StringEncoding::UTF8);
return svec;
}

throw std::runtime_error("Unsupported Python type for RDS writing: " + std::string(py::str(obj.get_type())));
}

void write_rds_file(const py::object& obj, const std::string& path) {
rds2cpp::RdsFile file_info;
file_info.object = py_to_robject(obj, file_info.symbols);
rds2cpp::WriteRdsOptions options;
rds2cpp::write_rds(file_info, path, options);
}

void write_rda_file(const py::dict& objects, const std::string& path) {
rds2cpp::RdaFile file_info;
for (auto& item : objects) {
auto name = item.first.cast<std::string>();
auto sym = rds2cpp::register_symbol(name, rds2cpp::StringEncoding::UTF8, file_info.symbols);
auto value = py_to_robject(py::reinterpret_borrow<py::object>(item.second), file_info.symbols);
file_info.objects.emplace_back(std::move(sym), std::move(value));
}
rds2cpp::WriteRdaOptions options;
rds2cpp::write_rda(file_info, path, options);
}

PYBIND11_MODULE(lib_rds_parser, m) {
py::register_exception<std::runtime_error>(m, "RdsParserError");

py::class_<RdsObject>(m, "RdsObject")
.def(py::init<const std::string&>())
.def("get_robject", &RdsObject::get_robject, py::return_value_policy::reference_internal);
.def(py::init<const std::string&>())
.def("get_robject", &RdsObject::get_robject, py::return_value_policy::reference_internal);

py::class_<RdaObject>(m, "RdaObject")
.def(py::init<const std::string&>())
.def("get_object_names", &RdaObject::get_object_names)
.def("get_object_count", &RdaObject::get_object_count)
.def("get_object_by_index", &RdaObject::get_object_by_index, py::return_value_policy::take_ownership, py::keep_alive<0, 1>())
.def("get_object_by_name", &RdaObject::get_object_by_name, py::return_value_policy::take_ownership, py::keep_alive<0, 1>());
.def(py::init<const std::string&>())
.def("get_object_names", &RdaObject::get_object_names)
.def("get_object_count", &RdaObject::get_object_count)
.def("get_object_by_index", &RdaObject::get_object_by_index, py::return_value_policy::take_ownership, py::keep_alive<0, 1>())
.def("get_object_by_name", &RdaObject::get_object_by_name, py::return_value_policy::take_ownership, py::keep_alive<0, 1>());

py::class_<RdsReader>(m, "RdsReader")
.def("get_rtype", &RdsReader::get_rtype)
.def("get_rsize", &RdsReader::get_rsize)
.def("get_numeric_data", &RdsReader::get_numeric_data)
.def("get_string_arr", &RdsReader::get_string_arr)
.def("get_attribute_names", &RdsReader::get_attribute_names)
.def("load_attribute_by_name", &RdsReader::load_attribute_by_name)
.def("load_vec_element", &RdsReader::load_vec_element)
.def("get_package_name", &RdsReader::get_package_name)
.def("get_class_name", &RdsReader::get_class_name)
.def("get_dimensions", &RdsReader::get_dimensions);
.def("get_rtype", &RdsReader::get_rtype)
.def("get_rsize", &RdsReader::get_rsize)
.def("get_numeric_data", &RdsReader::get_numeric_data)
.def("get_string_arr", &RdsReader::get_string_arr)
.def("get_attribute_names", &RdsReader::get_attribute_names)
.def("load_attribute_by_name", &RdsReader::load_attribute_by_name)
.def("load_vec_element", &RdsReader::load_vec_element)
.def("get_package_name", &RdsReader::get_package_name)
.def("get_class_name", &RdsReader::get_class_name)
.def("get_dimensions", &RdsReader::get_dimensions);

m.def("write_rds", &write_rds_file, "Write a Python object to an RDS file",
py::arg("obj"), py::arg("path"));

m.def("write_rda", &write_rda_file, "Write named Python objects to an RData file",
py::arg("objects"), py::arg("path"));
}
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ package_dir =
=src

# Require a min/specific Python version (comma-separated conditions)
python_requires = >=3.9
python_requires = >=3.10

# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0.
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
Expand All @@ -50,7 +50,7 @@ python_requires = >=3.9
install_requires =
importlib-metadata; python_version<"3.8"
numpy
biocutils>=0.1.5
biocutils>=0.4.1

[options.packages.find]
where = src
Expand Down
4 changes: 2 additions & 2 deletions src/rds2py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
del version, PackageNotFoundError


from .generics import read_rds, read_rda
from .rdsutils import parse_rds, parse_rda
from .generics import read_rds, read_rda, save_rds
from .rdsutils import parse_rds, parse_rda, write_rds, write_rda
50 changes: 33 additions & 17 deletions src/rds2py/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
print(type(data))
"""

from functools import singledispatch
from importlib import import_module
from typing import List, Optional
from typing import Any, List, Optional
from warnings import warn

from .rdsutils import get_class, parse_rda, parse_rds
Expand Down Expand Up @@ -72,22 +73,6 @@
}


# @singledispatch
# def save_rds(x, path: str):
# """Save a Python object as RDS file.

# Args:
# x:
# Object to save.

# path:
# Path to save the object.
# """
# raise NotImplementedError(
# f"No `save_rds` method implemented for '{type(x).__name__}' objects."
# )


def read_rds(path: str, **kwargs):
"""Read an RDS file and convert it to an appropriate Python object.

Expand Down Expand Up @@ -177,3 +162,34 @@ def _dispatcher(robject: dict, **kwargs):
)

return robject


@singledispatch
def save_rds(x: Any, path: Optional[str] = None):
"""Save a Python object as RDS file.

Args:
x:
Object to save.

path:
Path to save the object. If ``None``, returns the converted representation.
"""
raise NotImplementedError(f"No `save_rds` method implemented for '{type(x).__name__}' objects.")


# Import all modules with save_rds registrations to ensure they are loaded
from . import ( # noqa: E402
save_atomic, # noqa: F401
save_compressed_list, # noqa: F401
save_delayed_matrix, # noqa: F401
save_dict, # noqa: F401
save_factor, # noqa: F401
save_frame, # noqa: F401
save_granges, # noqa: F401
save_mae, # noqa: F401
save_matrix, # noqa: F401
save_rle, # noqa: F401
save_sce, # noqa: F401
save_se, # noqa: F401
)
36 changes: 35 additions & 1 deletion src/rds2py/rdsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
information from parsed objects.
"""

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

from .lib_rds_parser import write_rda as _write_rda_native
from .PyRdaReader import PyRdaParser
from .PyRdsReader import PyRdsParser

Expand Down Expand Up @@ -57,6 +58,39 @@ def parse_rda(path: str, objects: Optional[List[str]] = None) -> Dict[str, dict]
return result


def write_rds(obj: Any, path: str) -> None:
"""Write a Python object to RDS file.

Args:
obj:
The Python object to write.

path:
Output file path.
"""
from .generics import save_rds

save_rds(obj, path)


def write_rda(objects: Dict[str, Any], path: str) -> None:
"""Write multiple named Python objects to a gzip-compressed RData file.

Each value is converted using :py:func:`~.write_rds`.

Args:
objects:
Dictionary mapping variable names to Python objects.

path:
Output file path.
"""
from .generics import save_rds

converted = {str(k): save_rds(v) for k, v in objects.items()}
_write_rda_native(converted, path)


def get_class(robj: dict) -> str:
"""Infer the R class name from a parsed RDS object.

Expand Down
Loading
Loading