maomao88 commited on
Commit
a084789
·
1 Parent(s): 70ddcac

display model structure

Browse files
backend/__pycache__/hf_model_utils.cpython-313.pyc CHANGED
Binary files a/backend/__pycache__/hf_model_utils.cpython-313.pyc and b/backend/__pycache__/hf_model_utils.cpython-313.pyc differ
 
backend/hf_model_utils.py CHANGED
@@ -28,6 +28,8 @@ def module_hash(module):
28
  rep_bytes = str(rep).encode('utf-8')
29
  return hashlib.md5(rep_bytes).hexdigest()
30
 
 
 
31
 
32
  def hf_style_structural_dict(module):
33
  """
@@ -52,7 +54,7 @@ def hf_style_structural_dict(module):
52
  count = 1
53
  j = i + 1
54
  # Count consecutive children that are structurally identical
55
- while j < len(children) and module_hash(children[j][1]) == current_hash:
56
  count += 1
57
  j += 1
58
 
@@ -93,6 +95,10 @@ def get_model_structure(model_name: str, model_type: str | None):
93
  "hidden_size": getattr(config, "hidden_size", None),
94
  "num_hidden_layers": getattr(config, "num_hidden_layers", None),
95
  "num_attention_heads": getattr(config, "num_attention_heads", None),
 
 
 
 
96
  "layers": hf_style_structural_dict(model)
97
  }
98
 
 
28
  rep_bytes = str(rep).encode('utf-8')
29
  return hashlib.md5(rep_bytes).hexdigest()
30
 
31
+ def is_number_string(value):
32
+ return isinstance(value, str) and value.isdigit()
33
 
34
  def hf_style_structural_dict(module):
35
  """
 
54
  count = 1
55
  j = i + 1
56
  # Count consecutive children that are structurally identical
57
+ while j < len(children) and is_number_string(name) and module_hash(children[j][1]) == current_hash:
58
  count += 1
59
  j += 1
60
 
 
95
  "hidden_size": getattr(config, "hidden_size", None),
96
  "num_hidden_layers": getattr(config, "num_hidden_layers", None),
97
  "num_attention_heads": getattr(config, "num_attention_heads", None),
98
+ "image_size": getattr(config, "image_size", None),
99
+ "intermediate_size": getattr(config, "intermediate_size", None),
100
+ "patch_size": getattr(config, "patch_size", None),
101
+ "vocab_size": getattr(config, "vocab_size", None),
102
  "layers": hf_style_structural_dict(model)
103
  }
104
 
frontend/src/App.jsx CHANGED
@@ -1,6 +1,7 @@
1
  import { useState } from "react";
2
  import ModelInputBar from "./components/ModelInputBar";
3
  import ModelLayersCard from "./components/ModelLayersCard";
 
4
 
5
  export default function App() {
6
  const [structure, setStructure] = useState(null);
@@ -40,7 +41,7 @@ export default function App() {
40
  <div className="w-full max-w-3xl bg-white rounded-2xl shadow-lg p-6">
41
  {/* Header */}
42
  <h1 className="text-3xl font-bold text-center text-slate-800 mb-4">
43
- Hugging Face Model Structure Viewer
44
  </h1>
45
  <p className="text-center text-slate-500 mb-6">
46
  Enter a model name (e.g. <code>deepseek-ai/deepseek-moe-16b-base</code>) to view its
@@ -56,19 +57,19 @@ export default function App() {
56
  <div className="text-red-600 text-center font-medium mb-4">{error}</div>
57
  )}
58
 
59
- {/* Model Structure */}
60
- {structure && (
61
- <div className="max-h-[500px] overflow-y-auto bg-slate-50 rounded-xl p-4 border border-slate-200 shadow-inner">
62
- <pre className="text-sm text-slate-800 whitespace-pre-wrap">
63
- {structure.model_type}
64
- </pre>
65
- </div>
66
- )}
67
  </div>
68
 
69
- {/* Model Layers Card */}
70
  {structure && (
71
- <ModelLayersCard layers={structure?.layers?.children || {}} />
72
  )}
73
 
74
  {/* Footer */}
 
1
  import { useState } from "react";
2
  import ModelInputBar from "./components/ModelInputBar";
3
  import ModelLayersCard from "./components/ModelLayersCard";
4
+ import ModelInfo from "./components/ModelInfo";
5
 
6
  export default function App() {
7
  const [structure, setStructure] = useState(null);
 
41
  <div className="w-full max-w-3xl bg-white rounded-2xl shadow-lg p-6">
42
  {/* Header */}
43
  <h1 className="text-3xl font-bold text-center text-slate-800 mb-4">
44
+ Transformer Model Structure Viewer
45
  </h1>
46
  <p className="text-center text-slate-500 mb-6">
47
  Enter a model name (e.g. <code>deepseek-ai/deepseek-moe-16b-base</code>) to view its
 
57
  <div className="text-red-600 text-center font-medium mb-4">{error}</div>
58
  )}
