This repository contains the followings:
- An implementation of AlexNet using PyTorch for training, testing, and inference on the CIFAR-10 dataset.
- A FastAPI-based REST API for image classification using a trained AlexNet model on the CIFAR-10 dataset.
- A Streamlit web application for classifying images using a FastAPI backend with a trained AlexNet model on the CIFAR-10 dataset.
ALEXNET_PYTORCH/
β
βββ data/ # Downloaded CIFAR-10 dataset
βββ src/
β βββ data_preprocessing/ # Scripts to clean/prepare data
β β βββ dataset.py
β β
β βββ model/ # Training and evaluation code
β β βββ train.py
β β βββ model.py # Model architecture
β β
β βββ inference/
β βββ predict.py # Code to run inference on new data
β
βββ models/ # Saved models (e.g., .pth)
β βββ best_model.pth
β
βββ api/ # API using FastAPI
β βββ main.py # API entry point
β
βββ streamlit_app/ # Streamlit frontend
β βββ app.py # Streamlit entry point
β
βββ tests/ # Unit and integration tests
β βββ test.py
β
βββ requirements.txt # Project dependencies
βββ LICENSE # MIT License
βββ README.md
βββ .gitignore
Ensure you have the following dependencies installed before running the scripts:
pip install -r requirements.txt
To train AlexNet on CIFAR-10, run the following command:
cd src/model
python train.py
The training script uses the following default parameters:
-
Number of Classes: 10 (CIFAR-10 dataset)
-
Epochs: 20
-
Batch Size: 64
-
Learning Rate: 0.001
The best model based on validation accuracy will be saved as best_model.pth.
-
Load the CIFAR-10 dataset and split it into training and validation sets.
-
Initialize the AlexNet model and configure it to run on a GPU (if available).
-
Define the loss function (
CrossEntropyLoss) and the optimizer (SGD). -
Train the model for a specified number of epochs, calculating loss and updating weights.
-
Evaluate the model on the validation set after each epoch.
-
Save the model if the validation accuracy improves.
Once trained, you can evaluate the model using:
cd tests
python test.py
To perform inference on new images, use the inference.py script:
cd src/inference
python src/predict.py --image_path test_images/sample.jpg
The model.py script contains the AlexNet architecture defined using PyTorch. The model is adapted for CIFAR-10 by modifying the fully connected layers to match the dataset's 10 output classes.
The src/data_preprocessing/dataset.py script includes utilities for loading and preprocessing the CIFAR-10 dataset, including splitting it into training and validation sets.
During training, validation accuracy is monitored, and the best model is saved. The final accuracy will be printed at the end of training.
A FastAPI-based REST API for image classification using a trained AlexNet model on the CIFAR-10 dataset.
- Single Image Classification: Upload and classify individual images
- Batch Processing: Classify multiple images in a single request (max 10 files)
- Top-3 Predictions: Get confidence scores for the top 3 most likely classes
- Health Monitoring: Check API and model status
- CIFAR-10 Classes: Supports classification of 10 object categories
airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
pip install fastapi uvicorn torch torchvision pillow-
Update Model Path: Edit the
model_pathin the code to point to your trained AlexNet model:model_path = Path('path/to/your/best_model.pth')
-
Run the API:
cd api python main.py
The API will be available at http://localhost:8000
API information and available endpoints
Check API status and model loading state
List all supported CIFAR-10 classes
Upload a single image for classification
- Input: Image file (PNG, JPG, JPEG)
- Output: Prediction with confidence score and top-3 results
Upload multiple images for batch classification
- Input: Up to 10 image files
- Output: Array of predictions for each image
{
"filename": "cat.jpg",
"file_size": 45632,
"image_dimensions": "224x224",
"prediction": {
"predicted_class": "cat",
"confidence": 0.89,
"top3_predictions": [
{"class": "cat", "confidence": 0.89},
{"class": "dog", "confidence": 0.08},
{"class": "deer", "confidence": 0.02}
]
}
}- Python 3.9+
- PyTorch
- FastAPI
- Pillow (PIL)
- A trained AlexNet model file (.pth)
- Images are automatically resized to 227x227 pixels (AlexNet input size)
- The API handles various image formats and converts them to RGB
- GPU acceleration is used when available, falls back to CPU
A Streamlit web application for classifying images using a FastAPI backend with a trained AlexNet model on the CIFAR-10 dataset.
- Single Image Classification: Upload and classify individual images
- Batch Processing: Classify multiple images at once (up to 10)
- Real-time API Health Monitoring: Check backend server status
- Interactive Visualizations: Confidence charts and probability distributions
- CIFAR-10 Classes: Classifies 10 categories (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck)
- Python 3.7+
- FastAPI backend server running
- Required Python packages (see Installation)
pip install streamlit requests pillow numpy plotly-
Start your FastAPI server (typically on
http://localhost:8000) -
Run the Streamlit app:
streamlit run app.py
-
Configure API URL in the sidebar (default:
http://localhost:8000) -
Upload images and get predictions!
- πΌοΈ Single Image: Upload one image for classification
- π Batch Upload: Process multiple images simultaneously
- π§ API Info: View API status and endpoints
- API configuration and health check
- CIFAR-10 class reference
- Application information
Your backend should provide these endpoints:
GET /health- API health statusPOST /predict- Single image predictionPOST /predict-batch- Batch image predictionGET /classes- Available classes
- Supported formats: PNG, JPG, JPEG
- Batch limit: 10 images maximum
- Size recommendations: Standard web image sizes for best performance
- Horizontal bar charts for top predictions
- Pie charts for probability distribution
- Real-time confidence metrics
- Color-coded class indicators
- Connection timeout handling
- API error reporting
- File format validation
- Graceful degradation when API is unavailable
Default settings can be modified in the sidebar:
- FastAPI server URL
- API timeout settings
- Display preferences