File size: 5,487 Bytes
0b1c1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/env python3
"""
Script to analyze and compare training results from multiple model runs.
"""
import json
import os
from pathlib import Path
def load_metadata(run_dir):
"""Load metadata from a training run directory"""
metadata_path = os.path.join(run_dir, "metadata.json")
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
return None
def analyze_all_runs():
"""Analyze all training runs and create comparison"""
runs_dir = Path("runs")
results = []
# Find all metadata files
for run_dir in runs_dir.glob("*/"):
if run_dir.is_dir():
metadata = load_metadata(run_dir)
if metadata:
results.append({
'run_id': run_dir.name,
'model': metadata.get('classifier', 'Unknown'),
'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank',
'max_features': metadata.get('max_features', 0),
'ngram_range': metadata.get('ngram_range', [1,1]),
'train_accuracy': metadata.get('train_accuracy', 0),
'test_accuracy': metadata.get('test_accuracy', 0),
'train_time': metadata.get('train_time', 0),
'prediction_time': metadata.get('prediction_time', 0),
'train_samples': metadata.get('train_samples', 0),
'test_samples': metadata.get('test_samples', 0)
})
return results
def print_comparison_table(results):
"""Print formatted comparison table"""
print("\n" + "="*120)
print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS")
print("="*120)
# Filter for VNTC results (news classification)
vntc_results = [r for r in results if r['dataset'] == 'VNTC']
if vntc_results:
print("\nVNTC Dataset (Vietnamese News Classification):")
print("-"*120)
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
print("-"*120)
# Sort by test accuracy
vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True)
for result in vntc_results:
model = result['model'][:18]
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
train_acc = f"{result['train_accuracy']:.4f}"
test_acc = f"{result['test_accuracy']:.4f}"
train_time = f"{result['train_time']:.1f}s"
pred_time = f"{result['prediction_time']:.1f}s"
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")
# Filter for UTS2017_Bank results
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']
if bank_results:
print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):")
print("-"*120)
print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
print("-"*120)
# Sort by test accuracy
bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True)
for result in bank_results:
model = result['model'][:18]
features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
train_acc = f"{result['train_accuracy']:.4f}"
test_acc = f"{result['test_accuracy']:.4f}"
train_time = f"{result['train_time']:.1f}s"
pred_time = f"{result['prediction_time']:.1f}s"
print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")
print("="*120)
if vntc_results:
best_vntc = max(vntc_results, key=lambda x: x['test_accuracy'])
print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy")
if bank_results:
best_bank = max(bank_results, key=lambda x: x['test_accuracy'])
print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy")
def main():
"""Main analysis function"""
print("Analyzing Vietnamese Text Classification Training Results...")
results = analyze_all_runs()
if not results:
print("No training results found in runs/ directory.")
return
print(f"Found {len(results)} training runs.")
print_comparison_table(results)
# Create summary statistics
vntc_results = [r for r in results if r['dataset'] == 'VNTC']
bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']
print("\nSummary:")
print(f"- VNTC runs: {len(vntc_results)}")
print(f"- UTS2017_Bank runs: {len(bank_results)}")
if vntc_results:
avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results)
print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}")
if bank_results:
avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results)
print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}")
if __name__ == "__main__":
main() |