59
 
60
+ {/* Model Info */}
61
+ {structure && (
62
+ <ModelInfo structure={
63
+ Object.fromEntries(
64
+ Object.entries(structure).filter(([key]) => key !== "layers")
65
+ )
66
+ } />
67
+ )}
68
  </div>
69
 
70
+ {/* Model Layers */}
71
  {structure && (
72
+ <ModelLayersCard name={structure?.layers?.class_name} layers={structure?.layers?.children || {}} />
73
  )}
74
 
75
  {/* Footer */}
frontend/src/components/ModelInfo.jsx ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React from "react";
2
+
3
+ export default function ModelStructure({ structure }) {
4
+ if (!structure) return null;
5
+
6
+ const items = {
7
+ model_type: "Model Type",
8
+ hidden_size: "Hidden Size",
9
+ num_attention_heads: "Number of Attention Heads",
10
+ num_hidden_layers: "Number of Hidden Layers",
11
+ image_size: "Image Size",
12
+ intermediate_size: "Intermediate Size",
13
+ patch_size: "Patch Size",
14
+ vocab_size: "Vocab Size",
15
+ }
16
+
17
+ // Find which keys in structure match our interest list
18
+ const matched = Object.keys(items).filter((k) => k in structure && !!structure[k]);
19
+
20
+ return (
21
+ <div className="max-h-[500px] overflow-y-auto bg-slate-50 rounded-xl p-4 border border-slate-200 shadow-inner">
22
+ {matched.length === 0 ? (
23
+ <pre className="text-sm text-slate-800 whitespace-pre-wrap">{structure.model_type}</pre>
24
+ ) : (
25
+ <div className="flex flex-wrap gap-2">
26
+ {matched.map((k) => (
27
+ <div
28
+ key={k}
29
+ className="bg-sky-100 border border-sky-300 text-sky-900 text-sm px-3 py-1 rounded-full flex items-center gap-3 shadow-sm"
30
+ >
31
+ <div className="font-medium">{items[k]}:</div>
32
+ <div className="text-sky-700">{String(structure[k])}</div>
33
+ </div>
34
+ ))}
35
+ </div>
36
+ )}
37
+ </div>
38
+ );
39
+ }
frontend/src/components/ModelInputBar.jsx CHANGED
@@ -31,51 +31,55 @@ export default function ModelInputBar({ loading, fetchModelStructure }) {
31
  }
32
  }
33
 
34
- const [focused, setFocused] = useState(false);
35
-
36
  return (
37
- <div className="flex flex-col gap-2 mb-4 w-full">
38
- <div className="flex gap-2 w-full">
39
- <input
40
- type="text"
41
- value={modelName}
42
- onChange={(e) => setModelName(e.target.value)}
43
- placeholder={placeholder}
44
- className="flex-1 px-4 py-2 rounded-lg border border-slate-200 focus:outline-none focus:ring-2 focus:ring-sky-300"
45
- onKeyDown={(e) => {
46
- if (e.key === "Enter") handleFetch();
47
- }}
48
- onFocus={() => setFocused(true)}
49
- onBlur={() => setFocused(false)}
50
- aria-label="Model name"
51
- />
52
-
53
- <button
54
- onClick={handleFetch}
55
- disabled={loading}
56
- className={`px-4 py-2 rounded-lg font-medium text-white ${
57
- loading ? "bg-slate-400 cursor-wait" : "bg-sky-600 hover:bg-sky-700"
58
- }`}
59
- >
60
- {loading ? "Loading..." : "Fetch"}
61
- </button>
62
- </div>
63
-
64
- <div className="flex flex-wrap items-center gap-4 text-sm text-slate-700">
65
- {options.map((option) => (
66
- <label key={option.value} className="flex items-center gap-2">
67
  <input
68
- type="radio"
69
- name="fetchOption"
70
- value={option.value}
71
- checked={selectedOption === option.value}
72
- onChange={handelRadioChange}
73
- className="w-4 h-4"
 
 
 
74
  />
75
- <span>{option.label}</span>
76
- </label>
77
- ))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  </div>
79
- </div>
80
- );
81
  }
 
