image-classifier / inference_example.py
justin-onda's picture
fix: ํ—ค๋“œ ONNX ์ถœ๋ ฅ ์ˆœ์„œ ์ˆ˜์ • (v0.2.1) [AI]
575aa0e
#!/usr/bin/env python3
"""
ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋ฉ€ํ‹ฐํ—ค๋“œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์ถ”๋ก  ์˜ˆ์ œ
์ „์ฒด ๋ชจ๋ธ(model.onnx) ๋˜๋Š” ๋ถ„๋ฆฌ ๋ชจ๋ธ(encoder.onnx + head.onnx) ์‚ฌ์šฉ ๊ฐ€๋Šฅ
"""
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import json
from pathlib import Path
# ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_model_info(model_info_path):
"""๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ"""
with open(model_info_path, 'r', encoding='utf-8') as f:
return json.load(f)
def preprocess_image(image_path):
"""์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ"""
image = Image.open(image_path).convert('RGB')
tensor = transform(image)
return tensor.unsqueeze(0).numpy() # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€
def softmax(x):
"""Softmax ํ•จ์ˆ˜"""
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return exp_x / np.sum(exp_x, axis=1, keepdims=True)
def predict_image_full_model(model_path, model_info_path, image_path):
"""์ „์ฒด ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์˜ˆ์ธก"""
# ๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ
model_info = load_model_info(model_info_path)
# ONNX ์„ธ์…˜ ์ƒ์„ฑ
session = ort.InferenceSession(model_path)
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
image_array = preprocess_image(image_path)
# ์ถ”๋ก  ์‹คํ–‰
inputs = {'image': image_array}
outputs = session.run(None, inputs)
# ๊ฒฐ๊ณผ ํ•ด์„
results = {}
head_names = list(model_info['output_specification']['heads'].keys())
for i, output_name in enumerate(head_names):
logits = outputs[i]
probabilities = softmax(logits)[0]
# ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘
class_names = model_info['class_mappings'].get(output_name, {})
# ์ตœ๊ณ  ํ™•๋ฅ  ํด๋ž˜์Šค
pred_idx = np.argmax(probabilities)
pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}")
pred_prob = probabilities[pred_idx]
# ์ƒ์œ„ 3๊ฐœ ํด๋ž˜์Šค
top3_indices = np.argsort(probabilities)[-3:][::-1]
top3_results = []
for idx in top3_indices:
class_name = class_names.get(str(idx), f"Class_{idx}")
prob = probabilities[idx]
top3_results.append({'class': class_name, 'probability': float(prob)})
results[output_name] = {
'predicted_class': pred_class,
'confidence': float(pred_prob),
'top3': top3_results
}
return results
def predict_image_split_model(encoder_path, head_path, model_info_path, image_path):
"""๋ถ„๋ฆฌ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์˜ˆ์ธก"""
# ๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ
model_info = load_model_info(model_info_path)
# ONNX ์„ธ์…˜ ์ƒ์„ฑ
encoder_session = ort.InferenceSession(encoder_path)
head_session = ort.InferenceSession(head_path)
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
image_array = preprocess_image(image_path)
# ์ธ์ฝ”๋”๋กœ ํŠน์ง• ๋ฒกํ„ฐ ์ถ”์ถœ
encoder_inputs = {'image': image_array}
features = encoder_session.run(None, encoder_inputs)[0]
# ํ—ค๋“œ๋กœ ๋ถ„๋ฅ˜
head_inputs = {'features': features}
outputs = head_session.run(None, head_inputs)
# ๊ฒฐ๊ณผ ํ•ด์„
results = {}
head_names = list(model_info['output_specification']['heads'].keys())
for i, output_name in enumerate(head_names):
logits = outputs[i]
probabilities = softmax(logits)[0]
# ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘
class_names = model_info['class_mappings'].get(output_name, {})
# ์ตœ๊ณ  ํ™•๋ฅ  ํด๋ž˜์Šค
pred_idx = np.argmax(probabilities)
pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}")
pred_prob = probabilities[pred_idx]
# ์ƒ์œ„ 3๊ฐœ ํด๋ž˜์Šค
top3_indices = np.argsort(probabilities)[-3:][::-1]
top3_results = []
for idx in top3_indices:
class_name = class_names.get(str(idx), f"Class_{idx}")
prob = probabilities[idx]
top3_results.append({'class': class_name, 'probability': float(prob)})
results[output_name] = {
'predicted_class': pred_class,
'confidence': float(pred_prob),
'top3': top3_results
}
return results
# ์‚ฌ์šฉ ์˜ˆ์‹œ
if __name__ == "__main__":
model_info_path = "model_info.json"
image_path = "test_image.jpg"
# ๋ถ„๋ฆฌ ๋ชจ๋ธ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
if Path("encoder.onnx").exists() and Path("head.onnx").exists():
print("๋ถ„๋ฆฌ ๋ชจ๋ธ ์‚ฌ์šฉ")
results = predict_image_split_model("encoder.onnx", "head.onnx", model_info_path, image_path)
elif Path("model.onnx").exists():
print("์ „์ฒด ๋ชจ๋ธ ์‚ฌ์šฉ")
results = predict_image_full_model("model.onnx", model_info_path, image_path)
else:
print("ONNX ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
exit(1)
print(f"\n์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ: {image_path}")
print("=" * 50)
for output_name, result in results.items():
print(f"\n{output_name.upper()}:")
print(f" ์˜ˆ์ธก ํด๋ž˜์Šค: {result['predicted_class']}")
print(f" ์‹ ๋ขฐ๋„: {result['confidence']:.4f}")
print(f" Top 3:")
for i, top_result in enumerate(result['top3'], 1):
print(f" {i}. {top_result['class']}: {top_result['probability']:.4f}")