Spaces:
Running
Running
| import os | |
| from functools import partial | |
| from typing import Any, Optional | |
| import torch | |
| import videosys | |
| from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port | |
| class VideoSysEngine: | |
| """ | |
| this is partly inspired by vllm | |
| """ | |
| def __init__(self, config): | |
| self.config = config | |
| self.parallel_worker_tasks = None | |
| self._init_worker(config.pipeline_cls) | |
| def _init_worker(self, pipeline_cls): | |
| world_size = self.config.num_gpus | |
| if "CUDA_VISIBLE_DEVICES" not in os.environ: | |
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size)) | |
| # Disable torch async compiling which won't work with daemonic processes | |
| os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" | |
| # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU | |
| # contention amongst the shards | |
| if "OMP_NUM_THREADS" not in os.environ: | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| # NOTE: The two following lines need adaption for multi-node | |
| assert world_size <= torch.cuda.device_count() | |
| # change addr for multi-node | |
| distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port()) | |
| if world_size == 1: | |
| self.workers = [] | |
| self.worker_monitor = None | |
| else: | |
| result_handler = ResultHandler() | |
| self.workers = [ | |
| ProcessWorkerWrapper( | |
| result_handler, | |
| partial( | |
| self._create_pipeline, | |
| pipeline_cls=pipeline_cls, | |
| rank=rank, | |
| local_rank=rank, | |
| distributed_init_method=distributed_init_method, | |
| ), | |
| ) | |
| for rank in range(1, world_size) | |
| ] | |
| self.worker_monitor = WorkerMonitor(self.workers, result_handler) | |
| result_handler.start() | |
| self.worker_monitor.start() | |
| self.driver_worker = self._create_pipeline( | |
| pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method | |
| ) | |
| # TODO: add more options here for pipeline, or wrap all options into config | |
| def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None): | |
| videosys.initialize(rank=rank, world_size=self.config.num_gpus, init_method=distributed_init_method, seed=42) | |
| pipeline = pipeline_cls(self.config) | |
| return pipeline | |
| def _run_workers( | |
| self, | |
| method: str, | |
| *args, | |
| async_run_tensor_parallel_workers_only: bool = False, | |
| max_concurrent_workers: Optional[int] = None, | |
| **kwargs, | |
| ) -> Any: | |
| """Runs the given method on all workers.""" | |
| # Start the workers first. | |
| worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers] | |
| if async_run_tensor_parallel_workers_only: | |
| # Just return futures | |
| return worker_outputs | |
| driver_worker_method = getattr(self.driver_worker, method) | |
| driver_worker_output = driver_worker_method(*args, **kwargs) | |
| # Get the results of the workers. | |
| return [driver_worker_output] + [output.get() for output in worker_outputs] | |
| def _driver_execute_model(self, *args, **kwargs): | |
| return self.driver_worker.generate(*args, **kwargs) | |
| def generate(self, *args, **kwargs): | |
| return self._run_workers("generate", *args, **kwargs)[0] | |
| def stop_remote_worker_execution_loop(self) -> None: | |
| if self.parallel_worker_tasks is None: | |
| return | |
| parallel_worker_tasks = self.parallel_worker_tasks | |
| self.parallel_worker_tasks = None | |
| # Ensure that workers exit model loop cleanly | |
| # (this will raise otherwise) | |
| self._wait_for_tasks_completion(parallel_worker_tasks) | |
| def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: | |
| """Wait for futures returned from _run_workers() with | |
| async_run_remote_workers_only to complete.""" | |
| for result in parallel_worker_tasks: | |
| result.get() | |
| def save_video(self, video, output_path): | |
| return self.driver_worker.save_video(video, output_path) | |
| def shutdown(self): | |
| if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: | |
| worker_monitor.close() | |
| torch.distributed.destroy_process_group() | |
| def __del__(self): | |
| self.shutdown() | |