diff --git a/video2x/interpolator.py b/video2x/interpolator.py index fe58451..4ed8bb6 100755 --- a/video2x/interpolator.py +++ b/video2x/interpolator.py @@ -21,16 +21,16 @@ Author: K4YT3X """ import time +from importlib import import_module from loguru import logger from PIL import ImageChops, ImageStat -from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife from .processor import Processor class Interpolator: - ALGORITHM_CLASSES = {"rife": Rife} + ALGORITHM_CLASSES = {"rife": "rife_ncnn_vulkan_python.rife_ncnn_vulkan.Rife"} processor_objects = {} @@ -43,9 +43,16 @@ class Interpolator: if difference_ratio < difference_threshold: processor_object = self.processor_objects.get(algorithm) + if processor_object is None: - processor_object = self.ALGORITHM_CLASSES[algorithm](0) + module_name, class_name = self.ALGORITHM_CLASSES[algorithm].rsplit( + ".", 1 + ) + processor_module = import_module(module_name) + processor_class = getattr(processor_module, class_name) + processor_object = processor_class(0) self.processor_objects[algorithm] = processor_object + interpolated_image = processor_object.process(image0, image1) else: diff --git a/video2x/upscaler.py b/video2x/upscaler.py index b9d9cb9..697a553 100755 --- a/video2x/upscaler.py +++ b/video2x/upscaler.py @@ -22,13 +22,9 @@ Author: K4YT3X import math import time +from importlib import import_module -from anime4k_python import Anime4K from PIL import Image -from realcugan_ncnn_vulkan_python import Realcugan -from realsr_ncnn_vulkan_python import Realsr -from srmd_ncnn_vulkan_python import Srmd -from waifu2x_ncnn_vulkan_python import Waifu2x from .processor import Processor @@ -45,11 +41,11 @@ class Upscaler: } ALGORITHM_CLASSES = { - "anime4k": Anime4K, - "realcugan": Realcugan, - "realsr": Realsr, - "srmd": Srmd, - "waifu2x": Waifu2x, + "anime4k": "anime4k_python.Anime4K", + "realcugan": "realcugan_ncnn_vulkan_python.Realcugan", + "realsr": "realsr_ncnn_vulkan_python.Realsr", + "srmd": "srmd_ncnn_vulkan_python.Srmd", + "waifu2x": "waifu2x_ncnn_vulkan_python.Waifu2x", } processor_objects = {} @@ -148,9 +144,12 @@ class Upscaler: # create a new object if none are available processor_object = self.processor_objects.get((algorithm, task)) if processor_object is None: - processor_object = self.ALGORITHM_CLASSES[algorithm]( - noise=noise, scale=task + module_name, class_name = self.ALGORITHM_CLASSES[algorithm].rsplit( + ".", 1 ) + processor_module = import_module(module_name) + processor_class = getattr(processor_module, class_name) + processor_object = processor_class(noise=noise, scale=task) self.processor_objects[(algorithm, task)] = processor_object # process the image with the selected algorithm diff --git a/video2x/video2x.py b/video2x/video2x.py index 2b62496..240ff79 100755 --- a/video2x/video2x.py +++ b/video2x/video2x.py @@ -37,9 +37,10 @@ import signal import sys import time from enum import Enum +from importlib import import_module from multiprocessing import Manager, Pool, Queue, Value from pathlib import Path -from typing import Any, Callable, Optional +from typing import Callable, Optional import ffmpeg from cv2 import cv2 @@ -156,9 +157,12 @@ class Video2X: # process by directly invoking the # if the selected algorithm does not support frameserving if mode == ProcessingMode.UPSCALE: - standalone_processor: Any = Upscaler.ALGORITHM_CLASSES[ + standalone_processor_path: str = Upscaler.ALGORITHM_CLASSES[ processing_settings[2] ] + module_name, class_name = standalone_processor_path.rsplit(".", 1) + processor_module = import_module(module_name) + standalone_processor = getattr(processor_module, class_name) if getattr(standalone_processor, "process", None) is None: logger.warning("No progress bar available for this processor") standalone_processor().process_video( @@ -172,9 +176,12 @@ class Video2X: return # elif mode == ProcessingMode.INTERPOLATE: else: - standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[ + standalone_processor_path: str = Interpolator.ALGORITHM_CLASSES[ processing_settings[1] ] + module_name, class_name = standalone_processor_path.rsplit(".", 1) + processor_module = import_module(module_name) + standalone_processor = getattr(processor_module, class_name) if getattr(standalone_processor, "process", None) is None: logger.warning("No progress bar available for this processor") standalone_processor().process_video(