Skip to content

Commit ecb50b0

Browse files
Peter JohnsonPeter Johnson
authored andcommitted
Add educational material
1 parent 23423fb commit ecb50b0

File tree

9 files changed

+1044
-44
lines changed

9 files changed

+1044
-44
lines changed

.DS_Store

-6 KB
Binary file not shown.

educational_material/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Educational Material (for `langModels`)
2+
3+
This folder uses the Evaluation Function to build educational materials to accompany its use. In partilar graphs and plots, and text for lessons using the Function.
4+
5+
This folder is not part of the _Evaluation Function_ (it is not deployed) or part of the docs for the evaluation function.

educational_material/__init__.py

Whitespace-only changes.

educational_material/main.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import pandas as pd
4+
import torch, argparse
5+
from pathlib import Path
6+
from evaluation_function.models.basic_nn import TinyNet, f, train_model, MODEL_PATH
7+
8+
def plot_letter_histogram(show_plots: bool=False, media_dir: Path=None):
9+
"""Plot a histogram from norvig_letter_single.csv."""
10+
csv_path = Path(__file__).parent.parent / "evaluation_function" / "models" / "storage" / "norvig_letter_single.csv"
11+
df = pd.read_csv(csv_path)
12+
13+
df = df.sort_values(by="Percent", ascending=False)
14+
15+
plt.bar(df["Letter"], df["Percent"], color="skyblue", edgecolor="black")
16+
plt.xlabel("Letter")
17+
plt.ylabel("Frequency")
18+
plt.tight_layout()
19+
20+
out_path = media_dir / "letter_histogram.png"
21+
plt.savefig(out_path, dpi=150, bbox_inches="tight")
22+
if show_plots:
23+
print(f"Plot saved to {out_path}, displaying plot now.")
24+
plt.show()
25+
else:
26+
print(f"Plot saved to {out_path}.")
27+
28+
def plot_neural_network_results(show_plots: bool=False, media_dir: Path=None):
29+
"""Plot the results of a neural network model against the data.
30+
31+
Args:
32+
x (torch.Tensor): Input data.
33+
y (torch.Tensor): Target data.
34+
model (torch.nn.Module): Trained neural network model.
35+
"""
36+
# Load trained model (or train if needed)
37+
model = TinyNet().to(device)
38+
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
39+
model.eval()
40+
41+
# Recreate training data for plotting
42+
x = torch.linspace(-2*torch.pi, 2*torch.pi, 200).unsqueeze(1).to(device)
43+
y = (f(x) + 0.1*torch.randn_like(x)).to(device)
44+
45+
with torch.no_grad():
46+
# Make domain twice as wide as training range
47+
x_plot = torch.linspace(2*x.min().item(), 2*x.max().item(), 800, device=x.device).unsqueeze(1)
48+
y_plot = model(x_plot)
49+
50+
plt.scatter(x.cpu(), y.cpu(), s=10, label="Data")
51+
plt.plot(x_plot.cpu(), y_plot.cpu(), color="red", label="Model")
52+
plt.legend()
53+
out_path = media_dir / "basic_nn_plot.png"
54+
plt.savefig(out_path, dpi=150, bbox_inches="tight") # good web resolution
55+
if show_plots:
56+
print(f"Plot saved to {out_path}, displaying plot now.")
57+
plt.show()
58+
else:
59+
print(f"Plot saved to {out_path}.")
60+
61+
62+
if __name__ == "__main__":
63+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
64+
parser = argparse.ArgumentParser()
65+
parser.add_argument(
66+
"--show-plots",
67+
action="store_true",
68+
help="Display plots interactively instead of just saving them."
69+
)
70+
args = parser.parse_args()
71+
media_dir = Path(__file__).parent / "media"
72+
media_dir.mkdir(exist_ok=True)
73+
#plot_neural_network_results(show_plots=args.show_plots, media_dir=media_dir)
74+
plot_letter_histogram(show_plots=args.show_plots, media_dir=media_dir)
44.2 KB
Loading
19 KB
Loading

evaluation_function/dev.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"basic_nn": {
33
"response": "1.0",
4-
"answer": "1.0",
4+
"answer": "10.0",
55
"model": "basic_nn",
66
"refresh": false
77
},

poetry.lock

Lines changed: 960 additions & 41 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ priority = "explicit"
1919
python = ">=3.11,<3.14"
2020

2121
# Uncomment below and comment out cpu if using macOS, the marker feature does not work (at least on my machine) - MM
22-
#torch = { version = "^2.8.0", markers = "sys_platform != 'linux'" }
23-
torch = { version = "^2.8.0+cpu", markers = "sys_platform == 'linux'", source = "pytorch-cpu" }
22+
torch = { version = "^2.8.0", markers = "sys_platform != 'linux'" }
23+
#torch = { version = "^2.8.0+cpu", markers = "sys_platform == 'linux'", source = "pytorch-cpu" }
2424

2525
typing_extensions = "^4.12.2"
2626
lf_toolkit = { git = "https://github.com/lambda-feedback/toolkit-python.git", branch = "main", extras = [
@@ -39,6 +39,8 @@ sympy = ">=1.13.3"
3939
pytest = "^8.2.2"
4040
flake8 = "^7.1.0"
4141
nltk = "^3.9.2"
42+
matplotlib = "^3.8.0"
43+
pandas = "^2.2.1"
4244

4345
[build-system]
4446
requires = ["poetry-core"]

0 commit comments

Comments
 (0)