|
|
|
|
|
""" |
|
|
ONNX ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ฉํฐํค๋ ์ด๋ฏธ์ง ๋ถ๋ฅ ์ถ๋ก ์์ |
|
|
์ ์ฒด ๋ชจ๋ธ(model.onnx) ๋๋ ๋ถ๋ฆฌ ๋ชจ๋ธ(encoder.onnx + head.onnx) ์ฌ์ฉ ๊ฐ๋ฅ |
|
|
""" |
|
|
|
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torchvision.transforms as transforms |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def load_model_info(model_info_path): |
|
|
"""๋ชจ๋ธ ์ ๋ณด ๋ก๋""" |
|
|
with open(model_info_path, 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
|
|
|
def preprocess_image(image_path): |
|
|
"""์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ""" |
|
|
image = Image.open(image_path).convert('RGB') |
|
|
tensor = transform(image) |
|
|
return tensor.unsqueeze(0).numpy() |
|
|
|
|
|
def softmax(x): |
|
|
"""Softmax ํจ์""" |
|
|
exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) |
|
|
return exp_x / np.sum(exp_x, axis=1, keepdims=True) |
|
|
|
|
|
def predict_image_full_model(model_path, model_info_path, image_path): |
|
|
"""์ ์ฒด ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ ์์ธก""" |
|
|
|
|
|
|
|
|
model_info = load_model_info(model_info_path) |
|
|
|
|
|
|
|
|
session = ort.InferenceSession(model_path) |
|
|
|
|
|
|
|
|
image_array = preprocess_image(image_path) |
|
|
|
|
|
|
|
|
inputs = {'image': image_array} |
|
|
outputs = session.run(None, inputs) |
|
|
|
|
|
|
|
|
results = {} |
|
|
head_names = list(model_info['output_specification']['heads'].keys()) |
|
|
|
|
|
for i, output_name in enumerate(head_names): |
|
|
logits = outputs[i] |
|
|
probabilities = softmax(logits)[0] |
|
|
|
|
|
|
|
|
class_names = model_info['class_mappings'].get(output_name, {}) |
|
|
|
|
|
|
|
|
pred_idx = np.argmax(probabilities) |
|
|
pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}") |
|
|
pred_prob = probabilities[pred_idx] |
|
|
|
|
|
|
|
|
top3_indices = np.argsort(probabilities)[-3:][::-1] |
|
|
top3_results = [] |
|
|
for idx in top3_indices: |
|
|
class_name = class_names.get(str(idx), f"Class_{idx}") |
|
|
prob = probabilities[idx] |
|
|
top3_results.append({'class': class_name, 'probability': float(prob)}) |
|
|
|
|
|
results[output_name] = { |
|
|
'predicted_class': pred_class, |
|
|
'confidence': float(pred_prob), |
|
|
'top3': top3_results |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def predict_image_split_model(encoder_path, head_path, model_info_path, image_path): |
|
|
"""๋ถ๋ฆฌ ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ ์์ธก""" |
|
|
|
|
|
|
|
|
model_info = load_model_info(model_info_path) |
|
|
|
|
|
|
|
|
encoder_session = ort.InferenceSession(encoder_path) |
|
|
head_session = ort.InferenceSession(head_path) |
|
|
|
|
|
|
|
|
image_array = preprocess_image(image_path) |
|
|
|
|
|
|
|
|
encoder_inputs = {'image': image_array} |
|
|
features = encoder_session.run(None, encoder_inputs)[0] |
|
|
|
|
|
|
|
|
head_inputs = {'features': features} |
|
|
outputs = head_session.run(None, head_inputs) |
|
|
|
|
|
|
|
|
results = {} |
|
|
head_names = list(model_info['output_specification']['heads'].keys()) |
|
|
|
|
|
for i, output_name in enumerate(head_names): |
|
|
logits = outputs[i] |
|
|
probabilities = softmax(logits)[0] |
|
|
|
|
|
|
|
|
class_names = model_info['class_mappings'].get(output_name, {}) |
|
|
|
|
|
|
|
|
pred_idx = np.argmax(probabilities) |
|
|
pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}") |
|
|
pred_prob = probabilities[pred_idx] |
|
|
|
|
|
|
|
|
top3_indices = np.argsort(probabilities)[-3:][::-1] |
|
|
top3_results = [] |
|
|
for idx in top3_indices: |
|
|
class_name = class_names.get(str(idx), f"Class_{idx}") |
|
|
prob = probabilities[idx] |
|
|
top3_results.append({'class': class_name, 'probability': float(prob)}) |
|
|
|
|
|
results[output_name] = { |
|
|
'predicted_class': pred_class, |
|
|
'confidence': float(pred_prob), |
|
|
'top3': top3_results |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model_info_path = "model_info.json" |
|
|
image_path = "test_image.jpg" |
|
|
|
|
|
|
|
|
if Path("encoder.onnx").exists() and Path("head.onnx").exists(): |
|
|
print("๋ถ๋ฆฌ ๋ชจ๋ธ ์ฌ์ฉ") |
|
|
results = predict_image_split_model("encoder.onnx", "head.onnx", model_info_path, image_path) |
|
|
elif Path("model.onnx").exists(): |
|
|
print("์ ์ฒด ๋ชจ๋ธ ์ฌ์ฉ") |
|
|
results = predict_image_full_model("model.onnx", model_info_path, image_path) |
|
|
else: |
|
|
print("ONNX ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ต๋๋ค.") |
|
|
exit(1) |
|
|
|
|
|
print(f"\n์ด๋ฏธ์ง ๋ถ๋ฅ ๊ฒฐ๊ณผ: {image_path}") |
|
|
print("=" * 50) |
|
|
|
|
|
for output_name, result in results.items(): |
|
|
print(f"\n{output_name.upper()}:") |
|
|
print(f" ์์ธก ํด๋์ค: {result['predicted_class']}") |
|
|
print(f" ์ ๋ขฐ๋: {result['confidence']:.4f}") |
|
|
print(f" Top 3:") |
|
|
for i, top_result in enumerate(result['top3'], 1): |
|
|
print(f" {i}. {top_result['class']}: {top_result['probability']:.4f}") |
|
|
|