File size: 5,487 Bytes
0b1c1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python3
"""
Script to analyze and compare training results from multiple model runs.
"""

import json
import os
from pathlib import Path

def load_metadata(run_dir):
    """Load metadata from a training run directory"""
    metadata_path = os.path.join(run_dir, "metadata.json")
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    return None

def analyze_all_runs():
    """Analyze all training runs and create comparison"""
    runs_dir = Path("runs")
    results = []

    # Find all metadata files
    for run_dir in runs_dir.glob("*/"):
        if run_dir.is_dir():
            metadata = load_metadata(run_dir)
            if metadata:
                results.append({
                    'run_id': run_dir.name,
                    'model': metadata.get('classifier', 'Unknown'),
                    'dataset': 'VNTC' if 'VNTC' in metadata.get('config_name', '') else 'UTS2017_Bank',
                    'max_features': metadata.get('max_features', 0),
                    'ngram_range': metadata.get('ngram_range', [1,1]),
                    'train_accuracy': metadata.get('train_accuracy', 0),
                    'test_accuracy': metadata.get('test_accuracy', 0),
                    'train_time': metadata.get('train_time', 0),
                    'prediction_time': metadata.get('prediction_time', 0),
                    'train_samples': metadata.get('train_samples', 0),
                    'test_samples': metadata.get('test_samples', 0)
                })

    return results

def print_comparison_table(results):
    """Print formatted comparison table"""
    print("\n" + "="*120)
    print("VIETNAMESE TEXT CLASSIFICATION - MODEL COMPARISON RESULTS")
    print("="*120)

    # Filter for VNTC results (news classification)
    vntc_results = [r for r in results if r['dataset'] == 'VNTC']

    if vntc_results:
        print("\nVNTC Dataset (Vietnamese News Classification):")
        print("-"*120)
        print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
        print("-"*120)

        # Sort by test accuracy
        vntc_results.sort(key=lambda x: x['test_accuracy'], reverse=True)

        for result in vntc_results:
            model = result['model'][:18]
            features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
            ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
            train_acc = f"{result['train_accuracy']:.4f}"
            test_acc = f"{result['test_accuracy']:.4f}"
            train_time = f"{result['train_time']:.1f}s"
            pred_time = f"{result['prediction_time']:.1f}s"

            print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")

    # Filter for UTS2017_Bank results
    bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']

    if bank_results:
        print("\nUTS2017_Bank Dataset (Vietnamese Banking Text Classification):")
        print("-"*120)
        print(f"{'Model':<20} {'Features':<10} {'N-gram':<10} {'Train Acc':<12} {'Test Acc':<12} {'Train Time':<12} {'Pred Time':<12}")
        print("-"*120)

        # Sort by test accuracy
        bank_results.sort(key=lambda x: x['test_accuracy'], reverse=True)

        for result in bank_results:
            model = result['model'][:18]
            features = f"{result['max_features']//1000}k" if result['max_features'] > 0 else "N/A"
            ngram = f"{result['ngram_range'][0]}-{result['ngram_range'][1]}"
            train_acc = f"{result['train_accuracy']:.4f}"
            test_acc = f"{result['test_accuracy']:.4f}"
            train_time = f"{result['train_time']:.1f}s"
            pred_time = f"{result['prediction_time']:.1f}s"

            print(f"{model:<20} {features:<10} {ngram:<10} {train_acc:<12} {test_acc:<12} {train_time:<12} {pred_time:<12}")

    print("="*120)

    if vntc_results:
        best_vntc = max(vntc_results, key=lambda x: x['test_accuracy'])
        print(f"\nBest VNTC model: {best_vntc['model']} with {best_vntc['test_accuracy']:.4f} test accuracy")

    if bank_results:
        best_bank = max(bank_results, key=lambda x: x['test_accuracy'])
        print(f"Best UTS2017_Bank model: {best_bank['model']} with {best_bank['test_accuracy']:.4f} test accuracy")

def main():
    """Main analysis function"""
    print("Analyzing Vietnamese Text Classification Training Results...")

    results = analyze_all_runs()

    if not results:
        print("No training results found in runs/ directory.")
        return

    print(f"Found {len(results)} training runs.")
    print_comparison_table(results)

    # Create summary statistics
    vntc_results = [r for r in results if r['dataset'] == 'VNTC']
    bank_results = [r for r in results if r['dataset'] == 'UTS2017_Bank']

    print("\nSummary:")
    print(f"- VNTC runs: {len(vntc_results)}")
    print(f"- UTS2017_Bank runs: {len(bank_results)}")

    if vntc_results:
        avg_vntc_acc = sum(r['test_accuracy'] for r in vntc_results) / len(vntc_results)
        print(f"- Average VNTC test accuracy: {avg_vntc_acc:.4f}")

    if bank_results:
        avg_bank_acc = sum(r['test_accuracy'] for r in bank_results) / len(bank_results)
        print(f"- Average UTS2017_Bank test accuracy: {avg_bank_acc:.4f}")

if __name__ == "__main__":
    main()