31
  }
32
  }
33
 
 
 
34
  return (
35
+ <div className="flex flex-col gap-2 mb-4 w-full">
36
+ {/* Responsive flex row/col for title and input/button */}
37
+ <div className="flex flex-col sm:flex-row sm:items-center gap-2 w-full py-2">
38
+ <h3 className="text-sm font-semibold text-slate-800 mb-2 whitespace-nowrap sm:mb-0">
39
+ Model Name:
40
+ </h3>
41
+ <div className="flex flex-1 gap-2">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  <input
43
+ type="text"
44
+ value={modelName}
45
+ onChange={(e) => setModelName(e.target.value)}
46
+ placeholder={placeholder}
47
+ className="flex-1 px-4 py-2 rounded-lg border border-slate-200 focus:outline-none focus:ring-2 focus:ring-sky-300"
48
+ onKeyDown={(e) => {
49
+ if (e.key === "Enter") handleFetch();
50
+ }}
51
+ aria-label="Model name"
52
  />
53
+ <button
54
+ onClick={handleFetch}
55
+ disabled={loading}
56
+ className={`model-name px-4 py-2 rounded-lg font-medium text-white ${
57
+ loading ? "bg-slate-400 cursor-wait" : "bg-sky-600 hover:bg-sky-700"
58
+ }`}
59
+ >
60
+ {loading ? "Loading..." : "View"}
61
+ </button>
62
+ </div>
63
+ </div>
64
+
65
+ <div className="flex flex-col sm:flex-row sm:top gap-2 w-full py-2">
66
+ <h3 className="text-sm font-semibold text-slate-800 mb-2 whitespace-nowrap sm:mb-0">Model Type:</h3>
67
+ <div className="flex flex-wrap items-center gap-4 text-sm text-slate-700">
68
+ {options.map((option) => (
69
+ <label key={option.value} className="model-type flex items-center gap-2">
70
+ <input
71
+ type="radio"
72
+ name="fetchOption"
73
+ value={option.value}
74
+ checked={selectedOption === option.value}
75
+ onChange={handelRadioChange}
76
+ className="w-4 h-4"
77
+ />
78
+ <span>{option.label}</span>
79
+ </label>
80
+ ))}
81
+ </div>
82
+ </div>
83
  </div>
84
+ );
 
85
  }
frontend/src/components/ModelLayersCard.jsx CHANGED
@@ -1,28 +1,53 @@
1
  import React from "react";
2
 
