Skip to content

RMalikM/AlexNet_Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

3 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

AlexNet PyTorch Implementation

Overview

This repository contains the followings:

  1. An implementation of AlexNet using PyTorch for training, testing, and inference on the CIFAR-10 dataset.
  2. A FastAPI-based REST API for image classification using a trained AlexNet model on the CIFAR-10 dataset.
  3. A Streamlit web application for classifying images using a FastAPI backend with a trained AlexNet model on the CIFAR-10 dataset.

Repository Structure

ALEXNET_PYTORCH/
    β”‚
    β”œβ”€β”€ data/                      # Downloaded CIFAR-10 dataset
    β”œβ”€β”€ src/
    β”‚   β”œβ”€β”€ data_preprocessing/     # Scripts to clean/prepare data
    β”‚   β”‚   └── dataset.py
    β”‚   β”‚
    β”‚   β”œβ”€β”€ model/                  # Training and evaluation code
    β”‚   β”‚   β”œβ”€β”€ train.py
    β”‚   β”‚   β”œβ”€β”€ model.py            # Model architecture
    β”‚   β”‚
    β”‚   └── inference/
    β”‚       β”œβ”€β”€ predict.py          # Code to run inference on new data
    β”‚
    β”œβ”€β”€ models/                     # Saved models (e.g., .pth)
    β”‚   └── best_model.pth
    β”‚
    β”œβ”€β”€ api/                        # API using FastAPI
    β”‚   β”œβ”€β”€ main.py                 # API entry point
    β”‚
    β”œβ”€β”€ streamlit_app/             # Streamlit frontend
    β”‚   β”œβ”€β”€ app.py                  # Streamlit entry point
    β”‚
    β”œβ”€β”€ tests/                     # Unit and integration tests
    β”‚   └── test.py
    β”‚
    β”œβ”€β”€ requirements.txt           # Project dependencies
    β”œβ”€β”€ LICENSE                    # MIT License
    β”œβ”€β”€ README.md
    └── .gitignore

Requirements

Ensure you have the following dependencies installed before running the scripts:

pip install -r requirements.txt

Training the Model

To train AlexNet on CIFAR-10, run the following command:

cd src/model
python train.py

Training Parameters

The training script uses the following default parameters:

  • Number of Classes: 10 (CIFAR-10 dataset)

  • Epochs: 20

  • Batch Size: 64

  • Learning Rate: 0.001

The best model based on validation accuracy will be saved as best_model.pth.

Training Workflow

  1. Load the CIFAR-10 dataset and split it into training and validation sets.

  2. Initialize the AlexNet model and configure it to run on a GPU (if available).

  3. Define the loss function (CrossEntropyLoss) and the optimizer (SGD).

  4. Train the model for a specified number of epochs, calculating loss and updating weights.

  5. Evaluate the model on the validation set after each epoch.

  6. Save the model if the validation accuracy improves.

Testing the Model

Once trained, you can evaluate the model using:

cd tests
python test.py

Running Inference

To perform inference on new images, use the inference.py script:

cd src/inference
python src/predict.py --image_path test_images/sample.jpg

Model Implementation

The model.py script contains the AlexNet architecture defined using PyTorch. The model is adapted for CIFAR-10 by modifying the fully connected layers to match the dataset's 10 output classes.

Dataset Handling

The src/data_preprocessing/dataset.py script includes utilities for loading and preprocessing the CIFAR-10 dataset, including splitting it into training and validation sets.

Results

During training, validation accuracy is monitored, and the best model is saved. The final accuracy will be printed at the end of training.

AlexNet Image Classification API

A FastAPI-based REST API for image classification using a trained AlexNet model on the CIFAR-10 dataset.

Features

  • Single Image Classification: Upload and classify individual images
  • Batch Processing: Classify multiple images in a single request (max 10 files)
  • Top-3 Predictions: Get confidence scores for the top 3 most likely classes
  • Health Monitoring: Check API and model status
  • CIFAR-10 Classes: Supports classification of 10 object categories

Supported Classes

airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck

Quick Start

Prerequisites

pip install fastapi uvicorn torch torchvision pillow

Setup

  1. Update Model Path: Edit the model_path in the code to point to your trained AlexNet model:

    model_path = Path('path/to/your/best_model.pth')
  2. Run the API:

    cd api
    python main.py

The API will be available at http://localhost:8000

API Endpoints

GET /

API information and available endpoints

GET /health

Check API status and model loading state

GET /classes

List all supported CIFAR-10 classes

POST /predict

Upload a single image for classification

  • Input: Image file (PNG, JPG, JPEG)
  • Output: Prediction with confidence score and top-3 results

POST /predict-batch

Upload multiple images for batch classification

  • Input: Up to 10 image files
  • Output: Array of predictions for each image

Example Response

{
  "filename": "cat.jpg",
  "file_size": 45632,
  "image_dimensions": "224x224",
  "prediction": {
    "predicted_class": "cat",
    "confidence": 0.89,
    "top3_predictions": [
      {"class": "cat", "confidence": 0.89},
      {"class": "dog", "confidence": 0.08},
      {"class": "deer", "confidence": 0.02}
    ]
  }
}

Requirements

  • Python 3.9+
  • PyTorch
  • FastAPI
  • Pillow (PIL)
  • A trained AlexNet model file (.pth)

Notes

  • Images are automatically resized to 227x227 pixels (AlexNet input size)
  • The API handles various image formats and converts them to RGB
  • GPU acceleration is used when available, falls back to CPU

AlexNet Image Classifier πŸ”

A Streamlit web application for classifying images using a FastAPI backend with a trained AlexNet model on the CIFAR-10 dataset.

Features

  • Single Image Classification: Upload and classify individual images
  • Batch Processing: Classify multiple images at once (up to 10)
  • Real-time API Health Monitoring: Check backend server status
  • Interactive Visualizations: Confidence charts and probability distributions
  • CIFAR-10 Classes: Classifies 10 categories (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck)

Prerequisites

  • Python 3.7+
  • FastAPI backend server running
  • Required Python packages (see Installation)

Installation

pip install streamlit requests pillow numpy plotly

Usage

  1. Start your FastAPI server (typically on http://localhost:8000)

  2. Run the Streamlit app:

    streamlit run app.py
  3. Configure API URL in the sidebar (default: http://localhost:8000)

  4. Upload images and get predictions!

App Structure

Main Tabs

  • πŸ–ΌοΈ Single Image: Upload one image for classification
  • πŸ“‚ Batch Upload: Process multiple images simultaneously
  • πŸ”§ API Info: View API status and endpoints

Sidebar Features

  • API configuration and health check
  • CIFAR-10 class reference
  • Application information

Required FastAPI Endpoints

Your backend should provide these endpoints:

  • GET /health - API health status
  • POST /predict - Single image prediction
  • POST /predict-batch - Batch image prediction
  • GET /classes - Available classes

File Support

  • Supported formats: PNG, JPG, JPEG
  • Batch limit: 10 images maximum
  • Size recommendations: Standard web image sizes for best performance

Visualization Features

  • Horizontal bar charts for top predictions
  • Pie charts for probability distribution
  • Real-time confidence metrics
  • Color-coded class indicators

Error Handling

  • Connection timeout handling
  • API error reporting
  • File format validation
  • Graceful degradation when API is unavailable

Configuration

Default settings can be modified in the sidebar:

  • FastAPI server URL
  • API timeout settings
  • Display preferences

About

AlexNet PyTorch Implementation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages