devices.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import TYPE_CHECKING
  2. if TYPE_CHECKING:
  3. from render import Render
  4. try:
  5. from .error import RenderError
  6. from .device import Device
  7. except ImportError:
  8. from error import RenderError
  9. from device import Device
  10. class Devices:
  11. def __init__(self, render_instance: "Render"):
  12. self._render_instance = render_instance
  13. self._devices: set[Device] = set()
  14. # Tracks all container device paths to make sure they are not duplicated
  15. self._container_device_paths: set[str] = set()
  16. # Scan values for devices we should automatically add
  17. # for example /dev/dri for gpus
  18. self._auto_add_devices_from_values()
  19. def _auto_add_devices_from_values(self):
  20. resources = self._render_instance.values.get("resources", {})
  21. if resources.get("gpus", {}).get("use_all_gpus", False):
  22. self.add_device("/dev/dri", "/dev/dri", allow_disallowed=True)
  23. if resources["gpus"].get("kfd_device_exists", False):
  24. self.add_device("/dev/kfd", "/dev/kfd", allow_disallowed=True) # AMD ROCm
  25. def add_device(self, host_device: str, container_device: str, cgroup_perm: str = "", allow_disallowed=False):
  26. # Host device can be mapped to multiple container devices,
  27. # so we only make sure container devices are not duplicated
  28. if container_device in self._container_device_paths:
  29. raise RenderError(f"Device with container path [{container_device}] already added")
  30. self._devices.add(Device(host_device, container_device, cgroup_perm, allow_disallowed))
  31. self._container_device_paths.add(container_device)
  32. def add_usb_bus(self):
  33. self.add_device("/dev/bus/usb", "/dev/bus/usb", allow_disallowed=True)
  34. def _add_snd_device(self):
  35. self.add_device("/dev/snd", "/dev/snd", allow_disallowed=True)
  36. def _add_tun_device(self):
  37. self.add_device("/dev/net/tun", "/dev/net/tun", allow_disallowed=True)
  38. def has_devices(self):
  39. return len(self._devices) > 0
  40. # Mainly will be used from dependencies
  41. # There is no reason to pass devices to
  42. # redis or postgres for example
  43. def remove_devices(self):
  44. self._devices.clear()
  45. self._container_device_paths.clear()
  46. # Check if there are any gpu devices
  47. # Used to determine if we should add groups
  48. # like 'video' to the container
  49. def has_gpus(self):
  50. for d in self._devices:
  51. if d.host_device == "/dev/dri":
  52. return True
  53. return False
  54. def render(self) -> list[str]:
  55. return sorted([d.render() for d in self._devices])