OmniSVG commited on
Commit
062fa0c
·
verified ·
1 Parent(s): aac856b

Update decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +60 -10
decoder.py CHANGED
@@ -26,29 +26,79 @@ def load_config(config_path=None):
26
  return config
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  class SketchDecoder(nn.Module):
30
  """
31
- Autoregressive generative model
32
  """
33
- def __init__(self, config_path=None, model_path=None, **kwargs):
 
 
 
 
 
 
 
 
 
34
  super().__init__()
35
 
36
  config_data = load_config(config_path)
37
 
 
 
 
 
 
 
 
 
38
  model_config = config_data.get('model', {})
39
- huggingface_config = config_data.get('huggingface', {})
40
 
41
  self.bos_token_id = model_config['bos_token_id']
42
  self.eos_token_id = model_config['eos_token_id']
43
  self.pad_token_id = model_config['pad_token_id']
44
 
45
- self.vocab_size = model_config.get(
46
- 'vocab_size',
47
- max(self.bos_token_id, self.eos_token_id, self.pad_token_id) + 1
48
- )
 
 
 
49
 
 
50
  if model_path is None:
51
- model_path = huggingface_config['qwen_model']
 
 
 
 
 
 
52
 
53
  config = AutoConfig.from_pretrained(
54
  model_path,
@@ -61,7 +111,7 @@ class SketchDecoder(nn.Module):
61
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
  model_path,
63
  config=config,
64
- torch_dtype=torch.bfloat16,
65
  attn_implementation="sdpa",
66
  device_map="auto",
67
  ignore_mismatched_sizes=True
@@ -70,4 +120,4 @@ class SketchDecoder(nn.Module):
70
  self.transformer.resize_token_embeddings(self.vocab_size)
71
 
72
  def forward(self, *args, **kwargs):
73
- raise NotImplementedError("Forward pass not included in open-source version")
 
26
  return config
27
 
28
 
29
+ def get_model_specific_value(config, model_size, *keys):
30
+ """Get model-specific config value with fallback to shared config."""
31
+ # Try model-specific config first
32
+ model_cfg = config.get('models', {}).get(model_size, {})
33
+ value = model_cfg
34
+ for key in keys:
35
+ if isinstance(value, dict) and key in value:
36
+ value = value[key]
37
+ else:
38
+ value = None
39
+ break
40
+
41
+ # Fallback to shared config if not found
42
+ if value is None:
43
+ value = config
44
+ for key in keys:
45
+ if isinstance(value, dict) and key in value:
46
+ value = value[key]
47
+ else:
48
+ return None
49
+
50
+ return value
51
+
52
+
53
  class SketchDecoder(nn.Module):
54
  """
55
+ Autoregressive generative model - supports both 8B and 4B models
56
  """
57
+ def __init__(self, config_path=None, model_path=None, model_size=None, **kwargs):
58
+ """
59
+ Initialize SketchDecoder.
60
+
61
+ Args:
62
+ config_path: Path to config.yaml
63
+ model_path: HuggingFace model path (overrides config if provided)
64
+ model_size: Model size ("8B" or "4B"). If None, uses default from config.
65
+ **kwargs: Additional arguments (e.g., torch_dtype, pix_len, text_len)
66
+ """
67
  super().__init__()
68
 
69
  config_data = load_config(config_path)
70
 
71
+ # Determine model size
72
+ self.model_size = model_size or config_data.get('default_model_size', '8B')
73
+ if self.model_size not in config_data.get('models', {}):
74
+ raise ValueError(f"Invalid model_size: {self.model_size}. Must be one of: {list(config_data.get('models', {}).keys())}")
75
+
76
+ print(f"[SketchDecoder] Initializing with model_size: {self.model_size}")
77
+
78
+ # Get model-specific and shared configs
79
  model_config = config_data.get('model', {})
 
80
 
81
  self.bos_token_id = model_config['bos_token_id']
82
  self.eos_token_id = model_config['eos_token_id']
83
  self.pad_token_id = model_config['pad_token_id']
84
 
85
+ # Get vocab_size from model-specific config
86
+ self.vocab_size = get_model_specific_value(config_data, self.model_size, 'model', 'vocab_size')
87
+ if self.vocab_size is None:
88
+ self.vocab_size = model_config.get(
89
+ 'vocab_size',
90
+ max(self.bos_token_id, self.eos_token_id, self.pad_token_id) + 1
91
+ )
92
 
93
+ # Determine model path
94
  if model_path is None:
95
+ model_path = get_model_specific_value(config_data, self.model_size, 'huggingface', 'qwen_model')
96
+
97
+ print(f"[SketchDecoder] Using Qwen model: {model_path}")
98
+ print(f"[SketchDecoder] Vocab size: {self.vocab_size}")
99
+
100
+ # Get torch_dtype from kwargs or use default
101
+ torch_dtype = kwargs.get('torch_dtype', torch.bfloat16)
102
 
103
  config = AutoConfig.from_pretrained(
104
  model_path,
 
111
  self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
112
  model_path,
113
  config=config,
114
+ torch_dtype=torch_dtype,
115
  attn_implementation="sdpa",
116
  device_map="auto",
117
  ignore_mismatched_sizes=True
 
120
  self.transformer.resize_token_embeddings(self.vocab_size)
121
 
122
  def forward(self, *args, **kwargs):
123
+ raise NotImplementedError("Forward pass not included in open-source version")