From 3c80ac5a0451dcfc93fc2783e47a0c35f1101396 Mon Sep 17 00:00:00 2001 From: Aki Date: Sun, 4 Sep 2022 13:59:00 +0200 Subject: Refactored edge cases handling --- szilagyi/nomogram.py | 26 ++++++++++++++++++-------- szilagyi/plots.py | 5 +---- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/szilagyi/nomogram.py b/szilagyi/nomogram.py index 8c567d7..4dcd141 100644 --- a/szilagyi/nomogram.py +++ b/szilagyi/nomogram.py @@ -4,6 +4,11 @@ from collections import deque from . import _dataset +class NomogramEdgeCase(Exception): + def __init__(self, value): + self.value = value + + def look_downwards(data, x, start): for i in range(start, 0, -1): if data[i - 1].x < x: @@ -46,24 +51,29 @@ def find_boundary_curves(swis, x, y): middle = segments[1][1] run = middle[j].x - middle[i].x if run == 0: - raise RuntimeError # tidy up dataset + raise NomogramEdgeCase(segments[1][0]) # Tidy up the dataset nonetheless slope = (middle[j].y - middle[i].y) / run intercept = middle[j].y - slope * middle[j].x value = slope * x + intercept if value == y: - raise RuntimeError # Exactly on point; SWI == index + raise NomogramEdgeCase(segments[1][0]) if value < y: segments.popleft() else: segments.pop() if len(segments) == 1: - raise RuntimeError # SWI == -10 + raise NomogramEdgeCase(-10) return segments def calculate_swi(x, y): - low, high = find_boundary_curves(_dataset.INDICES, x, y) - vec = _dataset.Vector(x, y) - dist_to_low = min(abs(vec - p) for p in (low[1][low[2]], low[1][low[2]])) - dist_to_high = min(abs(vec - p) for p in (high[1][high[2]], high[1][high[2]])) - return dist_to_low / (dist_to_low + dist_to_high) * (high[0] - low[0]) + low[0] + try: + low, high = find_boundary_curves(_dataset.INDICES, x, y) + vec = _dataset.Vector(x, y) + dist_to_low = min(abs(vec - p) for p in (low[1][low[2]], low[1][low[2]])) + dist_to_high = min(abs(vec - p) for p in (high[1][high[2]], high[1][high[2]])) + return dist_to_low / (dist_to_low + dist_to_high) * (high[0] - low[0]) + low[0] + except NomogramEdgeCase as edge: + return edge.value + except IndexError: + return -10 # Reduce the amount of cases in which it may occur diff --git a/szilagyi/plots.py b/szilagyi/plots.py index 44782d6..9453919 100644 --- a/szilagyi/plots.py +++ b/szilagyi/plots.py @@ -23,10 +23,7 @@ if __name__ == "__main__": row = [] for x_scaled in range(0, 1000): x = x_scaled / 25 - try: - swi = calculate_swi(x, y) - except (IndexError, RuntimeError): - swi = -10 + swi = calculate_swi(x, y) row.append(swi) C.append(row) plot.pcolormesh(X, Y, C, cmap='viridis', vmin=-10, vmax=10, rasterized=True) -- cgit v1.1