A Java-based neural network implementation for classifying Pokemon card images. This project implements a custom neural network architecture with support for multi-threaded training and data augmentation.
This project was my first attempt ever using Neural Networks. Be kind please
- Custom neural network implementation with configurable layers
- Multi-threaded batch training for improved performance
- Data augmentation capabilities for training data
- Support for serialization and deserialization of trained models
- Progress tracking and accuracy reporting during training
- Image preprocessing and resizing utilities
Main.java- Main entry point and training loop implementationNeuralNetwork.java- Core neural network implementationLayer.java- Neural network layer implementationDataLoader.java- Data loading and batch managementDataScraper.java- Data collection utilitiesPokemonScraper.java- Pokemon-specific data collectionDataAugment.py- Python script for data augmentationImageResizer.java- Image preprocessing utilities
- Java Runtime Environment (JRE)
- Python (for data augmentation)
- Pokemon card image dataset
-
Prepare your dataset:
- Place Pokemon card images in the
Data/Pokemon/ResizedCardImages/CELdirectory - Place corresponding labels in
Data/Pokemon/CardNames/CEL.txt
- Place Pokemon card images in the
-
Configure training parameters in
Main.java:public static float learnRate = 0.5f; public static float dataSetCoverage = 1f; public static int epochNumbers = 1000; public static int batchSize = 50; public static int threadNumber = 4;
-
Run the training:
javac *.java java Main
The default network architecture consists of:
- Input layer: 1104 nodes
- Hidden layer 1: 786 nodes
- Hidden layer 2: 400 nodes
- Output layer: 50 nodes
- The network loads or initializes a new model
- Training data is loaded and preprocessed
- Training proceeds in epochs with multi-threaded batch processing
- After each epoch:
- Model is saved
- Test accuracy is reported
- Progress is displayed
The project includes a Python script (DataAugment.py) for augmenting the training dataset with various transformations to improve model robustness.
Trained models are saved with the .ser extension and can be loaded for continued training or inference.