"""
This module contains examples that loads data and generates images.
"""
import timeit
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from shutil import rmtree
from tempfile import mkdtemp
from typing import List, Optional, Union
from uuid import uuid4
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from pydantic import BaseModel, Extra
np.random.seed(20280808)
[docs]class Example(BaseModel, extra=Extra.allow):
"""
The base example class containing most of the common inputs and methods used in
other examples inheriting from this. Note, this should not to be used directly.
Args:
length: The number of items in the data.
scratch_directory: The base directory to create the temporary directory
for intermediary files.
to_bytes_io: If True, save output to `BytesIO`; if False, save to disk.
"""
length: int = 1000
scratch_directory: Optional[Path] = None
to_bytes_io: bool = False
_temporary_directory: Path = None
def __init__(self, **data):
super().__init__(**data)
self._temporary_directory = Path(
mkdtemp(prefix="enjoyn_", dir=self.scratch_directory)
)
[docs] @contextmanager
def time_run(self):
"""
A context manager for tracking and printing the runtime.
Returns:
The time it took to complete the run in seconds.
"""
start = timeit.default_timer()
yield
stop = timeit.default_timer()
runtime = stop - start
print(f"Runtime: {runtime} seconds")
[docs] def cleanup_images(self):
"""
Deletes the temporary directory.
"""
if self._temporary_directory.exists():
rmtree(self._temporary_directory)
[docs] def size_of(self, file: Union[Path, str]):
"""
Gets the size of a file in MBs.
"""
path = Path(file)
file_size = path.stat().st_size / 1024 / 1024
print(f"File size of {path.name}: {file_size:.2f} MBs")
[docs]class RandomWalkExample(Example):
"""
An example related to a numpy array of random coordinates.
Args:
length: The number of items in the data.
scratch_directory: The base directory to create the temporary directory
for intermediary files.
"""
[docs] def load_data(self) -> np.ndarray:
"""
Loads a `(self.length, 2)` shaped array.
"""
start = np.random.random(2)
steps = np.random.uniform(-0.2, 0.2, size=(self.length, 2))
data = start + np.cumsum(steps, axis=0)
return data
[docs] def plot_image(self, data_subset: np.ndarray) -> Union[BytesIO, Path]:
"""
Plots an image from the data subset.
Args:
data_subset: The subset data array; should be shaped (n, 2).
Returns:
The output image as `BytesIO` or `Path`.
"""
fig = plt.figure(figsize=(8, 8))
ax = plt.axes()
x, y = zip(*data_subset)
ax.plot(x, y)
if self.to_bytes_io:
output = BytesIO()
else:
output = self._temporary_directory / f"{uuid4().hex}.png"
fig.savefig(output, transparent=False, facecolor="white")
plt.close()
return output
[docs] def output_images(self) -> List[Union[BytesIO, Path]]:
"""
Outputs a list of images as `BytesIO` or `Path`.
"""
data = self.load_data()
outputs = [self.plot_image(data[:i]) for i in np.arange(1, len(data) + 1)]
return outputs
[docs]class AirTemperatureExample(Example):
"""
An example related to an xarray Dataset of air temperatures.
Args:
length: The number of items in the data.
scratch_directory: The base directory to create the temporary directory
for intermediary files.
"""
length: int = 2920
[docs] def load_data(self) -> xr.Dataset:
"""
Loads an xarray Dataset.
"""
ds = xr.tutorial.open_dataset("air_temperature").chunk({"time": 10})
ds = ds.isel(time=slice(None, self.length))
return ds
[docs] def plot_image(self, ds_sel: xr.Dataset) -> Union[BytesIO, Path]:
"""
Plots an image from the data subset.
Args:
data_subset: The subset dataset; should be shaped (x, y).
Returns:
The output image as `BytesIO` or `Path`.
"""
fig = plt.figure(figsize=(12, 8))
ax = plt.axes()
img = ax.contourf(
ds_sel["lon"],
ds_sel["lat"],
ds_sel["air"],
cmap="RdBu_r",
levels=range(220, 320, 10),
)
plt.colorbar(img)
title = ds_sel["time"].dt.strftime("%H:%MZ %Y-%m-%d").item()
ax.set_title(title)
if self.to_bytes_io:
output = BytesIO()
else:
output = self._temporary_directory / f"{uuid4().hex}.png"
fig.savefig(output, transparent=False, facecolor="white")
plt.close()
return output
[docs] def output_images(self) -> List[Union[BytesIO, Path]]:
"""
Outputs a list of images as `BytesIO` or `Path`.
"""
ds = self.load_data()
outputs = [self.plot_image(ds.sel(time=time)) for time in ds["time"].values]
return outputs