diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..91ce952 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,80 @@ +name: CI + +on: + push: + branches: [ main, develop, feature/* ] + pull_request: + branches: [ main, develop ] + +jobs: + test: + name: Test with Python ${{ matrix.python-version }} + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential g++ + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip + pip install numpy pybind11 pytest pillow + + - name: Build C++ extension + run: | + make clean + make all + + - name: Create test image + run: | + python -c "from PIL import Image; import numpy as np; img = np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8); Image.fromarray(img).save('tests/test_data/sample.jpg', 'JPEG', quality=90)" + + - name: Run tests + run: | + python -m pytest tests/ -v --tb=short + + - name: Test import + run: | + python -c "import sys; sys.path.insert(0, 'src/python'); import fast_jpeg_decoder as fjd; print('Import successful'); print('Version:', fjd.__version__)" + + build-check: + name: Build Check + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install numpy pybind11 + + - name: Build extension + run: | + make clean + make all + + - name: Check build artifacts + run: | + ls -lh src/python/fast_jpeg_decoder/_fast_jpeg_decoder*.so diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9762dcb --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ + +# C++ build artifacts +*.o +*.so +*.a +*.dylib + +# Python build artifacts +__pycache__/ +*.pyc +*.pyo +*.pyd +*.egg-info/ +dist/ +build/ +*.egg + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Test images +*.jpg +*.jpeg +*.png +!tests/test_data/*.jpg +!tests/test_data/*.jpeg +!tests/test_data/*.png + +# OS +.DS_Store +Thumbs.db + +# Documentation +info.md \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1b7c7ab --- /dev/null +++ b/Makefile @@ -0,0 +1,82 @@ +# Makefile for Fast JPEG Decoder + +# Compiler settings +CXX = g++ +CXXFLAGS = -std=c++11 -O3 -Wall -fPIC +INCLUDES = -Isrc/cpp + +# Python settings +PYTHON = python3 +PYTHON_CONFIG = python3-config +PYTHON_INCLUDES = $(shell $(PYTHON_CONFIG) --includes) +PYTHON_LDFLAGS = $(shell $(PYTHON_CONFIG) --ldflags) + +# pybind11 settings +PYBIND11_INCLUDES = $(shell $(PYTHON) -m pybind11 --includes) + +# Source files +CPP_SOURCES = src/cpp/bitstream.cpp \ + src/cpp/idct.cpp \ + src/cpp/huffman.cpp \ + src/cpp/decoder.cpp + +BINDING_SOURCES = src/bindings/bindings.cpp + +# Object files +CPP_OBJECTS = $(CPP_SOURCES:.cpp=.o) +BINDING_OBJECTS = $(BINDING_SOURCES:.cpp=.o) + +# Output +EXTENSION_SUFFIX = $(shell $(PYTHON_CONFIG) --extension-suffix) +TARGET = src/python/fast_jpeg_decoder/_fast_jpeg_decoder$(EXTENSION_SUFFIX) + +# Targets +.PHONY: all clean install test develop + +all: $(TARGET) + +# Build Python extension +$(TARGET): $(CPP_OBJECTS) $(BINDING_OBJECTS) + $(CXX) -shared -fPIC $(CXXFLAGS) $^ -o $@ $(PYTHON_LDFLAGS) + +# Build C++ objects +src/cpp/%.o: src/cpp/%.cpp + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# Build binding objects +src/bindings/%.o: src/bindings/%.cpp + $(CXX) $(CXXFLAGS) $(INCLUDES) $(PYTHON_INCLUDES) $(PYBIND11_INCLUDES) -c $< -o $@ + +# Install package +install: all + $(PYTHON) setup.py install + +# Develop mode (editable install) +develop: all + $(PYTHON) -m pip install -e . + +# Run tests +test: develop + $(PYTHON) -m pytest tests/ -v + +# Clean build artifacts +clean: + rm -f $(CPP_OBJECTS) $(BINDING_OBJECTS) + rm -f $(TARGET) + rm -rf build/ dist/ *.egg-info + rm -rf src/python/fast_jpeg_decoder/*.so + rm -rf src/python/fast_jpeg_decoder/*.pyd + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete + +# Help +help: + @echo "Fast JPEG Decoder - Makefile" + @echo "" + @echo "Targets:" + @echo " all - Build the Python extension (default)" + @echo " install - Install the package" + @echo " develop - Install in development mode (editable)" + @echo " test - Run tests" + @echo " clean - Remove build artifacts" + @echo " help - Show this help message" diff --git a/README.md b/README.md index 5295477..58cf2b4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,111 @@ -# Fast-Jpeg-Decoder +# Fast JPEG Decoder + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + 視訊壓縮期末專案 + +高效能 JPEG 解碼器,核心計算使用 C++ 實現,透過 pybind11 提供 Python API。 + +## 特點 + +- **高效能**: 核心解碼邏輯使用 C++ 實現 +- **易用性**: 提供簡潔的 Python API +- **正確性**: 完整實現 JPEG Baseline DCT 解碼流程 +- **可擴展**: 模組化設計,便於後續優化(OpenMP, SIMD) + +## 安裝 + +### 依賴 + +- Python 3.8+ +- NumPy +- pybind11 +- C++ 編譯器(支援 C++11) + +### 從原始碼安裝 + +```bash +# 安裝依賴 +pip install numpy pybind11 + +# 編譯並安裝 +make develop +``` + +## 使用方法 + +### 基本用法 + +```python +import fast_jpeg_decoder as fjd + +# 從檔案載入 JPEG 圖片 +image = fjd.load('photo.jpg') +print(image.shape) # (height, width, 3) + +# 從 bytes 載入 +with open('photo.jpg', 'rb') as f: + data = f.read() +image = fjd.load_bytes(data) +``` + +### 使用 Decoder 類別 + +```python +import fast_jpeg_decoder as fjd + +decoder = fjd.JPEGDecoder() +decoder.decode_file('photo.jpg') + +print(f"Width: {decoder.width}") +print(f"Height: {decoder.height}") +print(f"Channels: {decoder.channels}") + +image = decoder.get_image_data() +``` + +## 測試 + +```bash +# 運行測試 +make test + +# 或直接使用 pytest +pytest tests/ -v +``` + +## 專案結構 + +``` +Fast-Jpeg-Decoder/ +├── src/ +│ ├── cpp/ # C++ 核心實現 +│ ├── bindings/ # pybind11 綁定 +│ └── python/ # Python 包裝 +├── tests/ # 單元測試 +├── benchmarks/ # 效能測試 +├── Makefile # 建構腳本 +└── setup.py # Python 安裝腳本 +``` + +## JPEG 解碼流程 + +1. 解析文件結構(Markers) +2. 霍夫曼解碼(Bitstream) +3. RLE 解碼 +4. 反 ZigZag 排序 +5. 反量化(Dequantization) +6. 反離散餘弦變換(IDCT) +7. 取樣重建(Upsampling) +8. 色彩空間轉換(YCbCr → RGB) + +## 開發計劃 + +- [x] Phase 1: 基礎實現(Naive 版本) +- [x] Phase 2: CI/CD +- [ ] Phase 3: 效能優化(OpenMP, SIMD) +- [ ] Phase 4: 基準測試與比較 + +## License + +MIT License diff --git a/example.py b/example.py new file mode 100755 index 0000000..2328640 --- /dev/null +++ b/example.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +範例:使用 Fast JPEG Decoder 解碼 JPEG 圖片 +""" + +import sys +import os +import numpy as np + +# Add src/python to path for development +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'python')) + +try: + import fast_jpeg_decoder as fjd +except ImportError: + print("Error: fast_jpeg_decoder 模組未安裝") + print("請先執行: make develop") + sys.exit(1) + + +def main(): + """主函數""" + if len(sys.argv) < 2: + print("用法: python example.py ") + print("\n範例:") + print(" python example.py photo.jpg") + return + + filename = sys.argv[1] + + print(f"正在解碼: {filename}") + + try: + # 方法 1: 使用簡便函數 + image = fjd.load(filename) + print(f"✓ 解碼成功!") + print(f" 圖片尺寸: {image.shape[1]} x {image.shape[0]}") + print(f" 通道數: {image.shape[2]}") + print(f" 資料類型: {image.dtype}") + print(f" 數值範圍: [{image.min()}, {image.max()}]") + + # 方法 2: 使用 Decoder 類別 + print("\n使用 Decoder 類別:") + decoder = fjd.JPEGDecoder() + success = decoder.decode_file(filename) + + if success: + print(f"✓ 解碼成功!") + print(f" 寬度: {decoder.width}") + print(f" 高度: {decoder.height}") + print(f" 通道: {decoder.channels}") + + image2 = decoder.get_image_data() + + # 驗證兩種方法得到相同結果 + if np.array_equal(image, image2): + print("\n✓ 兩種方法結果一致") + else: + print("\n✗ 兩種方法結果不同") + + except Exception as e: + print(f"✗ 解碼失敗: {e}") + return 1 + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..655bbde --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=45", "wheel", "pybind11>=2.6.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "fast_jpeg_decoder" +version = "0.1.0" +description = "A high-performance JPEG decoder implemented in C++" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [ + "numpy", + "pybind11>=2.6.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pillow", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = "-v --tb=short" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9855d94 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e112a98 --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +""" +Setup script for Fast JPEG Decoder +""" + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext +import sys +import os + +class get_pybind_include: + """Helper class to determine the pybind11 include path""" + def __str__(self): + import pybind11 + return pybind11.get_include() + +ext_modules = [ + Extension( + 'fast_jpeg_decoder._fast_jpeg_decoder', + sources=[ + 'src/cpp/bitstream.cpp', + 'src/cpp/idct.cpp', + 'src/cpp/huffman.cpp', + 'src/cpp/decoder.cpp', + 'src/bindings/bindings.cpp', + ], + include_dirs=[ + get_pybind_include(), + 'src/cpp', + ], + language='c++', + extra_compile_args=['-std=c++11', '-O3', '-Wall'], + ), +] + +setup( + name='fast_jpeg_decoder', + version='0.1.0', + author='Your Name', + description='A high-performance JPEG decoder implemented in C++', + long_description='', + ext_modules=ext_modules, + packages=['fast_jpeg_decoder'], + package_dir={'fast_jpeg_decoder': 'src/python/fast_jpeg_decoder'}, + install_requires=['numpy', 'pybind11>=2.6.0'], + python_requires='>=3.6', + zip_safe=False, +) diff --git a/src/bindings/bindings.cpp b/src/bindings/bindings.cpp new file mode 100644 index 0000000..3226c61 --- /dev/null +++ b/src/bindings/bindings.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include "../cpp/decoder.h" + +namespace py = pybind11; + +PYBIND11_MODULE(_fast_jpeg_decoder, m) { + m.doc() = "Fast JPEG Decoder - C++ implementation with Python bindings"; + + py::class_(m, "JPEGDecoder") + .def(py::init<>()) + .def("decode_file", &jpeg::JPEGDecoder::decodeFile, + py::arg("filename"), + "Decode a JPEG file from disk") + .def("decode_memory", [](jpeg::JPEGDecoder& self, py::bytes data) { + std::string str = data; + return self.decodeMemory( + reinterpret_cast(str.data()), + str.size() + ); + }, + py::arg("data"), + "Decode a JPEG from memory (bytes)") + .def("get_image_data", [](const jpeg::JPEGDecoder& self) { + int width = self.getWidth(); + int height = self.getHeight(); + const uint8_t* data = self.getImageData(); + + // 建立 numpy array + auto result = py::array_t({height, width, 3}); + auto buf = result.request(); + uint8_t* ptr = static_cast(buf.ptr); + + // 複製資料 + std::memcpy(ptr, data, width * height * 3); + + return result; + }, + "Get decoded image data as numpy array (H, W, 3)") + .def_property_readonly("width", &jpeg::JPEGDecoder::getWidth, + "Image width") + .def_property_readonly("height", &jpeg::JPEGDecoder::getHeight, + "Image height") + .def_property_readonly("channels", &jpeg::JPEGDecoder::getChannels, + "Number of color channels"); + + // 解碼檔案並返回 numpy array + m.def("decode", [](const std::string& filename) { + jpeg::JPEGDecoder decoder; + if (!decoder.decodeFile(filename)) { + throw std::runtime_error("Failed to decode JPEG file: " + filename); + } + + int width = decoder.getWidth(); + int height = decoder.getHeight(); + const uint8_t* data = decoder.getImageData(); + + auto result = py::array_t({height, width, 3}); + auto buf = result.request(); + uint8_t* ptr = static_cast(buf.ptr); + std::memcpy(ptr, data, width * height * 3); + + return result; + }, py::arg("filename"), "Decode a JPEG file and return as numpy array"); + + m.def("decode_bytes", [](py::bytes data) { + std::string str = data; + jpeg::JPEGDecoder decoder; + if (!decoder.decodeMemory( + reinterpret_cast(str.data()), + str.size())) { + throw std::runtime_error("Failed to decode JPEG from bytes"); + } + + int width = decoder.getWidth(); + int height = decoder.getHeight(); + const uint8_t* img_data = decoder.getImageData(); + + auto result = py::array_t({height, width, 3}); + auto buf = result.request(); + uint8_t* ptr = static_cast(buf.ptr); + std::memcpy(ptr, img_data, width * height * 3); + + return result; + }, py::arg("data"), "Decode a JPEG from bytes and return as numpy array"); +} diff --git a/src/cpp/bitstream.cpp b/src/cpp/bitstream.cpp new file mode 100644 index 0000000..d7de4ee --- /dev/null +++ b/src/cpp/bitstream.cpp @@ -0,0 +1,121 @@ +#include "bitstream.h" +#include +#include + +namespace jpeg { + +BitStream::BitStream(const uint8_t* data, size_t size) + : data_(data), size_(size), byte_pos_(0), bit_buffer_(0), bits_in_buffer_(0) { + if (!data || size == 0) { + throw std::invalid_argument("BitStream: invalid data or size"); + } +} + +void BitStream::fillBuffer() { + while (bits_in_buffer_ <= 24 && byte_pos_ < size_) { + uint8_t byte = data_[byte_pos_++]; + + // 處理 byte stuffing: 如果是 0xFF,檢查下一個 byte + if (byte == 0xFF) { + if (byte_pos_ < size_) { + uint8_t next_byte = data_[byte_pos_]; + if (next_byte == 0x00) { + // 是 byte stuffing (FF 00),跳過 00 + byte_pos_++; + } else if (next_byte >= 0xD0 && next_byte <= 0xD7) { + // 是 RST marker (FF Dx),跳過整個 marker + byte_pos_++; + continue; // 不要將 0xFF 放入緩衝區,繼續讀取下一個 byte + } else { + // 其他 marker,可能是錯誤或資料結束 + // 根據 JPEG 規範,這裡應該是資料結束,但我們保守地放入 0xFF + } + } + } + + bit_buffer_ = (bit_buffer_ << 8) | byte; + bits_in_buffer_ += 8; + } +} + +uint16_t BitStream::readBits(int n_bits) { + if (n_bits <= 0 || n_bits > 16) { + throw std::invalid_argument("BitStream: n_bits must be between 1 and 16"); + } + + if (bits_in_buffer_ < n_bits) { + fillBuffer(); + } + + if (bits_in_buffer_ < n_bits) { + throw std::runtime_error("BitStream: unexpected end of data"); + } + + uint16_t result = bit_buffer_ >> (bits_in_buffer_ - n_bits); + + bits_in_buffer_ -= n_bits; + bit_buffer_ &= (1U << bits_in_buffer_) - 1; + + return result; +} + +uint16_t BitStream::peekBits(int n_bits) { + if (n_bits <= 0 || n_bits > 16) { + throw std::invalid_argument("BitStream: n_bits must be between 1 and 16"); + } + + if (bits_in_buffer_ < n_bits) { + fillBuffer(); + } + + if (bits_in_buffer_ < n_bits) { + throw std::runtime_error("BitStream: unexpected end of data"); + } + + return bit_buffer_ >> (bits_in_buffer_ - n_bits); +} + +void BitStream::skipBits(int n_bits) { + if (bits_in_buffer_ >= n_bits) { + bits_in_buffer_ -= n_bits; + bit_buffer_ &= (1U << bits_in_buffer_) - 1; + } else { + n_bits -= bits_in_buffer_; + bits_in_buffer_ = 0; + bit_buffer_ = 0; + + // 這裡可以優化,但為了簡單起見,先這樣 + size_t bytes_to_skip = n_bits / 8; + byte_pos_ += bytes_to_skip; + n_bits %= 8; + + if (n_bits > 0) { + readBits(n_bits); + } + } +} + +bool BitStream::hasMoreData() const { + return bits_in_buffer_ > 0 || byte_pos_ < size_; +} + +size_t BitStream::getBitPosition() const { + return byte_pos_ * 8 - bits_in_buffer_; +} + +void BitStream::reset(size_t byte_pos, int bit_pos) { + if (byte_pos > size_) { // allow reset to end + throw std::out_of_range("BitStream: byte_pos out of range"); + } + if (bit_pos < 0 || bit_pos >= 8) { + throw std::out_of_range("BitStream: bit_pos must be 0-7"); + } + byte_pos_ = byte_pos; + bits_in_buffer_ = 0; + bit_buffer_ = 0; + if (bit_pos > 0) { + readBits(bit_pos); // 讀取並丟棄,以對齊 bit 位置 + } +} + +} // namespace jpeg diff --git a/src/cpp/bitstream.h b/src/cpp/bitstream.h new file mode 100644 index 0000000..e6458cd --- /dev/null +++ b/src/cpp/bitstream.h @@ -0,0 +1,75 @@ +#ifndef BITSTREAM_H +#define BITSTREAM_H + +#include +#include +#include + +namespace jpeg { + +/** + * BitStream - 處理 JPEG bit 層級的資料讀取 + * + * JPEG 的 SOS (Start of Scan) 資料是以 bit 為單位編碼的, + * 而且包含 byte stuffing: 0xFF 後面必須跟 0x00,解碼時需跳過 0x00 + */ +class BitStream { +public: + BitStream(const uint8_t* data, size_t size); + + /** + * 讀取指定數量的 bits + * @param n_bits 要讀取的 bit 數量 (1-16) + * @return 讀取的數值 + */ + uint16_t readBits(int n_bits); + + /** + * 查看(peek)接下來的 bits 但不移動位置 + * @param n_bits 要查看的 bit 數量 + * @return 讀取的數值 + */ + uint16_t peekBits(int n_bits); + + /** + * 跳過指定數量的 bits + * @param n_bits 要跳過的 bit 數量 + */ + void skipBits(int n_bits); + + /** + * 檢查是否還有資料可讀 + * @return true 如果還有資料 + */ + bool hasMoreData() const; + + /** + * 取得當前位元位置 + * @return 當前的 bit 位置 + */ + size_t getBitPosition() const; + + /** + * 重設到指定位置 + * @param byte_pos 位元組位置 + * @param bit_pos 在該位元組內的 bit 位置 (0-7) + */ + void reset(size_t byte_pos = 0, int bit_pos = 0); + +private: + const uint8_t* data_; // 原始資料指標 + size_t size_; // 資料大小 + size_t byte_pos_; // 當前位元組位置 + + uint32_t bit_buffer_; // 位元緩衝區 + int bits_in_buffer_; // 緩衝區中剩餘的 bit 數量 + + /** + * 填充緩衝區,處理 byte stuffing + */ + void fillBuffer(); +}; + +} // namespace jpeg + +#endif // BITSTREAM_H diff --git a/src/cpp/decoder.cpp b/src/cpp/decoder.cpp new file mode 100644 index 0000000..3d27949 --- /dev/null +++ b/src/cpp/decoder.cpp @@ -0,0 +1,419 @@ +#include "decoder.h" +#include "idct.h" +#include +#include +#include +#include +#include + +namespace jpeg { + +JPEGDecoder::JPEGDecoder() + : width_(0), height_(0), num_components_(0), + restart_interval_(0), data_pos_(0) { + std::memset(qt_set_, 0, sizeof(qt_set_)); +} + +JPEGDecoder::~JPEGDecoder() { +} + +bool JPEGDecoder::decodeFile(const std::string& filename) { + + std::ifstream file(filename, std::ios::binary | std::ios::ate); + if (!file.is_open()) { + return false; + } + + size_t size = file.tellg(); + file.seekg(0, std::ios::beg); + + jpeg_data_.resize(size); + file.read(reinterpret_cast(jpeg_data_.data()), size); + file.close(); + + data_pos_ = 0; + return parse(); +} + +bool JPEGDecoder::decodeMemory(const uint8_t* data, size_t size) { + jpeg_data_.assign(data, data + size); + data_pos_ = 0; + return parse(); +} + +uint8_t JPEGDecoder::readByte() { + if (data_pos_ >= jpeg_data_.size()) { + throw std::runtime_error("Unexpected end of JPEG data"); + } + return jpeg_data_[data_pos_++]; +} + +uint16_t JPEGDecoder::readWord() { + uint16_t high = readByte(); + uint16_t low = readByte(); + return (high << 8) | low; +} + +uint8_t JPEGDecoder::readMarker() { + uint8_t byte = readByte(); + if (byte != 0xFF) { + throw std::runtime_error("Expected marker"); + } + + // 跳過填充的 0xFF + byte = readByte(); + while (byte == 0xFF) { + byte = readByte(); + } + + return byte; +} + +void JPEGDecoder::skipSegment() { + uint16_t length = readWord(); + data_pos_ += length - 2; +} + +bool JPEGDecoder::parse() { + try { + // 檢查 SOI marker + uint8_t marker = readMarker(); + if (marker != MARKER_SOI) { + return false; + } + + // 解析各個 segments + while (true) { + marker = readMarker(); + + switch (marker) { + case MARKER_SOF0: + if (!processSOF0()) return false; + break; + + case MARKER_DQT: + if (!processDQT()) return false; + break; + + case MARKER_DHT: + if (!processDHT()) return false; + break; + + case MARKER_SOS: + if (!processSOS()) return false; + break; + + case MARKER_DRI: + if (!processDRI()) return false; + break; + + case MARKER_EOI: + return true; + + case MARKER_APP0: + default: + // 跳過不認識的 segment + skipSegment(); + break; + } + } + } catch (const std::exception& e) { + return false; + } + + return false; +} + +bool JPEGDecoder::processDQT() { + uint16_t length = readWord(); + size_t end_pos = data_pos_ + length - 2; + + while (data_pos_ < end_pos) { + uint8_t qt_info = readByte(); + int precision = (qt_info >> 4); // 0 = 8-bit, 1 = 16-bit + int qt_id = qt_info & 0x0F; + + if (qt_id >= 4) { + return false; + } + + // 讀取 64 個量化值 + for (int i = 0; i < 64; ++i) { + if (precision == 0) { + quantization_tables_[qt_id][i] = readByte(); + } else { + quantization_tables_[qt_id][i] = readWord(); + } + } + + qt_set_[qt_id] = true; + } + + return true; +} + +bool JPEGDecoder::processSOF0() { + readWord(); // length + + uint8_t precision = readByte(); // precision (通常是 8) + if (precision != 8) { + throw std::runtime_error("Only 8-bit precision is supported"); + } + + height_ = readWord(); + width_ = readWord(); + num_components_ = readByte(); + + if (num_components_ != 1 && num_components_ != 3) { + return false; // 只支援 grayscale 或 YCbCr + } + + for (int i = 0; i < num_components_; ++i) { + components_[i].id = readByte(); + uint8_t sampling = readByte(); + components_[i].h_sample = sampling >> 4; + components_[i].v_sample = sampling & 0x0F; + components_[i].qt_id = readByte(); + } + + return true; +} + +bool JPEGDecoder::processDHT() { + uint16_t length = readWord(); + size_t end_pos = data_pos_ + length - 2; + + while (data_pos_ < end_pos) { + uint8_t ht_info = readByte(); + int table_class = (ht_info >> 4); // 0 = DC, 1 = AC + int table_id = ht_info & 0x0F; + + // 讀取 bits 陣列 + uint8_t bits[16]; + int total_codes = 0; + for (int i = 0; i < 16; ++i) { + bits[i] = readByte(); + total_codes += bits[i]; + } + + // 讀取符號值 + std::vector values(total_codes); + for (int i = 0; i < total_codes; ++i) { + values[i] = readByte(); + } + + // 建立霍夫曼表 + huffman_decoder_.setTable(table_id, table_class == 1, bits, values.data()); + } + + return true; +} + +bool JPEGDecoder::processDRI() { + readWord(); // length + restart_interval_ = readWord(); + return true; +} + +bool JPEGDecoder::processSOS() { + readWord(); // length + + int num_components = readByte(); + + for (int i = 0; i < num_components; ++i) { + int component_id = readByte(); + uint8_t table_ids = readByte(); + + // 找到對應的 component + for (int j = 0; j < num_components_; ++j) { + if (components_[j].id == component_id) { + components_[j].dc_table_id = table_ids >> 4; + components_[j].ac_table_id = table_ids & 0x0F; + break; + } + } + } + + // 跳過 3 個 bytes (Ss, Se, Ah/Al) + readByte(); // Start of spectral selection + readByte(); // End of spectral selection + readByte(); // Successive approximation + + // 剩下的資料是壓縮的圖像資料 + size_t scan_data_start = data_pos_; + + // 找到掃描資料的結尾(下一個 marker 或 EOI) + size_t scan_data_end = data_pos_; + while (scan_data_end < jpeg_data_.size()) { + if (jpeg_data_[scan_data_end] == 0xFF) { + uint8_t next = jpeg_data_[scan_data_end + 1]; + if (next != 0x00 && !(next >= 0xD0 && next <= 0xD7)) { + // 找到下一個 marker + break; + } + } + scan_data_end++; + } + + size_t scan_data_size = scan_data_end - scan_data_start; + + // 使用 BitStream 解碼 + BitStream bs(jpeg_data_.data() + scan_data_start, scan_data_size); + + // 計算 MCU 的數量 + int max_h_sample = 1, max_v_sample = 1; + for (int i = 0; i < num_components_; ++i) { + max_h_sample = std::max(max_h_sample, components_[i].h_sample); + max_v_sample = std::max(max_v_sample, components_[i].v_sample); + } + + int mcu_width = max_h_sample * 8; + int mcu_height = max_v_sample * 8; + int mcu_cols = (width_ + mcu_width - 1) / mcu_width; + int mcu_rows = (height_ + mcu_height - 1) / mcu_height; + + // 準備輸出緩衝區 + image_data_.resize(width_ * height_ * 3); + + // 儲存 Y, Cb, Cr 分量 + std::vector> y_data(mcu_rows * mcu_cols); + std::vector> cb_data(mcu_rows * mcu_cols); + std::vector> cr_data(mcu_rows * mcu_cols); + + // 解碼所有 MCU + int16_t prev_dc[3] = {0, 0, 0}; + + for (int mcu_row = 0; mcu_row < mcu_rows; ++mcu_row) { + for (int mcu_col = 0; mcu_col < mcu_cols; ++mcu_col) { + int mcu_index = mcu_row * mcu_cols + mcu_col; + + // 為每個分量解碼區塊 + for (int comp = 0; comp < num_components_; ++comp) { + int h_blocks = components_[comp].h_sample; + int v_blocks = components_[comp].v_sample; + + for (int v = 0; v < v_blocks; ++v) { + for (int h = 0; h < h_blocks; ++h) { + int16_t block[64]; + decodeBlock(bs, comp, &prev_dc[comp], block); + + // IDCT + uint8_t pixels[64]; + IDCT::transform8x8(block, pixels); + + // 儲存到對應的分量 + if (comp == 0) { // Y + y_data[mcu_index].insert(y_data[mcu_index].end(), pixels, pixels + 64); + } else if (comp == 1) { // Cb + cb_data[mcu_index].insert(cb_data[mcu_index].end(), pixels, pixels + 64); + } else { // Cr + cr_data[mcu_index].insert(cr_data[mcu_index].end(), pixels, pixels + 64); + } + } + } + } + } + } + + // Upsampling 和 YCbCr 轉 RGB + upsample(y_data, cb_data, cr_data); + + data_pos_ = scan_data_end; + return true; +} + +void JPEGDecoder::decodeBlock(BitStream& bs, int component_id, int16_t* prev_dc, int16_t block[64]) { + int16_t dc = huffman_decoder_.decodeDC(bs, components_[component_id].dc_table_id, *prev_dc); + *prev_dc = dc; + + int16_t ac[63]; + huffman_decoder_.decodeAC(bs, components_[component_id].ac_table_id, ac); + + // 組合成 64 個係數 (ZigZag 順序) + int16_t coeffs[64]; + coeffs[0] = dc; + std::memcpy(coeffs + 1, ac, 63 * sizeof(int16_t)); + + // 反 ZigZag + int16_t matrix[64]; + ZigZag::toMatrix(coeffs, matrix); + + // 反量化 + int qt_id = components_[component_id].qt_id; + for (int i = 0; i < 64; ++i) { + block[i] = matrix[i] * quantization_tables_[qt_id][i]; + } +} + +void JPEGDecoder::ycbcrToRgb(int y, int cb, int cr, uint8_t& r, uint8_t& g, uint8_t& b) { + int r_val = y + 1.402 * (cr - 128); + int g_val = y - 0.344136 * (cb - 128) - 0.714136 * (cr - 128); + int b_val = y + 1.772 * (cb - 128); + + r = (r_val < 0) ? 0 : (r_val > 255) ? 255 : r_val; + g = (g_val < 0) ? 0 : (g_val > 255) ? 255 : g_val; + b = (b_val < 0) ? 0 : (b_val > 255) ? 255 : b_val; +} + +void JPEGDecoder::upsample(const std::vector>& y_blocks, + const std::vector>& cb_blocks, + const std::vector>& cr_blocks) { + // 簡化版:假設 4:4:4 或 4:2:0 + // 這裡實作簡單的 nearest neighbor upsampling + + if (num_components_ == 1) { + // Grayscale + for (int row = 0; row < height_; ++row) { + for (int col = 0; col < width_; ++col) { + int mcu_col = col / 8; + int mcu_row = row / 8; + int block_x = col % 8; + int block_y = row % 8; + int mcu_index = mcu_row * ((width_ + 7) / 8) + mcu_col; + + if (mcu_index < static_cast(y_blocks.size()) && + !y_blocks[mcu_index].empty()) { + uint8_t y = y_blocks[mcu_index][block_y * 8 + block_x]; + int pixel_index = (row * width_ + col) * 3; + image_data_[pixel_index] = y; + image_data_[pixel_index + 1] = y; + image_data_[pixel_index + 2] = y; + } + } + } + } else { + // YCbCr 轉 RGB + for (int row = 0; row < height_; ++row) { + for (int col = 0; col < width_; ++col) { + int mcu_col = col / 8; + int mcu_row = row / 8; + int block_x = col % 8; + int block_y = row % 8; + int mcu_index = mcu_row * ((width_ + 7) / 8) + mcu_col; + + if (mcu_index < static_cast(y_blocks.size()) && + !y_blocks[mcu_index].empty()) { + uint8_t y = y_blocks[mcu_index][block_y * 8 + block_x]; + uint8_t cb = 128, cr = 128; + + if (!cb_blocks[mcu_index].empty()) { + cb = cb_blocks[mcu_index][block_y * 8 + block_x]; + } + if (!cr_blocks[mcu_index].empty()) { + cr = cr_blocks[mcu_index][block_y * 8 + block_x]; + } + + uint8_t r, g, b; + ycbcrToRgb(y, cb, cr, r, g, b); + + int pixel_index = (row * width_ + col) * 3; + image_data_[pixel_index] = r; + image_data_[pixel_index + 1] = g; + image_data_[pixel_index + 2] = b; + } + } + } + } +} + +} // namespace jpeg \ No newline at end of file diff --git a/src/cpp/decoder.h b/src/cpp/decoder.h new file mode 100644 index 0000000..b3e1206 --- /dev/null +++ b/src/cpp/decoder.h @@ -0,0 +1,177 @@ +#ifndef DECODER_H +#define DECODER_H + +#include +#include +#include +#include +#include "huffman.h" +#include "bitstream.h" + +namespace jpeg { + +/** + * JPEGDecoder - JPEG 解碼器主類別 + * + * 負責整個 JPEG 解碼流程: + * 1. 解析文件結構 (markers) + * 2. 讀取量化表 (DQT) + * 3. 讀取圖像資訊 (SOF) + * 4. 讀取霍夫曼表 (DHT) + * 5. 解碼圖像資料 (SOS) + */ +class JPEGDecoder { +public: + JPEGDecoder(); + ~JPEGDecoder(); + + /** + * 從檔案解碼 JPEG 圖像 + * + * @param filename JPEG 檔案路徑 + * @return true 如果成功 + */ + bool decodeFile(const std::string& filename); + + /** + * 從記憶體解碼 JPEG 圖像 + * + * @param data JPEG 資料 + * @param size 資料大小 + * @return true 如果成功 + */ + bool decodeMemory(const uint8_t* data, size_t size); + + /** + * 取得解碼後的圖像資料(RGB 格式) + * + * @return RGB 資料指標(每個像素 3 bytes: R, G, B) + */ + const uint8_t* getImageData() const { return image_data_.data(); } + + int getWidth() const { return width_; } + + int getHeight() const { return height_; } + + int getChannels() const { return num_components_; } + +private: + // JPEG Markers + static constexpr uint8_t MARKER_SOI = 0xD8; // Start of Image + static constexpr uint8_t MARKER_EOI = 0xD9; // End of Image + static constexpr uint8_t MARKER_SOF0 = 0xC0; // Start of Frame (Baseline DCT) + static constexpr uint8_t MARKER_DHT = 0xC4; // Define Huffman Table + static constexpr uint8_t MARKER_DQT = 0xDB; // Define Quantization Table + static constexpr uint8_t MARKER_SOS = 0xDA; // Start of Scan + static constexpr uint8_t MARKER_DRI = 0xDD; // Define Restart Interval + static constexpr uint8_t MARKER_APP0 = 0xE0; // Application Segment 0 + + // 圖像資訊 + int width_; + int height_; + int num_components_; // 1 = grayscale, 3 = YCbCr + + // 顏色分量資訊 + struct Component { + int id; + int h_sample; // 水平取樣因子 + int v_sample; // 垂直取樣因子 + int qt_id; // 量化表 ID + int dc_table_id; // DC 霍夫曼表 ID + int ac_table_id; // AC 霍夫曼表 ID + }; + Component components_[3]; // 最多 3 個分量 (Y, Cb, Cr) + + // 量化表 + uint16_t quantization_tables_[4][64]; + bool qt_set_[4]; + + // 霍夫曼解碼器 + HuffmanDecoder huffman_decoder_; + + // Restart interval + int restart_interval_; + + // 解碼後的圖像資料 (RGB) + std::vector image_data_; + + // JPEG 原始資料 + std::vector jpeg_data_; + size_t data_pos_; + + /** + * 讀取一個 byte + */ + uint8_t readByte(); + + /** + * 讀取兩個 bytes (big-endian) + */ + uint16_t readWord(); + + /** + * 尋找下一個 marker + */ + uint8_t readMarker(); + + /** + * 解析 JPEG 檔案 + */ + bool parse(); + + /** + * 處理 DQT (Define Quantization Table) + */ + bool processDQT(); + + /** + * 處理 SOF0 (Start of Frame - Baseline DCT) + */ + bool processSOF0(); + + /** + * 處理 DHT (Define Huffman Table) + */ + bool processDHT(); + + /** + * 處理 SOS (Start of Scan) + */ + bool processSOS(); + + /** + * 處理 DRI (Define Restart Interval) + */ + bool processDRI(); + + /** + * 跳過一個 segment + */ + void skipSegment(); + + /** + * 解碼一個 MCU (Minimum Coded Unit) + */ + void decodeMCU(BitStream& bs, int mcu_row, int mcu_col); + + /** + * 解碼一個 8x8 區塊 + */ + void decodeBlock(BitStream& bs, int component_id, int16_t* prev_dc, int16_t block[64]); + + /** + * YCbCr 轉 RGB + */ + static void ycbcrToRgb(int y, int cb, int cr, uint8_t& r, uint8_t& g, uint8_t& b); + + /** + * Upsampling (色度子採樣重建) + */ + void upsample(const std::vector>& y_blocks, + const std::vector>& cb_blocks, + const std::vector>& cr_blocks); +}; + +} // namespace jpeg + +#endif // DECODER_H diff --git a/src/cpp/huffman.cpp b/src/cpp/huffman.cpp new file mode 100644 index 0000000..6bf05e0 --- /dev/null +++ b/src/cpp/huffman.cpp @@ -0,0 +1,179 @@ +#include "huffman.h" +#include +#include + +namespace jpeg { + +HuffmanTable::HuffmanTable() : built_(false) { + std::memset(min_code_, 0, sizeof(min_code_)); + std::memset(max_code_, 0, sizeof(max_code_)); + std::memset(val_offset_, 0, sizeof(val_offset_)); +} + +void HuffmanTable::build(const uint8_t bits[16], const uint8_t* values) { + // 計算總符號數 + int total_symbols = 0; + for (int i = 0; i < 16; ++i) { + total_symbols += bits[i]; + } + + // 複製符號值 + symbols_.clear(); + symbols_.reserve(total_symbols); + for (int i = 0; i < total_symbols; ++i) { + symbols_.push_back(values[i]); + } + + // 建立霍夫曼表(使用 JPEG 標準算法) + int k = 0; + uint32_t code = 0; + + for (int i = 0; i < 16; ++i) { + if (bits[i] != 0) { + val_offset_[i] = k; + min_code_[i] = code; + + for (int j = 0; j < bits[i]; ++j) { + code_to_symbol_[code] = symbols_[k]; + code++; + k++; + } + + max_code_[i] = code - 1; + } else { + min_code_[i] = 0xFFFFFFFF; + max_code_[i] = 0; + val_offset_[i] = 0; + } + + code <<= 1; + } + + built_ = true; +} + +uint8_t HuffmanTable::decode(BitStream& bs) const { + if (!built_) { + throw std::runtime_error("HuffmanTable: table not built"); + } + + uint32_t code = 0; + + // 逐位元讀取,嘗試匹配霍夫曼碼 + for (int i = 0; i < 16; ++i) { + code = (code << 1) | bs.readBits(1); + + if (code <= max_code_[i] && code >= min_code_[i]) { + // 找到匹配的碼 + int index = val_offset_[i] + (code - min_code_[i]); + if (index < static_cast(symbols_.size())) { + return symbols_[index]; + } + } + } + + throw std::runtime_error("HuffmanTable: invalid huffman code"); +} + +// HuffmanDecoder 實作 + +void HuffmanDecoder::setTable(int table_id, bool is_ac, const uint8_t bits[16], const uint8_t* values) { + if (table_id < 0 || table_id >= 4) { + throw std::out_of_range("HuffmanDecoder: table_id must be 0-3"); + } + + if (is_ac) { + ac_tables_[table_id].build(bits, values); + } else { + dc_tables_[table_id].build(bits, values); + } +} + +const HuffmanTable* HuffmanDecoder::getTable(int table_id, bool is_ac) const { + if (table_id < 0 || table_id >= 4) { + return nullptr; + } + + if (is_ac) { + return ac_tables_[table_id].isBuilt() ? &ac_tables_[table_id] : nullptr; + } else { + return dc_tables_[table_id].isBuilt() ? &dc_tables_[table_id] : nullptr; + } +} + +int16_t HuffmanDecoder::receiveExtend(BitStream& bs, int size) { + if (size == 0) { + return 0; + } + + // 讀取 size 個 bits + int value = bs.readBits(size); + + // 檢查是否為負數(最高位為 0) + int vt = 1 << (size - 1); + if (value < vt) { + // 負數:需要擴展 + vt = (-1 << size) + 1; + value = value + vt; + } + + return static_cast(value); +} + +int16_t HuffmanDecoder::decodeDC(BitStream& bs, int table_id, int16_t prev_dc) const { + const HuffmanTable* table = getTable(table_id, false); + if (!table) { + throw std::runtime_error("HuffmanDecoder: DC table not found"); + } + + // 解碼 DC 差分值的大小 + uint8_t size = table->decode(bs); + + // 讀取實際的差分值 + int16_t diff = receiveExtend(bs, size); + + // DC 使用差分編碼 + return prev_dc + diff; +} + +void HuffmanDecoder::decodeAC(BitStream& bs, int table_id, int16_t coeffs[63]) const { + const HuffmanTable* table = getTable(table_id, true); + if (!table) { + throw std::runtime_error("HuffmanDecoder: AC table not found"); + } + + // 初始化為 0 + std::memset(coeffs, 0, 63 * sizeof(int16_t)); + + int k = 0; + while (k < 63) { + // 解碼 (Run, Size) 對 + uint8_t rs = table->decode(bs); + + int run = rs >> 4; // 高 4 bits 是 run (前面有幾個 0) + int size = rs & 0x0F; // 低 4 bits 是 size (係數的 bit 長度) + + if (size == 0) { + if (run == 15) { + // ZRL (Zero Run Length): 16 個連續的 0 + k += 16; + } else { + // EOB (End of Block): 剩下的都是 0 + break; + } + } else { + // 跳過 run 個 0 + k += run; + + if (k >= 63) { + break; + } + + // 讀取 AC 係數 + coeffs[k] = receiveExtend(bs, size); + k++; + } + } +} + +} // namespace jpeg diff --git a/src/cpp/huffman.h b/src/cpp/huffman.h new file mode 100644 index 0000000..aef8209 --- /dev/null +++ b/src/cpp/huffman.h @@ -0,0 +1,126 @@ +#ifndef HUFFMAN_H +#define HUFFMAN_H + +#include +#include +#include +#include "bitstream.h" + +namespace jpeg { + +/** + * HuffmanTable - 霍夫曼表 + * + * 儲存 JPEG DHT (Define Huffman Table) segment 的資訊 + * 並提供解碼功能 + */ +class HuffmanTable { +public: + HuffmanTable(); + + /** + * 從 DHT segment 資料建立霍夫曼表 + * + * @param bits 16 個元素的陣列,bits[i] 表示長度為 i+1 的碼字數量 + * @param values 所有符號值的陣列 + */ + void build(const uint8_t bits[16], const uint8_t* values); + + /** + * 使用此霍夫曼表解碼一個符號 + * + * @param bs BitStream 物件 + * @return 解碼出的符號值 + */ + uint8_t decode(BitStream& bs) const; + + /** + * 檢查表是否已建立 + */ + bool isBuilt() const { return built_; } + +private: + bool built_; + + // 霍夫曼碼表:code -> symbol + // 為了效率,可以用不同的資料結構,這裡先用簡單的 map + std::map code_to_symbol_; + + // 每個碼長的最小碼值 + uint32_t min_code_[16]; + + // 每個碼長的最大碼值 + uint32_t max_code_[16]; + + // 每個碼長的符號起始索引 + int val_offset_[16]; + + // 符號值陣列 + std::vector symbols_; +}; + +/** + * HuffmanDecoder - 霍夫曼解碼器 + * + * 管理多個霍夫曼表(DC 和 AC,可能有多個) + */ +class HuffmanDecoder { +public: + /** + * 設定霍夫曼表 + * + * @param table_id 表格 ID (0-3) + * @param is_ac true 表示 AC 表,false 表示 DC 表 + * @param bits 16 個元素的陣列 + * @param values 符號值陣列 + */ + void setTable(int table_id, bool is_ac, const uint8_t bits[16], const uint8_t* values); + + /** + * 取得霍夫曼表 + * + * @param table_id 表格 ID + * @param is_ac true 表示 AC 表,false 表示 DC 表 + * @return 霍夫曼表的指標 + */ + const HuffmanTable* getTable(int table_id, bool is_ac) const; + + /** + * 解碼一個 MCU (Minimum Coded Unit) 的 DC 係數 + * + * @param bs BitStream 物件 + * @param table_id 霍夫曼表 ID + * @param prev_dc 前一個 DC 值(DC 使用差分編碼) + * @return 解碼出的 DC 係數 + */ + int16_t decodeDC(BitStream& bs, int table_id, int16_t prev_dc) const; + + /** + * 解碼一個 MCU 的 AC 係數 + * + * @param bs BitStream 物件 + * @param table_id 霍夫曼表 ID + * @param coeffs 輸出的 63 個 AC 係數(不包含 DC) + */ + void decodeAC(BitStream& bs, int table_id, int16_t coeffs[63]) const; + +private: + // DC 表(最多 4 個) + HuffmanTable dc_tables_[4]; + + // AC 表(最多 4 個) + HuffmanTable ac_tables_[4]; + + /** + * 從 BitStream 讀取一個有符號整數 + * + * @param bs BitStream 物件 + * @param size bit 數量 + * @return 解碼的整數值 + */ + static int16_t receiveExtend(BitStream& bs, int size); +}; + +} // namespace jpeg + +#endif // HUFFMAN_H diff --git a/src/cpp/idct.cpp b/src/cpp/idct.cpp new file mode 100644 index 0000000..ff6b90a --- /dev/null +++ b/src/cpp/idct.cpp @@ -0,0 +1,97 @@ +#include "idct.h" +#include +#include + +namespace jpeg { + +// ZigZag 掃描順序表 +const int ZigZag::ZIGZAG_TABLE[64] = { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63 +}; + +void ZigZag::toMatrix(const int16_t zigzag[64], int16_t matrix[64]) { + for (int i = 0; i < 64; ++i) { + matrix[ZIGZAG_TABLE[i]] = zigzag[i]; + } +} + +void ZigZag::fromMatrix(const int16_t matrix[64], int16_t zigzag[64]) { + for (int i = 0; i < 64; ++i) { + zigzag[i] = matrix[ZIGZAG_TABLE[i]]; + } +} + +const int* ZigZag::getZigZagTable() { + return ZIGZAG_TABLE; +} + +// IDCT 實作 - Naive 版本使用標準公式 + +void IDCT::idct1D(const double input[8], double output[8]) { + for (int x = 0; x < 8; ++x) { + double sum = 0.0; + for (int u = 0; u < 8; ++u) { + double cu = (u == 0) ? 1.0 / std::sqrt(2.0) : 1.0; + sum += cu * input[u] * std::cos((2 * x + 1) * u * PI / 16.0); + } + output[x] = sum / 2.0; + } +} + +void IDCT::transform8x8Float(const int16_t input[64], uint8_t output[64]) { + double temp[64]; + double temp2[64]; + + // 先對每一行做 1D IDCT + for (int row = 0; row < 8; ++row) { + double row_input[8]; + double row_output[8]; + + for (int col = 0; col < 8; ++col) { + row_input[col] = static_cast(input[row * 8 + col]); + } + + idct1D(row_input, row_output); + + for (int col = 0; col < 8; ++col) { + temp[row * 8 + col] = row_output[col]; + } + } + + // 再對每一列做 1D IDCT + for (int col = 0; col < 8; ++col) { + double col_input[8]; + double col_output[8]; + + for (int row = 0; row < 8; ++row) { + col_input[row] = temp[row * 8 + col]; + } + + idct1D(col_input, col_output); + + for (int row = 0; row < 8; ++row) { + temp2[row * 8 + col] = col_output[row]; + } + } + + // Level shift (加 128) 並 clamp 到 0-255 + for (int i = 0; i < 64; ++i) { + int value = static_cast(std::round(temp2[i])) + 128; + output[i] = clamp(value); + } +} + +void IDCT::transform8x8(const int16_t input[64], uint8_t output[64]) { + // 目前使用 float 版本 + // 之後可以實作更快的整數版本或 SIMD 版本 + transform8x8Float(input, output); +} + +} // namespace jpeg diff --git a/src/cpp/idct.h b/src/cpp/idct.h new file mode 100644 index 0000000..cc68717 --- /dev/null +++ b/src/cpp/idct.h @@ -0,0 +1,87 @@ +#ifndef IDCT_H +#define IDCT_H + +#include + +namespace jpeg { + +/** + * IDCT - 反離散餘弦變換 (Inverse Discrete Cosine Transform) + * + * 將 8x8 的 DCT 係數矩陣轉換回 8x8 的像素值 + */ +class IDCT { +public: + /** + * 對 8x8 區塊執行 2D IDCT + * + * @param input 8x8 的 DCT 係數矩陣 (已反量化) + * @param output 8x8 的輸出像素矩陣 (0-255) + */ + static void transform8x8(const int16_t input[64], uint8_t output[64]); + + /** + * 對 8x8 區塊執行 2D IDCT (float 版本,更精確但較慢) + * + * @param input 8x8 的 DCT 係數矩陣 + * @param output 8x8 的輸出矩陣 (會自動進行 level shift 和 clamp) + */ + static void transform8x8Float(const int16_t input[64], uint8_t output[64]); + +private: + // IDCT 的縮放因子 + static constexpr double PI = 3.14159265358979323846; + + /** + * 1D IDCT + * @param input 8 個輸入係數 + * @param output 8 個輸出值 + */ + static void idct1D(const double input[8], double output[8]); + + /** + * 將值 clamp 到 0-255 範圍 + */ + static inline uint8_t clamp(int value) { + if (value < 0) return 0; + if (value > 255) return 255; + return static_cast(value); + } +}; + +/** + * ZigZag - 處理 ZigZag 掃描順序 + * + * JPEG 使用 ZigZag 順序來排列 DCT 係數, + * 需要將 1D 的 64 個係數重新排列成 8x8 矩陣 + */ +class ZigZag { +public: + /** + * 將 ZigZag 順序的 1D 陣列轉換為 8x8 矩陣 + * @param zigzag 64 個 ZigZag 順序的係數 + * @param matrix 8x8 矩陣輸出 + */ + static void toMatrix(const int16_t zigzag[64], int16_t matrix[64]); + + /** + * 將 8x8 矩陣轉換為 ZigZag 順序 + * @param matrix 8x8 矩陣 + * @param zigzag 64 個 ZigZag 順序的係數輸出 + */ + static void fromMatrix(const int16_t matrix[64], int16_t zigzag[64]); + + /** + * 取得 ZigZag 掃描表 + * @return 64 個位置的對應表 + */ + static const int* getZigZagTable(); + +private: + // ZigZag 掃描順序表 + static const int ZIGZAG_TABLE[64]; +}; + +} // namespace jpeg + +#endif // IDCT_H diff --git a/src/python/fast_jpeg_decoder/__init__.py b/src/python/fast_jpeg_decoder/__init__.py new file mode 100644 index 0000000..b42bcd5 --- /dev/null +++ b/src/python/fast_jpeg_decoder/__init__.py @@ -0,0 +1,82 @@ +""" +Fast JPEG Decoder +================= + +A high-performance JPEG decoder implemented in C++ with Python bindings. + +Usage: + >>> import fast_jpeg_decoder as fjd + >>> image = fjd.decode('path/to/image.jpg') + >>> print(image.shape) # (height, width, 3) + + >>> # Or use the decoder class + >>> decoder = fjd.JPEGDecoder() + >>> decoder.decode_file('path/to/image.jpg') + >>> image = decoder.get_image_data() +""" + +try: + from ._fast_jpeg_decoder import ( + JPEGDecoder, + decode, + decode_bytes + ) +except ImportError as e: + raise ImportError( + "Failed to import C++ extension. " + "Please make sure the extension is built. " + "Run 'make' to build the extension." + ) from e + +__version__ = '0.1.0' +__all__ = ['JPEGDecoder', 'decode', 'decode_bytes'] + + +def load(filename): + """ + Load a JPEG image from file. + + Parameters + ---------- + filename : str + Path to the JPEG file + + Returns + ------- + numpy.ndarray + Image data as numpy array with shape (height, width, 3) + and dtype uint8. Channels are in RGB order. + + Examples + -------- + >>> import fast_jpeg_decoder as fjd + >>> image = fjd.load('photo.jpg') + >>> print(image.shape) + (480, 640, 3) + """ + return decode(filename) + + +def load_bytes(data): + """ + Load a JPEG image from bytes. + + Parameters + ---------- + data : bytes + JPEG image data as bytes + + Returns + ------- + numpy.ndarray + Image data as numpy array with shape (height, width, 3) + and dtype uint8. Channels are in RGB order. + + Examples + -------- + >>> import fast_jpeg_decoder as fjd + >>> with open('photo.jpg', 'rb') as f: + ... data = f.read() + >>> image = fjd.load_bytes(data) + """ + return decode_bytes(data) diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..c2fe0ad --- /dev/null +++ b/tests/README.md @@ -0,0 +1,36 @@ +# 測試資料 + +## 準備測試圖片 + +在運行測試之前,請在 `tests/test_data/` 目錄下放置一個測試用的 JPEG 圖片,命名為 `sample.jpg`。 + +### 使用 Python PIL 生成簡單測試圖片 + +```python +from PIL import Image +import numpy as np + +# 創建一個簡單的測試圖片 +width, height = 640, 480 +img = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) +Image.fromarray(img).save('tests/test_data/sample.jpg', 'JPEG', quality=90) +``` + +### 或使用現有的 JPEG 圖片 + +```bash +cp /path/to/your/image.jpg tests/test_data/sample.jpg +``` + +## 運行測試 + +```bash +# 確保先編譯擴展 +make develop + +# 運行所有測試 +make test + +# 或使用 pytest +pytest tests/ -v +``` diff --git a/tests/test_data/lena.jpg b/tests/test_data/lena.jpg new file mode 100644 index 0000000..7e83c44 Binary files /dev/null and b/tests/test_data/lena.jpg differ diff --git a/tests/test_data/sample.jpg b/tests/test_data/sample.jpg new file mode 100644 index 0000000..5cba5d1 Binary files /dev/null and b/tests/test_data/sample.jpg differ diff --git a/tests/test_decoder.py b/tests/test_decoder.py new file mode 100644 index 0000000..1578884 --- /dev/null +++ b/tests/test_decoder.py @@ -0,0 +1,166 @@ +""" +Unit tests for Fast JPEG Decoder +""" + +import pytest +import numpy as np +import os +import sys + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src', 'python')) + +try: + import fast_jpeg_decoder as fjd +except ImportError: + pytest.skip("Extension not built yet", allow_module_level=True) + + +class TestJPEGDecoder: + """Test cases for JPEG decoder""" + + def test_decoder_creation(self): + """Test that decoder can be created""" + decoder = fjd.JPEGDecoder() + assert decoder is not None + + def test_load_function_exists(self): + """Test that load function exists""" + assert hasattr(fjd, 'load') + assert hasattr(fjd, 'decode') + + def test_decoder_methods(self): + """Test that decoder has required methods""" + decoder = fjd.JPEGDecoder() + assert hasattr(decoder, 'decode_file') + assert hasattr(decoder, 'decode_memory') + assert hasattr(decoder, 'get_image_data') + assert hasattr(decoder, 'width') + assert hasattr(decoder, 'height') + assert hasattr(decoder, 'channels') + + @pytest.mark.skipif( + not os.path.exists('tests/test_data'), + reason="Test data directory not found" + ) + def test_decode_sample_image(self): + """Test decoding a sample JPEG image""" + # This test requires a sample JPEG image in tests/test_data/ + test_image_path = 'tests/test_data/sample.jpg' + + if not os.path.exists(test_image_path): + pytest.skip("Sample image not found") + + # Test using load function + image = fjd.load(test_image_path) + + # Check that result is a numpy array + assert isinstance(image, np.ndarray) + + # Check shape (should be H x W x 3) + assert len(image.shape) == 3 + assert image.shape[2] == 3 + + # Check dtype + assert image.dtype == np.uint8 + + # Check that values are in valid range + assert np.all(image >= 0) + assert np.all(image <= 255) + + @pytest.mark.skipif( + not os.path.exists('tests/test_data'), + reason="Test data directory not found" + ) + def test_decoder_class(self): + """Test using decoder class directly""" + test_image_path = 'tests/test_data/sample.jpg' + + if not os.path.exists(test_image_path): + pytest.skip("Sample image not found") + + decoder = fjd.JPEGDecoder() + success = decoder.decode_file(test_image_path) + + assert success is True + assert decoder.width > 0 + assert decoder.height > 0 + assert decoder.channels in [1, 3] + + image = decoder.get_image_data() + assert isinstance(image, np.ndarray) + assert image.shape == (decoder.height, decoder.width, 3) + + def test_invalid_file(self): + """Test that invalid file raises appropriate error""" + with pytest.raises((RuntimeError, Exception)): + fjd.load('nonexistent_file.jpg') + + def test_decode_bytes(self): + """Test decoding from bytes""" + test_image_path = 'tests/test_data/sample.jpg' + + if not os.path.exists(test_image_path): + pytest.skip("Sample image not found") + + with open(test_image_path, 'rb') as f: + data = f.read() + + image = fjd.load_bytes(data) + assert isinstance(image, np.ndarray) + assert len(image.shape) == 3 + + +class TestImageProperties: + """Test image properties and correctness""" + + @pytest.mark.skipif( + not os.path.exists('tests/test_data'), + reason="Test data directory not found" + ) + def test_image_dimensions(self): + """Test that decoded image has correct dimensions""" + test_image_path = 'tests/test_data/sample.jpg' + + if not os.path.exists(test_image_path): + pytest.skip("Sample image not found") + + decoder = fjd.JPEGDecoder() + decoder.decode_file(test_image_path) + + # Width and height should be positive + assert decoder.width > 0 + assert decoder.height > 0 + + # Get image data + image = decoder.get_image_data() + + # Check that numpy array matches reported dimensions + assert image.shape[0] == decoder.height + assert image.shape[1] == decoder.width + + @pytest.mark.skipif( + not os.path.exists('tests/test_data'), + reason="Test data directory not found" + ) + def test_rgb_channels(self): + """Test that image has correct RGB channel order""" + test_image_path = 'tests/test_data/sample.jpg' + + if not os.path.exists(test_image_path): + pytest.skip("Sample image not found") + + image = fjd.load(test_image_path) + + # Should have 3 channels for RGB + assert image.shape[2] == 3 + + # Each channel should have valid values + for channel in range(3): + channel_data = image[:, :, channel] + assert channel_data.min() >= 0 + assert channel_data.max() <= 255 + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])