diff --git a/LoopStructural/__init__.py b/LoopStructural/__init__.py index 147a984f..ca6c3855 100644 --- a/LoopStructural/__init__.py +++ b/LoopStructural/__init__.py @@ -19,6 +19,7 @@ ch.setLevel(logging.WARNING) loggers = {} from .modelling.core.geological_model import GeologicalModel +from .modelling.core.stratigraphic_column import StratigraphicColumn from .interpolators._api import LoopInterpolator from .interpolators import InterpolatorBuilder from .datatypes import BoundingBox @@ -28,26 +29,43 @@ logger.info("Imported LoopStructural") -def setLogging(level="info"): +def setLogging(level="info", handler=None): """ - Set the logging parameters for log file + Set the logging parameters for log file or custom handler Parameters ---------- - filename : string - name of file or path to file - level : str, optional - 'info', 'warning', 'error', 'debug' mapped to logging levels, by default 'info' + level : str + 'info', 'warning', 'error', 'debug' + handler : logging.Handler, optional + A logging handler to use instead of the default StreamHandler """ import LoopStructural - logger = getLogger(__name__) - levels = get_levels() - level = levels.get(level, logging.WARNING) - LoopStructural.ch.setLevel(level) + level_value = levels.get(level, logging.WARNING) + + # Create default handler if none provided + if handler is None: + handler = logging.StreamHandler() + + formatter = logging.Formatter( + "%(levelname)s: %(asctime)s: %(filename)s:%(lineno)d -- %(message)s" + ) + handler.setFormatter(formatter) + handler.setLevel(level_value) + # Replace handlers in all known loggers for name in LoopStructural.loggers: logger = logging.getLogger(name) - logger.setLevel(level) - logger.info(f'Set logging to {level}') + logger.handlers = [] + logger.addHandler(handler) + logger.setLevel(level_value) + + # Also apply to main module logger + main_logger = logging.getLogger(__name__) + main_logger.handlers = [] + main_logger.addHandler(handler) + main_logger.setLevel(level_value) + + main_logger.info(f"Set logging to {level}") diff --git a/LoopStructural/modelling/core/geological_model.py b/LoopStructural/modelling/core/geological_model.py index d2131e8b..c7ec0365 100644 --- a/LoopStructural/modelling/core/geological_model.py +++ b/LoopStructural/modelling/core/geological_model.py @@ -37,7 +37,7 @@ from ...modelling.intrusions import IntrusionBuilder from ...modelling.intrusions import IntrusionFrameBuilder - +from .stratigraphic_column import StratigraphicColumn logger = getLogger(__name__) @@ -61,14 +61,11 @@ class GeologicalModel: the origin of the model box parameters : dict a dictionary tracking the parameters used to build the model - + """ - def __init__( - self, - *args - ): + def __init__(self, *args): """ Parameters ---------- @@ -78,7 +75,7 @@ def __init__( the origin of the model maximum : np.array(3,dtype=doubles) the maximum of the model - + Examples -------- Demo data @@ -126,7 +123,8 @@ def __init__( self.feature_name_index = {} self._data = pd.DataFrame() # None - self.stratigraphic_column = None + self.stratigraphic_column = StratigraphicColumn() + self.tol = 1e-10 * np.max(self.bounding_box.maximum - self.bounding_box.origin) self._dtm = None @@ -148,29 +146,6 @@ def to_dict(self): # json["features"] = [f.to_json() for f in self.features] return json - # @classmethod - # def from_json(cls,json): - # """ - # Create a geological model from a json string - - # Parameters - # ---------- - # json : str - # json string of the geological model - - # Returns - # ------- - # model : GeologicalModel - # a geological model - # """ - # model = cls(json["model"]["origin"],json["model"]["maximum"],data=None) - # model.stratigraphic_column = json["model"]["stratigraphic_column"] - # model.nsteps = json["model"]["nsteps"] - # model.data = pd.read_json(json["model"]["data"]) - # model.features = [] - # for feature in json["features"]: - # model.features.append(GeologicalFeature.from_json(feature,model)) - # return model def __str__(self): return f"GeologicalModel with {len(self.features)} features" @@ -181,6 +156,38 @@ def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: data = data.copy() data[['X', 'Y', 'Z']] = self.bounding_box.project(data[['X', 'Y', 'Z']].to_numpy()) + if "type" in data: + logger.warning("'type' is deprecated replace with 'feature_name' \n") + data.rename(columns={"type": "feature_name"}, inplace=True) + if "feature_name" not in data: + logger.error("Data does not contain 'feature_name' column") + raise BaseException("Cannot load data") + for h in all_heading(): + if h not in data: + data[h] = np.nan + if h == "w": + data[h] = 1.0 + if h == "coord": + data[h] = 0 + if h == "polarity": + data[h] = 1.0 + # LS wants polarity as -1 or 1, change 0 to -1 + data.loc[data["polarity"] == 0, "polarity"] = -1.0 + data.loc[np.isnan(data["w"]), "w"] = 1.0 + if "strike" in data and "dip" in data: + logger.info("Converting strike and dip to vectors") + mask = np.all(~np.isnan(data.loc[:, ["strike", "dip"]]), axis=1) + data.loc[mask, gradient_vec_names()] = ( + strikedip2vector(data.loc[mask, "strike"], data.loc[mask, "dip"]) + * data.loc[mask, "polarity"].to_numpy()[:, None] + ) + data.drop(["strike", "dip"], axis=1, inplace=True) + data[['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz']] = data[ + ['X', 'Y', 'Z', 'val', 'nx', 'ny', 'nz', 'gx', 'gy', 'gz', 'tx', 'ty', 'tz'] + ].astype(float) + return data + + if "type" in data: logger.warning("'type' is deprecated replace with 'feature_name' \n") data.rename(columns={"type": "feature_name"}, inplace=True) @@ -403,7 +410,6 @@ def fault_names(self): return [f.name for f in self.faults] - def to_file(self, file): """Save a model to a pickle file requires dill @@ -501,7 +507,6 @@ def data(self, data: pd.DataFrame): # self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy()) - def set_model_data(self, data): logger.warning("deprecated method. Model data can now be set using the data attribute") self.data = data.copy() @@ -527,28 +532,34 @@ def set_stratigraphic_column(self, stratigraphic_column, cmap="tab20"): } """ + self.stratigraphic_column.clear() # if the colour for a unit hasn't been specified we can just sample from # a colour map e.g. tab20 logger.info("Adding stratigraphic column to model") - random_colour = True - n_units = 0 + DeprecationWarning( + "set_stratigraphic_column is deprecated, use model.stratigraphic_column.add_units instead" + ) for g in stratigraphic_column.keys(): for u in stratigraphic_column[g].keys(): - if "colour" in stratigraphic_column[g][u]: - random_colour = False - break - n_units += 1 - if random_colour: - import matplotlib.cm as cm - - cmap = cm.get_cmap(cmap, n_units) - cmap_colours = cmap.colors - ci = 0 - for g in stratigraphic_column.keys(): - for u in stratigraphic_column[g].keys(): - stratigraphic_column[g][u]["colour"] = cmap_colours[ci, :] - ci += 1 - self.stratigraphic_column = stratigraphic_column + thickness = 0 + if "min" in stratigraphic_column[g][u] and "max" in stratigraphic_column[g][u]: + min_val = stratigraphic_column[g][u]["min"] + max_val = stratigraphic_column[g][u].get("max", None) + thickness = max_val - min_val if max_val is not None else None + logger.warning( + f""" + model.stratigraphic_column.add_unit({u}, + colour={stratigraphic_column[g][u].get("colour", None)}, + thickness={thickness})""" + ) + self.stratigraphic_column.add_unit( + u, + colour=stratigraphic_column[g][u].get("colour", None), + thickness=thickness, + ) + self.stratigraphic_column.add_unconformity( + name=''.join([g, 'unconformity']), + ) def create_and_add_foliation( self, @@ -595,7 +606,7 @@ def create_and_add_foliation( An interpolator will be chosen by calling :meth:`LoopStructural.GeologicalModel.get_interpolator` """ - + # if tol is not specified use the model default if tol is None: tol = self.tol @@ -631,7 +642,7 @@ def create_and_add_foliation( def create_and_add_fold_frame( self, - fold_frame_name:str, + fold_frame_name: str, *, fold_frame_data=None, interpolatortype="FDI", @@ -660,14 +671,14 @@ def create_and_add_fold_frame( :class:`LoopStructural.modelling.features.builders.StructuralFrameBuilder` and :meth:`LoopStructural.modelling.features.builders.StructuralFrameBuilder.setup` and the interpolator, such as `domain` or `tol` - + Returns ------- fold_frame : FoldFrame the created fold frame """ - + if tol is None: tol = self.tol @@ -743,7 +754,7 @@ def create_and_add_folded_foliation( :class:`LoopStructural.modelling.features.builders.FoldedFeatureBuilder` """ - + if tol is None: tol = self.tol @@ -772,11 +783,8 @@ def create_and_add_folded_foliation( if foliation_data.shape[0] == 0: logger.warning(f"No data for {foliation_name}, skipping") return - series_builder.add_data_from_data_frame( - self.prepare_data( - foliation_data - ) - ) + series_builder.add_data_from_data_frame(self.prepare_data(foliation_data)) + self._add_faults(series_builder) # series_builder.add_data_to_interpolator(True) # build feature @@ -843,7 +851,7 @@ def create_and_add_folded_fold_frame( see :class:`LoopStructural.modelling.features.fold.FoldEvent`, :class:`LoopStructural.modelling.features.builders.FoldedFeatureBuilder` """ - + if tol is None: tol = self.tol @@ -1170,7 +1178,7 @@ def add_onlap_unconformity(self, feature: GeologicalFeature, value: float) -> Ge return uc_feature def create_and_add_domain_fault( - self, fault_surface_data,*, nelements=10000, interpolatortype="FDI", **kwargs + self, fault_surface_data, *, nelements=10000, interpolatortype="FDI", **kwargs ): """ Parameters @@ -1224,7 +1232,7 @@ def create_and_add_fault( fault_name: str, displacement: float, *, - fault_data:Optional[pd.DataFrame] = None, + fault_data: Optional[pd.DataFrame] = None, interpolatortype="FDI", tol=None, fault_slip_vector=None, @@ -1309,7 +1317,7 @@ def create_and_add_fault( if "data_region" in kwargs: kwargs.pop("data_region") logger.error("kwarg data_region currently not supported, disabling") - displacement_scaled = displacement + displacement_scaled = displacement fault_frame_builder = FaultBuilder( interpolatortype, bounding_box=self.bounding_box, @@ -1330,11 +1338,11 @@ def create_and_add_fault( if fault_center is not None and ~np.isnan(fault_center).any(): fault_center = self.scale(fault_center, inplace=False) if minor_axis: - minor_axis = minor_axis + minor_axis = minor_axis if major_axis: - major_axis = major_axis + major_axis = major_axis if intermediate_axis: - intermediate_axis = intermediate_axis + intermediate_axis = intermediate_axis fault_frame_builder.create_data_from_geometry( fault_frame_data=self.prepare_data(fault_data), fault_center=fault_center, @@ -1390,7 +1398,8 @@ def rescale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: """ - return self.bounding_box.reproject(points,inplace=inplace) + return self.bounding_box.reproject(points, inplace=inplace) + # TODO move scale to bounding box/transformer def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: @@ -1408,7 +1417,8 @@ def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray: points : np.a::rray((N,3),dtype=double) """ - return self.bounding_box.project(np.array(points).astype(float),inplace=inplace) + return self.bounding_box.project(np.array(points).astype(float), inplace=inplace) + def regular_grid(self, *, nsteps=None, shuffle=True, rescale=False, order="C"): """ @@ -1557,7 +1567,7 @@ def evaluate_fault_displacements(self, points, scale=True): if f.type == FeatureType.FAULT: disp = f.displacementfeature.evaluate_value(points) vals[~np.isnan(disp)] += disp[~np.isnan(disp)] - return vals # convert from restoration magnutude to displacement + return vals # convert from restoration magnutude to displacement def get_feature_by_name(self, feature_name) -> GeologicalFeature: """Returns a feature from the mode given a name @@ -1726,30 +1736,15 @@ def get_stratigraphic_surfaces(self, units: List[str] = [], bottoms: bool = True units = [] if self.stratigraphic_column is None: return [] - for group in self.stratigraphic_column.keys(): - if group == "faults": + units = self.stratigraphic_column.get_isovalues() + for name, u in units.items(): + if u['group'] not in self: + logger.warning(f"Group {u['group']} not found in model") continue - for series in self.stratigraphic_column[group].values(): - series['feature_name'] = group - units.append(series) - unit_table = pd.DataFrame(units) - for u in unit_table['feature_name'].unique(): - - values = unit_table.loc[unit_table['feature_name'] == u, 'min' if bottoms else 'max'] - if 'name' not in unit_table.columns: - unit_table['name'] = unit_table['feature_name'] - - names = unit_table[unit_table['feature_name'] == u]['name'] - values = values.loc[~np.logical_or(values == np.inf, values == -np.inf)] + feature = self.get_feature_by_name(u['group']) + surfaces.extend( - self.get_feature_by_name(u).surfaces( - values.to_list(), - self.bounding_box, - name=names.loc[values.index].to_list(), - colours=unit_table.loc[unit_table['feature_name'] == u, 'colour'].tolist()[ - 1: - ], # we don't isosurface basement, no value - ) + feature.surfaces([u['value']], self.bounding_box, name=name, colours=[u['colour']]) ) return surfaces diff --git a/LoopStructural/modelling/core/stratigraphic_column.py b/LoopStructural/modelling/core/stratigraphic_column.py new file mode 100644 index 00000000..02c8caf5 --- /dev/null +++ b/LoopStructural/modelling/core/stratigraphic_column.py @@ -0,0 +1,473 @@ +import enum +from typing import Dict +import numpy as np +from LoopStructural.utils import rng, getLogger + +logger = getLogger(__name__) +logger.info("Imported LoopStructural Stratigraphic Column module") +class UnconformityType(enum.Enum): + """ + An enumeration for different types of unconformities in a stratigraphic column. + """ + + ERODE = 'erode' + ONLAP = 'onlap' + + +class StratigraphicColumnElementType(enum.Enum): + """ + An enumeration for different types of elements in a stratigraphic column. + """ + + UNIT = 'unit' + UNCONFORMITY = 'unconformity' + + +class StratigraphicColumnElement: + """ + A class to represent an element in a stratigraphic column, which can be a unit or a topological object + for example unconformity. + """ + + def __init__(self, uuid=None): + """ + Initializes the StratigraphicColumnElement with a uuid. + """ + if uuid is None: + import uuid as uuid_module + + uuid = str(uuid_module.uuid4()) + self.uuid = uuid + + +class StratigraphicUnit(StratigraphicColumnElement): + """ + A class to represent a stratigraphic unit. + """ + + def __init__(self, *, uuid=None, name=None, colour=None, thickness=None, data=None): + """ + Initializes the StratigraphicUnit with a name and an optional description. + """ + super().__init__(uuid) + self.name = name + if colour is None: + colour = rng.random(3) + self.colour = colour + self.thickness = thickness + self.data = data + self.element_type = StratigraphicColumnElementType.UNIT + + def to_dict(self): + """ + Converts the stratigraphic unit to a dictionary representation. + """ + colour = self.colour + if isinstance(colour, np.ndarray): + colour = colour.astype(float).tolist() + return {"name": self.name, "colour": colour, "thickness": self.thickness, 'uuid': self.uuid} + + @classmethod + def from_dict(cls, data): + """ + Creates a StratigraphicUnit from a dictionary representation. + """ + if not isinstance(data, dict): + raise TypeError("Data must be a dictionary") + name = data.get("name") + colour = data.get("colour") + thickness = data.get("thickness", None) + uuid = data.get("uuid", None) + return cls(uuid=uuid, name=name, colour=colour, thickness=thickness) + + def __str__(self): + """ + Returns a string representation of the stratigraphic unit. + """ + return ( + f"StratigraphicUnit(name={self.name}, colour={self.colour}, thickness={self.thickness})" + ) + + +class StratigraphicUnconformity(StratigraphicColumnElement): + """ + A class to represent a stratigraphic unconformity, which is a surface of discontinuity in the stratigraphic record. + """ + + def __init__( + self, *, uuid=None, name=None, unconformity_type: UnconformityType = UnconformityType.ERODE + ): + """ + Initializes the StratigraphicUnconformity with a name and an optional description. + """ + super().__init__(uuid) + + self.name = name + if unconformity_type not in [UnconformityType.ERODE, UnconformityType.ONLAP]: + raise ValueError("Invalid unconformity type") + self.unconformity_type = unconformity_type + self.element_type = StratigraphicColumnElementType.UNCONFORMITY + + def to_dict(self): + """ + Converts the stratigraphic unconformity to a dictionary representation. + """ + return { + "uuid": self.uuid, + "name": self.name, + "unconformity_type": self.unconformity_type.value, + } + + def __str__(self): + """ + Returns a string representation of the stratigraphic unconformity. + """ + return ( + f"StratigraphicUnconformity(name={self.name}, " + f"unconformity_type={self.unconformity_type.value})" + ) + + @classmethod + def from_dict(cls, data): + """ + Creates a StratigraphicUnconformity from a dictionary representation. + """ + if not isinstance(data, dict): + raise TypeError("Data must be a dictionary") + name = data.get("name") + unconformity_type = UnconformityType( + data.get("unconformity_type", UnconformityType.ERODE.value) + ) + uuid = data.get("uuid", None) + return cls(uuid=uuid, name=name, unconformity_type=unconformity_type) +class StratigraphicGroup: + """ + A class to represent a group of stratigraphic units. + This class is not fully implemented and serves as a placeholder for future development. + """ + + def __init__(self, name=None, units=None): + """ + Initializes the StratigraphicGroup with a name and an optional list of units. + """ + self.name = name + self.units = units if units is not None else [] + + +class StratigraphicColumn: + """ + A class to represent a stratigraphic column, which is a vertical section of the Earth's crust + showing the sequence of rock layers and their relationships. + """ + + def __init__(self): + """ + Initializes the StratigraphicColumn with a name and a list of layers. + """ + self.order = [StratigraphicUnit(name='Basement', colour='grey', thickness=np.inf),StratigraphicUnconformity(name='Base Unconformity', unconformity_type=UnconformityType.ERODE)] + self.group_mapping = {} + def clear(self,basement=True): + """ + Clears the stratigraphic column, removing all elements. + """ + if basement: + self.order = [StratigraphicUnit(name='Basement', colour='grey', thickness=np.inf),StratigraphicUnconformity(name='Base Unconformity', unconformity_type=UnconformityType.ERODE)] + else: + self.order = [] + self.group_mapping = {} + + def add_unit(self, name,*, colour=None, thickness=None, where='top'): + unit = StratigraphicUnit(name=name, colour=colour, thickness=thickness) + + if where == 'top': + self.order.append(unit) + elif where == 'bottom': + self.order.insert(0, unit) + else: + raise ValueError("Invalid 'where' argument. Use 'top' or 'bottom'.") + + return unit + + def remove_unit(self, uuid): + """ + Removes a unit or unconformity from the stratigraphic column by its uuid. + """ + for i, element in enumerate(self.order): + if element.uuid == uuid: + del self.order[i] + return True + return False + + def add_unconformity(self, name, *, unconformity_type=UnconformityType.ERODE, where='top' ): + unconformity = StratigraphicUnconformity( + uuid=None, name=name, unconformity_type=unconformity_type + ) + + if where == 'top': + self.order.append(unconformity) + elif where == 'bottom': + self.order.insert(0, unconformity) + else: + raise ValueError("Invalid 'where' argument. Use 'top' or 'bottom'.") + return unconformity + + def get_element_by_index(self, index): + """ + Retrieves an element by its index from the stratigraphic column. + """ + if index < 0 or index >= len(self.order): + raise IndexError("Index out of range") + return self.order[index] + + def get_unit_by_name(self, name): + """ + Retrieves a unit by its name from the stratigraphic column. + """ + for unit in self.order: + if isinstance(unit, StratigraphicUnit) and unit.name == name: + return unit + + return None + def get_unconformity_by_name(self, name): + """ + Retrieves an unconformity by its name from the stratigraphic column. + """ + for unconformity in self.order: + if isinstance(unconformity, StratigraphicUnconformity) and unconformity.name == name: + return unconformity + + return None + def get_element_by_uuid(self, uuid): + """ + Retrieves an element by its uuid from the stratigraphic column. + """ + for element in self.order: + if element.uuid == uuid: + return element + raise KeyError(f"No element found with uuid: {uuid}") + def add_element(self, element): + """ + Adds a StratigraphicColumnElement to the stratigraphic column. + """ + if isinstance(element, StratigraphicColumnElement): + self.order.append(element) + else: + raise TypeError("Element must be an instance of StratigraphicColumnElement") + + def get_elements(self): + """ + Returns a list of all elements in the stratigraphic column. + """ + return self.order + + def get_groups(self): + groups = [] + i=0 + group = StratigraphicGroup( + name=( + f'Group_{i}' + if f'Group_{i}' not in self.group_mapping + else self.group_mapping[f'Group_{i}'] + ) + ) + for e in reversed(self.order): + if isinstance(e, StratigraphicUnit): + group.units.append(e) + else: + if group.units: + groups.append(group) + i+=1 + group = StratigraphicGroup( + name=( + f'Group_{i}' + if f'Group_{i}' not in self.group_mapping + else self.group_mapping[f'Group_{i}'] + ) + ) + if group: + groups.append(group) + return groups + + def get_unitname_groups(self): + groups = self.get_groups() + groups_list = [] + group = [] + for g in groups: + group = [u.name for u in g.units if isinstance(u, StratigraphicUnit)] + groups_list.append(group) + return groups_list + + + def __getitem__(self, uuid): + """ + Retrieves an element by its uuid from the stratigraphic column. + """ + for element in self.order: + if element.uuid == uuid: + return element + raise KeyError(f"No element found with uuid: {uuid}") + + def update_order(self, new_order): + """ + Updates the order of elements in the stratigraphic column based on a new order list. + """ + if not isinstance(new_order, list): + raise TypeError("New order must be a list") + self.order = [ + self.__getitem__(uuid) for uuid in new_order if self.__getitem__(uuid) is not None + ] + + def update_element(self, unit_data: Dict): + """ + Updates an existing element in the stratigraphic column with new data. + :param unit_data: A dictionary containing the updated data for the element. + """ + if not isinstance(unit_data, dict): + raise TypeError("unit_data must be a dictionary") + element = self.__getitem__(unit_data['uuid']) + if isinstance(element, StratigraphicUnit): + element.name = unit_data.get('name', element.name) + element.colour = unit_data.get('colour', element.colour) + element.thickness = unit_data.get('thickness', element.thickness) + elif isinstance(element, StratigraphicUnconformity): + element.name = unit_data.get('name', element.name) + element.unconformity_type = UnconformityType( + unit_data.get('unconformity_type', element.unconformity_type.value) + ) + + def __str__(self): + """ + Returns a string representation of the stratigraphic column, listing all elements. + """ + return "\n".join([f"{i+1}. {element}" for i, element in enumerate(self.order)]) + + def to_dict(self): + """ + Converts the stratigraphic column to a dictionary representation. + """ + return { + "elements": [element.to_dict() for element in self.order], + } + def update_from_dict(self, data): + """ + Updates the stratigraphic column from a dictionary representation. + """ + if not isinstance(data, dict): + raise TypeError("Data must be a dictionary") + self.clear(basement=False) + elements_data = data.get("elements", []) + for element_data in elements_data: + if "unconformity_type" in element_data: + element = StratigraphicUnconformity.from_dict(element_data) + else: + element = StratigraphicUnit.from_dict(element_data) + self.add_element(element) + @classmethod + def from_dict(cls, data): + """ + Creates a StratigraphicColumn from a dictionary representation. + """ + if not isinstance(data, dict): + raise TypeError("Data must be a dictionary") + column = cls() + column.clear(basement=False) + elements_data = data.get("elements", []) + for element_data in elements_data: + if "unconformity_type" in element_data: + element = StratigraphicUnconformity.from_dict(element_data) + else: + element = StratigraphicUnit.from_dict(element_data) + column.add_element(element) + return column + + def get_isovalues(self) -> Dict[str, float]: + """ + Returns a dictionary of isovalues for the stratigraphic units in the column. + """ + surface_values = {} + for g in reversed(self.get_groups()): + v = 0 + for u in g.units: + surface_values[u.name] = {'value':v,'group':g.name,'colour':u.colour} + v += u.thickness + return surface_values + + def plot(self,*, ax=None, **kwargs): + import matplotlib.pyplot as plt + from matplotlib import cm + from matplotlib.patches import Polygon + from matplotlib.collections import PatchCollection + n_units = 0 # count how many discrete colours (number of stratigraphic units) + xmin = 0 + ymin = 0 + ymax = 1 + xmax = 1 + fig = None + if ax is None: + fig, ax = plt.subplots(figsize=(2, 10)) + patches = [] # stores the individual stratigraphic unit polygons + + total_height = 0 + prev_coords = [0, 0] + + # iterate through groups, skipping faults + for u in reversed(self.order): + if u.element_type == StratigraphicColumnElementType.UNCONFORMITY: + logger.info(f"Plotting unconformity {u.name} of type {u.unconformity_type.value}") + ax.axhline(y=total_height, linestyle='--', color='black') + ax.annotate( + getattr(u, 'name', 'Unconformity'), + xy=(xmin, total_height), + fontsize=8, + ha='left', + ) + + total_height -= 0.05 # Adjust height slightly for visual separation + continue + + if u.element_type == StratigraphicColumnElementType.UNIT: + logger.info(f"Plotting unit {u.name} of type {u.element_type}") + + n_units += 1 + + ymax = total_height + ymin = ymax - (getattr(u, 'thickness', np.nan) if not np.isinf(getattr(u, 'thickness', np.nan)) else np.nanmean([getattr(e, 'thickness', np.nan) for e in self.order if not np.isinf(getattr(e, 'thickness', np.nan))])) + + if not np.isfinite(ymin): + ymin = prev_coords[1] - (prev_coords[1] - prev_coords[0]) * (1 + rng.random()) + + total_height = ymin + + prev_coords = (ymin, ymax) + + polygon_points = np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]) + patches.append(Polygon(polygon_points)) + ax.annotate(getattr(u, 'name', 'Unknown'), xy=(xmin+(xmax-xmin)/2, (ymax-ymin)/2+ymin), fontsize=8, ha='left') + + if 'cmap' not in kwargs: + import matplotlib.colors as colors + + colours = [] + boundaries = [] + data = [] + for i, u in enumerate(self.order): + if u.element_type != StratigraphicColumnElementType.UNIT: + continue + data.append((i, u.colour)) + colours.append(u.colour) + boundaries.append(i) # print(u,v) + cmap = colors.ListedColormap(colours) + else: + cmap = cm.get_cmap(kwargs['cmap'], n_units - 1) + p = PatchCollection(patches, cmap=cmap) + + colors = np.arange(len(patches)) + p.set_array(np.array(colors)) + + ax.add_collection(p) + + ax.set_ylim(total_height - (total_height - prev_coords[0]) * 0.1, 0) + + ax.axis("off") + + return fig