Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
land_mask: xr.DataArray = None,
time_dim: str = "time",
spatial_dims: Tuple[str, str] = ("lat", "lon"),
patch_size: Tuple[int, int] = (16, 16), # (lat, lon)
patch_size: Tuple[int, int, int] = (1, 16, 16), # (M, lat, lon)
stride: Tuple[int, int] = None,
sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
Expand All @@ -33,12 +33,12 @@ def __init__(
"""Initialize the dataset with daily and monthly data, and optional land mask.

Args:
input_da: xarray DataArray with daily data (M, time, H, W) or hourly data (M, time, H, W)
input_da: xarray DataArray with daily data (time, H, W) or hourly data (time, H, W)
monthly_da: xarray DataArray with monthly data (M, H, W)
land_mask: Optional xarray DataArray with land mask (H, W) or (1, H, W)
time_dim: Name of the time dimension in the input data
spatial_dims: Tuple of (lat_dim, lon_dim) names in the input data
patch_size: Tuple of (patch_height, patch_width) in pixels
patch_size: Tuple of (patch_time, patch_height, patch_width) in time unit and pixels
stride: Tuple of (stride_height, stride_width) in pixels. If None, defaults to patch_size (non-overlapping patches).
is_hourly: Whether the daily data is hourly (T=31*24) or daily (T=31).

Expand All @@ -47,7 +47,7 @@ def __init__(
self.patch_size = patch_size
self.input_da = input_da
self.monthly_da = monthly_da
self.stride = stride if stride is not None else patch_size
self.stride = stride if stride is not None else (patch_size[1], patch_size[2])

self.sh_embed_dim = sh_embed_dim
self.sh_order_L = sh_order_L
Expand All @@ -60,8 +60,8 @@ def __init__(
raise ValueError(f"Spatial dimension '{dim}' not found in input data")

if (
patch_size[0] > input_da.sizes[spatial_dims[0]]
or patch_size[1] > input_da.sizes[spatial_dims[1]]
patch_size[1] > input_da.sizes[spatial_dims[0]]
or patch_size[2] > input_da.sizes[spatial_dims[1]]
):
raise ValueError(
f"Patch size {patch_size} is larger than data dimensions {input_da.sizes}"
Expand Down Expand Up @@ -111,12 +111,12 @@ def __init__(
self.daily_std = None

# Pre-build zero land tensor for the no-mask case
ph, pw = self.patch_size
_, ph, pw = self.patch_size
self._zero_land = torch.zeros(ph, pw, dtype=torch.bool)

# Precompute lazy index mapping for patches
H, W = self.daily_t.shape[2], self.daily_t.shape[3]
self.patch_indices = self._compute_patch_indices(H, W)
M, H, W = self.daily_t.shape[0], self.daily_t.shape[2], self.daily_t.shape[3]
self.patch_indices = self._compute_patch_indices(M, H, W)

# Precompute geoposition and scale embeddings for patches
self.sh_geo_pos = None
Expand All @@ -143,11 +143,19 @@ def _get_geo_pos(self, sh_pos_table: str):
# compatability of L and sh_dim between requested
# and loaded. Raise error if not consistent

def _compute_patch_indices(self, H: int, W: int) -> list:
def _compute_patch_indices(self, M: int, H: int, W: int) -> list:
"""Generate patch start indices with coverage warning (overlap support)."""
ph, pw = self.patch_size
pm, ph, pw = self.patch_size
sh, sw = self.stride

# validate temporal patch size
if pm > M:
raise ValueError(
f"Temporal patch size {pm} is larger than available months {M}."
)
if pm < 1:
raise ValueError(f"Temporal patch size {pm} must be at least 1.")

# Validate stride
if sh > ph or sw > pw:
warnings.warn(
Expand All @@ -158,42 +166,48 @@ def _compute_patch_indices(self, H: int, W: int) -> list:

# Compute patch start indices using stride
# Ensure we don't go out of bounds
m_starts = list(range(0, M - pm + 1, pm)) # Temporal patches are non-overlapping
i_starts = list(range(0, H - ph + 1, sh))
j_starts = list(range(0, W - pw + 1, sw))

# Warn if there's incomplete coverage at the edges
if not i_starts or not j_starts:
if not i_starts or not j_starts or not m_starts:
raise ValueError(
f"No valid patches can be extracted. Image size ({H}, {W}) "
f"No valid patches can be extracted. Image size ({M}, {H}, {W}) "
f"is smaller than patch size {self.patch_size}."
)

# Check edge coverage
last_m = m_starts[-1] + pm
last_i = i_starts[-1] + ph
last_j = j_starts[-1] + pw
if last_i < H or last_j < W:
if last_m < M or last_i < H or last_j < W:
warnings.warn(
f"Patches do not fully cover the image. "
f"Uncovered pixels: {H - last_i} in height, {W - last_j} in width. "
f"Uncovered pixels: {M - last_m} in time, {H - last_i} in height, {W - last_j} in width. "
f"Consider adjusting stride or adding edge patches.",
UserWarning,
)

overlap_h = ph - sh if sh < ph else 0
overlap_w = pw - sw if sw < pw else 0

len_m = len(m_starts)
len_i = len(i_starts)
len_j = len(j_starts)
print(
f"Patch grid: {len(i_starts)} x {len(j_starts)} = {len(i_starts) * len(j_starts)} patches"
f"Patch grid (m x i x j): {len_m} x {len_i} x {len_j} = {len_m * len_i * len_j} patches"
)
print(f"Overlap: {overlap_h} pixels (height), {overlap_w} pixels (width)")

return [(i, j) for i in i_starts for j in j_starts]
return [(m, i, j) for m in m_starts for i in i_starts for j in j_starts]

def _compute_geoscalepatch_embeddings(self):
patch_geo_embeddings = []
patch_scale_features = []

for i, j in self.patch_indices:
ph, pw = self.patch_size
for _, i, j in self.patch_indices:
_, ph, pw = self.patch_size
geo_pos_tensor = self.sh_geo_pos[
i : i + ph,
j : j + pw,
Expand Down Expand Up @@ -227,18 +241,18 @@ def __getitem__(self, idx):
if idx < 0 or idx >= len(self.patch_indices):
raise IndexError("Index out of range")

i, j = self.patch_indices[idx]
ph, pw = self.patch_size
m, i, j = self.patch_indices[idx]
pm, ph, pw = self.patch_size

# Extract spatial patch via slicing — faster than xarray indexing
# (M, T, H, W) -> (M,T,pH, pW)
daily_t_patch = self.daily_t[:, :, i : i + ph, j : j + pw].unsqueeze(0)
daily_t_patch = self.daily_t[m : m + pm, :, i : i + ph, j : j + pw].unsqueeze(0)

# (M, H, W) -> (M, pH, pW)
monthly_t_patch = self.monthly_t[:, i : i + ph, j : j + pw]
monthly_t_patch = self.monthly_t[m : m + pm, i : i + ph, j : j + pw]

# (M, T, H, W) -> (M, T, pH, pW)
daily_nan_mask_t_patch = self.daily_nan_mask_t[:, :, i : i + ph, j : j + pw].unsqueeze(0)
daily_nan_mask_t_patch = self.daily_nan_mask[m : m + pm, :, i : i + ph, j : j + pw].unsqueeze(0)

if self.land_mask_t is not None:
land_t_patch = self.land_mask_t[i : i + ph, j : j + pw] # (H, W)
Expand All @@ -263,18 +277,18 @@ def __getitem__(self, idx):

# Convert to tensors
return {
"daily_patch": daily_t_patch, # (C=1, M, T=31, pH, pW)
"monthly_patch": monthly_t_patch, # (M, pH, pW)
"daily_mask_patch": daily_mask_t_patch, # (C=1, M, T=31, pH, pW)
"daily_patch": daily_t_patch, # (C=1, pm, T=31, pH, pW)
"monthly_patch": monthly_t_patch, # (pm, pH, pW)
"daily_mask_patch": daily_mask_t_patch, # (C=1, pm, T=31, pH, pW)
"land_mask_patch": land_t_patch, # (pH,pW) True=Land
"daily_timef_patch": self.daily_timef_t, # (M, T=31, 2)
"padded_days_mask": self.padded_days_t, # (M, T=31) True=padded
"daily_timef_patch": self.daily_timef_t[m : m + pm], # (pm, T=31, 2)
"padded_days_mask": self.padded_days_t[m : m + pm], # (pm, T=31) True=padded
"scale_feature_patch": scale_feature_t, # (10,)
"geo_pos_embedding_patch": geo_pos_embedding_t, # (sh_embed_dim,)
"sh_embed_dim": self.sh_embed_dim_t,
"harmonic_order": self.harmonic_order_t,
"scale_f_dim": self.scale_f_dim,
"coords": torch.tensor([i, j]),
"coords": torch.tensor([m, i, j]),
"lat_patch": lat_patch, # (pH,)
"lon_patch": lon_patch, # (pW,)
}
Expand All @@ -292,15 +306,15 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
data = self.monthly_t.numpy() # (M, H, W)
else:
# Stack selected spatial patches
ph, pw = self.patch_size
pm, ph, pw = self.patch_size
patches = []
for idx in indices:
i, j = self.patch_indices[idx]
patch = self.monthly_t[:, i : i + ph, j : j + pw].numpy()
m, i, j = self.patch_indices[idx]
patch = self.monthly_t[m : m + pm, i : i + ph, j : j + pw].numpy()
patches.append(patch)
data = np.concatenate(patches, axis=-1)

mean, std = calc_stats(data) # (M,)
mean, std = calc_stats(data) # (pm,)

self.daily_mean = mean
self.daily_std = std
Expand Down
9 changes: 4 additions & 5 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def _save_netcdf(predictions: np.ndarray, dataset: Dataset, save_dir: str):
times = base_dataset.monthly_da.coords["time"].values

full_predictions = np.full(
(M, len(lats), len(lons)), np.nan, dtype=predictions.dtype
(len(times), len(lats), len(lons)), np.nan, dtype=predictions.dtype
)
for i, patch_idx in enumerate(indices):
lat_start, lon_start = base_dataset.patch_indices[patch_idx]
full_predictions[:, lat_start : lat_start + H, lon_start : lon_start + W] = (
month_start, lat_start, lon_start = base_dataset.patch_indices[patch_idx]
full_predictions[month_start : month_start + M, lat_start : lat_start + H, lon_start : lon_start + W] = (
predictions[i]
)

Expand Down Expand Up @@ -106,8 +106,7 @@ def predict_monthly_var(
# Initialize an empty list to store predictions
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset

M = base_dataset.monthly_t.shape[0]
H, W = base_dataset.patch_size
M, H, W = base_dataset.patch_size
all_predictions = torch.empty(len(dataset), M, H, W, device=device)

# Set up logging
Expand Down
14 changes: 0 additions & 14 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ def __init__(
patch_w=4,
hidden=128,
overlap=1,
num_months=12,
dropout=0.0,
):
"""
Expand All @@ -368,7 +367,6 @@ def __init__(
The default is 128, which can be tuned.
overlap: Overlap size for deconvolution. It creates smooth blending
between adjacent upsampled patches. Default is 1, no overlap at edges.
num_months: Number of months. Default is 12.
dropout: Dropout rate for regularization in the refinement block. Default is 0.0.
"""
super().__init__()
Expand Down Expand Up @@ -420,10 +418,6 @@ def __init__(
# Final conv head to map to single-channel output
self.head = nn.Conv2d(out_channels, 1, kernel_size=1)

# Learnable scale and bias (mean and std) to improve predictions
self.scale = nn.Parameter(torch.ones(num_months))
self.bias = nn.Parameter(torch.zeros(num_months))

def forward(self, latent, M, out_H, out_W, land_mask=None):
"""Reconstruct 2D maps from latent patch tokens.
Args:
Expand Down Expand Up @@ -458,11 +452,6 @@ def forward(self, latent, M, out_H, out_W, land_mask=None):

# Apply final conv head to get single channel output
out = self.head(out) # (B*M, 1, H, W)

# Apply scale and bias per month to improve predictions; reshape to (B*M, 1, 1, 1) for broadcasting
scale = self.scale[:M].unsqueeze(0).expand(B, M).reshape(B * M, 1, 1, 1)
bias = self.bias[:M].unsqueeze(0).expand(B, M).reshape(B * M, 1, 1, 1)
out = out * scale + bias
out = out.view(B, M, out_H, out_W) # (B, M, H, W)

# Mask out land areas if land_mask is provided
Expand Down Expand Up @@ -629,7 +618,6 @@ def __init__(
embed_dim=128,
patch_size=(1, 4, 4),
max_months=12,
num_months=12,
hidden=256,
overlap=1,
spatial_depth=2,
Expand All @@ -645,7 +633,6 @@ def __init__(
embed_dim: Dimension of the patch embedding
patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
max_months: Maximum number of months for temporal positional encoding
num_months: Number of months to predict (output channels in decoder)
hidden: Hidden dimension used in the decoder
overlap: Overlap for deconvolution in the decoder
max_H: Maximum spatial height for 2D positional encoding
Expand Down Expand Up @@ -690,7 +677,6 @@ def __init__(
patch_w=patch_size[2],
hidden=hidden,
overlap=overlap,
num_months=num_months,
dropout=dropout,
)
self.patch_size = patch_size
Expand Down
19 changes: 0 additions & 19 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,6 @@ def _run_one_batch(model: torch.nn.Module, batch: dict):
return compute_masked_loss(pred, batch["monthly_patch"], batch["land_mask_patch"])


def _compute_stats(dataset: Dataset):
# check if dataset has indices attribute for stats calculation
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
indices = dataset.indices if hasattr(dataset, "indices") else None
mean, std = base_dataset.compute_stats(indices)
return mean, std


def _initialize_decoder(model: torch.nn.Module, dataset: Dataset):
mean, std = _compute_stats(dataset)
decoder = model.module.decoder if hasattr(model, "module") else model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

return model


def train_monthly_model(
model: torch.nn.Module,
dataset: Dataset,
Expand Down Expand Up @@ -77,7 +59,6 @@ def train_monthly_model(
"""
# Initialize the model
model = model.to(device)
model = _initialize_decoder(model, dataset)

# Create data loader
use_cuda = device == "cuda"
Expand Down