3
- function LayerNode({ node }) {
4
- if (!node) return null;
 
5
 
 
 
6
  const name = node.class_name || node.name || "<unknown>";
7
  const shape = node.params?.weight?.shape || [];
8
 
9
  // Normalize children: object map
10
  let children = [];
 
11
  if (node.children && typeof node.children === "object") {
12
  children = Object.values(node.children);
 
13
  }
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  return (
16
- <div className="pl-2">
17
- <div className="flex items-center gap-2">
 
 
 
18
  <div className="text-sm text-slate-800 font-medium">{name}</div>
 
19
  <div className="text-xs text-slate-500">{shape.join(" x ")}</div>
 
20
  </div>
21
 
22
  {children.length > 0 && (
23
- <div className="pl-4 mt-2 border-l border-slate-100">
24
  {children.map((child, idx) => (
25
- <LayerNode key={idx} node={child} />
26
  ))}
27
  </div>
28
  )}
@@ -30,15 +55,18 @@ function LayerNode({ node }) {
30
  );
31
  }
32
 
33
- export default function ModelLayersCard({ layers = {} }) {
34
  let rootNodes = [];
 
35
 
36
  if (layers && typeof layers === "object") {
37
  if (layers.children && typeof layers.children === "object") {
38
  rootNodes = Object.values(layers.children);
 
39
  } else {
40
  // Fallback: convert the top-level keyed object into an array
41
  rootNodes = Object.values(layers);
 
42
  }
43
  }
44
 
@@ -48,10 +76,10 @@ export default function ModelLayersCard({ layers = {} }) {
48
 
49
  return (
50
  <div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4">
51
- <h2 className="text-lg font-semibold text-slate-800 mb-3">Model Layers</h2>
52
  <div className="space-y-2">
53
  {rootNodes.map((node, idx) => (
54
- <LayerNode key={idx} node={node} />
55
  ))}
56
  </div>
57
  </div>
 
1
  import React from "react";
2
 
3
+ function isNumberString(value) {
4
+ return typeof value === "string" && /^[0-9]+$/.test(value);
5
+ }
6
 
7
+ function LayerNode({ node, nodeKey, last, depth = 0 }) {
8
+ if (!node) return null;
9
  const name = node.class_name || node.name || "<unknown>";
10
  const shape = node.params?.weight?.shape || [];
11
 
12
  // Normalize children: object map
13
  let children = [];
14
+ let keys = [];
15
  if (node.children && typeof node.children === "object") {
16
  children = Object.values(node.children);
17
+ keys = Object.keys(node.children);
18
  }
19
 
20
+ // choose background by depth for subtle alternation
21
+ const bgByDepth = ["bg-slate-50", "bg-slate-100", "bg-slate-200", "bg-slate-50"];
22
+ const bgClass = bgByDepth[depth % bgByDepth.length];
23
+ const isEmbed = String(name).toLowerCase().includes("embed");
24
+ const embedStyle = isEmbed ? { backgroundColor: "#fbdfe2" } : undefined;
25
+ const isAttn = String(name).toLowerCase().includes("attention") || String(name).toLowerCase().includes("attn");
26
+ const attnStyle = isAttn ? { backgroundColor: "#fddfba" } : undefined;
27
+ const isFFN = String(name).toLowerCase().includes("ffn") || String(name).toLowerCase().includes("mlp");
28
+ const ffnStyle = isFFN ? { backgroundColor: "#c2e6f8" } : undefined;
29
+ const isNorm = String(name).toLowerCase().includes("norm");
30
+ const normStyle = isNorm ? { backgroundColor: "#f3f7c3" } : undefined;
31
+ const isLastNorm = last && isNorm;
32
+ const lastNormStyle = isLastNorm ? { backgroundColor: "#DEDFF1" } : undefined;
33
+
34
+
35
  return (
36
+ <div
37
+ className={`pl-2 ${!isEmbed ? bgClass : ""} rounded-md border border-slate-200 p-2`}
38
+ style={embedStyle || attnStyle || ffnStyle || lastNormStyle || normStyle}
39
+ >
40
+ <div className="flex items-center gap-3">
41
  <div className="text-sm text-slate-800 font-medium">{name}</div>
42
+ {nodeKey && !isNumberString(nodeKey) && <div className="text-sm text-slate-800 font-medium">({nodeKey})</div>}
43
  <div className="text-xs text-slate-500">{shape.join(" x ")}</div>
44
+ <div>{node.num_repeats && <span className="text-blue-600 font-bold tracking-wide">x {node.num_repeats}</span>}</div>
45
  </div>
46
 
47
  {children.length > 0 && (
48
+ <div className="pl-4 mt-2">
49
  {children.map((child, idx) => (
50
+ <LayerNode key={idx} node={child} nodeKey={keys[idx]} depth={depth + 1} />
51
  ))}
52
  </div>
53
  )}
 
55
  );
56
  }
57
 
58
+ export default function ModelLayersCard({ layers = {}, name = "" }) {
59
  let rootNodes = [];
60
+ let rootKeys = [];
61
 
62
  if (layers && typeof layers === "object") {
63
  if (layers.children && typeof layers.children === "object") {
64
  rootNodes = Object.values(layers.children);
65
+ rootKeys = Object.keys(layers.children);
66
  } else {
67
  // Fallback: convert the top-level keyed object into an array
68
  rootNodes = Object.values(layers);
69
+ rootKeys = Object.keys(layers);
70
  }
71
  }
72
 
 
76
 
77
  return (
78
  <div className="w-full max-w-3xl mt-6 bg-white rounded-2xl shadow-lg p-4">
79
+ <h2 className="text-lg font-semibold text-slate-800 mb-3">{name}</h2>
80
  <div className="space-y-2">
81
  {rootNodes.map((node, idx) => (
82
+ <LayerNode key={idx} node={node} nodeKey={rootKeys[idx]} last={idx == rootNodes.length - 1} />
83
  ))}
84
  </div>
85
  </div>
frontend/src/index.css CHANGED
@@ -1 +1,5 @@
1
  @import "tailwindcss";
 
 
 
 
 
1
  @import "tailwindcss";
2
+
3
+ .model-name, .model-type {
4
+ cursor: pointer;
5
+ }