1+ import matplotlib .pyplot as plt
2+ import numpy as np
3+ import pandas as pd
4+ import torch , argparse
5+ from pathlib import Path
6+ from evaluation_function .models .basic_nn import TinyNet , f , train_model , MODEL_PATH
7+
8+ def plot_letter_histogram (show_plots : bool = False , media_dir : Path = None ):
9+ """Plot a histogram from norvig_letter_single.csv."""
10+ csv_path = Path (__file__ ).parent .parent / "evaluation_function" / "models" / "storage" / "norvig_letter_single.csv"
11+ df = pd .read_csv (csv_path )
12+
13+ df = df .sort_values (by = "Percent" , ascending = False )
14+
15+ plt .bar (df ["Letter" ], df ["Percent" ], color = "skyblue" , edgecolor = "black" )
16+ plt .xlabel ("Letter" )
17+ plt .ylabel ("Frequency" )
18+ plt .tight_layout ()
19+
20+ out_path = media_dir / "letter_histogram.png"
21+ plt .savefig (out_path , dpi = 150 , bbox_inches = "tight" )
22+ if show_plots :
23+ print (f"Plot saved to { out_path } , displaying plot now." )
24+ plt .show ()
25+ else :
26+ print (f"Plot saved to { out_path } ." )
27+
28+ def plot_neural_network_results (show_plots : bool = False , media_dir : Path = None ):
29+ """Plot the results of a neural network model against the data.
30+
31+ Args:
32+ x (torch.Tensor): Input data.
33+ y (torch.Tensor): Target data.
34+ model (torch.nn.Module): Trained neural network model.
35+ """
36+ # Load trained model (or train if needed)
37+ model = TinyNet ().to (device )
38+ model .load_state_dict (torch .load (MODEL_PATH , map_location = device ))
39+ model .eval ()
40+
41+ # Recreate training data for plotting
42+ x = torch .linspace (- 2 * torch .pi , 2 * torch .pi , 200 ).unsqueeze (1 ).to (device )
43+ y = (f (x ) + 0.1 * torch .randn_like (x )).to (device )
44+
45+ with torch .no_grad ():
46+ # Make domain twice as wide as training range
47+ x_plot = torch .linspace (2 * x .min ().item (), 2 * x .max ().item (), 800 , device = x .device ).unsqueeze (1 )
48+ y_plot = model (x_plot )
49+
50+ plt .scatter (x .cpu (), y .cpu (), s = 10 , label = "Data" )
51+ plt .plot (x_plot .cpu (), y_plot .cpu (), color = "red" , label = "Model" )
52+ plt .legend ()
53+ out_path = media_dir / "basic_nn_plot.png"
54+ plt .savefig (out_path , dpi = 150 , bbox_inches = "tight" ) # good web resolution
55+ if show_plots :
56+ print (f"Plot saved to { out_path } , displaying plot now." )
57+ plt .show ()
58+ else :
59+ print (f"Plot saved to { out_path } ." )
60+
61+
62+ if __name__ == "__main__" :
63+ device = torch .device ("mps" if torch .backends .mps .is_available () else "cpu" )
64+ parser = argparse .ArgumentParser ()
65+ parser .add_argument (
66+ "--show-plots" ,
67+ action = "store_true" ,
68+ help = "Display plots interactively instead of just saving them."
69+ )
70+ args = parser .parse_args ()
71+ media_dir = Path (__file__ ).parent / "media"
72+ media_dir .mkdir (exist_ok = True )
73+ #plot_neural_network_results(show_plots=args.show_plots, media_dir=media_dir)
74+ plot_letter_histogram (show_plots = args .show_plots , media_dir = media_dir )
0 commit comments