diff options
Diffstat (limited to 'api.py')
-rw-r--r-- | api.py | 32 |
1 files changed, 27 insertions, 5 deletions
@@ -3,6 +3,8 @@ from fastapi.responses import StreamingResponse from io import BytesIO import matplotlib.pyplot as plt import matplotlib.font_manager as font_manager +from matplotlib.font_manager import FontProperties +import matplotlib as mpl import numpy as np import pandas as pd @@ -32,6 +34,7 @@ class PlotDataRequest(BaseModel): y_axis_scale: str = Field("linear", title="Y Axis Scale", description="Scale of the Y axis") trendline_equation: str = Field(None, title="Trendline Equation", description="Manually specify the equation for the trendline") + special_mode: bool = Field(False, title="Special Mode", description="Special mode") constant_line_vals: list[float] = Field([], title="Constant Line Values", description="List of values for the constant line") constant_line_name: list[str] = Field([], title="Constant Line Names", description="List of names for the constant line") @@ -42,7 +45,7 @@ def plot_data(x_data, y_data, std_dev_data, color_picker, labels, df, title = "Plot", x_label = "X Axis", y_label = "Y Axis", plot_background_color="#ffffff", constant_line=[], enable_trendline=True, enable_grid=False, - trendline_color="#000000", x_axis_scale="linear", y_axis_scale="linear", trendline_equation=None): + trendline_color="#000000", x_axis_scale="linear", y_axis_scale="linear", trendline_equation=None, special_mode=False): fig, ax = plt.subplots(dpi=300) plots = [] @@ -52,10 +55,16 @@ def plot_data(x_data, y_data, std_dev_data, color_picker, labels, df, y = df[y_data[idx]].astype(float) color = color_picker[idx] data_series_title = labels[idx] - if (std_dev_data[idx] != None): - plot = ax.errorbar(x, y, yerr=df[std_dev_data[idx]].astype(float), fmt='o', ecolor='black', capsize=5, color=color, label=data_series_title) + if special_mode: + if (std_dev_data[idx] != None): + plot = ax.errorbar(x, y, yerr=df[std_dev_data[idx]].astype(float), fmt='o', ecolor='black', capsize=5, color="navy", label=data_series_title) + else: + plot = ax.plot(x, y, 'o', color="navy", label=data_series_title) else: - plot = ax.plot(x, y, 'o', color=color, label=data_series_title) + if (std_dev_data[idx] != None): + plot = ax.errorbar(x, y, yerr=df[std_dev_data[idx]].astype(float), fmt='o', ecolor='black', capsize=5, color=color, label=data_series_title) + else: + plot = ax.plot(x, y, color=color, label=data_series_title) if (type(plot) == list): plots.extend(plot) @@ -122,6 +131,17 @@ def plot_data(x_data, y_data, std_dev_data, color_picker, labels, df, ax.set_xscale(x_axis_scale) ax.set_yscale(y_axis_scale) + if special_mode: + ax.grid(linestyle="dashed", dashes=(5,3)) + ax.spines[['right', 'top']].set_visible(False) + ax.tick_params(axis='x', length=0) + ax.tick_params(axis='y', length=0) + arial_font = FontProperties(fname='./arial.ttf') + mpl.font_manager.fontManager.addfont('./arial.ttf') + with mpl.rc_context({"font.family": arial_font.get_name(), "font.size": 10}): + ax.legend(handles, labels, loc='best') + return fig + return fig @app.post("/post") @@ -158,7 +178,9 @@ async def create_plot(request: PlotDataRequest, data: Request): title=request.title, x_label=request.x_label, y_label=request.y_label, enable_trendline=request.enable_trendline, enable_grid=request.enable_grid, constant_line=constant_line, - x_axis_scale=request.x_axis_scale, y_axis_scale=request.y_axis_scale, trendline_equation=request.trendline_equation) + x_axis_scale=request.x_axis_scale, y_axis_scale=request.y_axis_scale, trendline_equation=request.trendline_equation, + special_mode=request.special_mode + ) buf = BytesIO() fig.savefig(buf, format="png") |