container.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. from typing import Any, TYPE_CHECKING
  2. if TYPE_CHECKING:
  3. from render import Render
  4. from storage import IxStorage
  5. try:
  6. from .configs import ContainerConfigs
  7. from .depends import Depends
  8. from .deploy import Deploy
  9. from .device_cgroup_rules import DeviceCGroupRules
  10. from .devices import Devices
  11. from .dns import Dns
  12. from .environment import Environment
  13. from .error import RenderError
  14. from .expose import Expose
  15. from .extra_hosts import ExtraHosts
  16. from .formatter import escape_dollar, get_image_with_hashed_data
  17. from .healthcheck import Healthcheck
  18. from .labels import Labels
  19. from .ports import Ports
  20. from .restart import RestartPolicy
  21. from .tmpfs import Tmpfs
  22. from .validations import (
  23. valid_cap_or_raise,
  24. valid_cgroup_or_raise,
  25. valid_ipc_mode_or_raise,
  26. valid_network_mode_or_raise,
  27. valid_pid_mode_or_raise,
  28. valid_port_bind_mode_or_raise,
  29. valid_port_mode_or_raise,
  30. valid_pull_policy_or_raise,
  31. )
  32. from .security_opts import SecurityOpts
  33. from .storage import Storage
  34. from .sysctls import Sysctls
  35. except ImportError:
  36. from configs import ContainerConfigs
  37. from depends import Depends
  38. from deploy import Deploy
  39. from device_cgroup_rules import DeviceCGroupRules
  40. from devices import Devices
  41. from dns import Dns
  42. from environment import Environment
  43. from error import RenderError
  44. from expose import Expose
  45. from extra_hosts import ExtraHosts
  46. from formatter import escape_dollar, get_image_with_hashed_data
  47. from healthcheck import Healthcheck
  48. from labels import Labels
  49. from ports import Ports
  50. from restart import RestartPolicy
  51. from tmpfs import Tmpfs
  52. from validations import (
  53. valid_cap_or_raise,
  54. valid_cgroup_or_raise,
  55. valid_ipc_mode_or_raise,
  56. valid_network_mode_or_raise,
  57. valid_pid_mode_or_raise,
  58. valid_port_bind_mode_or_raise,
  59. valid_port_mode_or_raise,
  60. valid_pull_policy_or_raise,
  61. )
  62. from security_opts import SecurityOpts
  63. from storage import Storage
  64. from sysctls import Sysctls
  65. class Container:
  66. def __init__(self, render_instance: "Render", name: str, image: str):
  67. self._render_instance = render_instance
  68. self._name: str = name
  69. self._image: str = self._resolve_image(image)
  70. self._build_image: str = ""
  71. self._pull_policy: str = ""
  72. self._user: str = ""
  73. self._tty: bool = False
  74. self._stdin_open: bool = False
  75. self._init: bool | None = None
  76. self._read_only: bool | None = None
  77. self._extra_hosts: ExtraHosts = ExtraHosts(self._render_instance)
  78. self._hostname: str = ""
  79. self._cap_drop: set[str] = set(["ALL"]) # Drop all capabilities by default and add caps granularly
  80. self._cap_add: set[str] = set()
  81. self._security_opt: SecurityOpts = SecurityOpts(self._render_instance)
  82. self._privileged: bool = False
  83. self._group_add: set[int | str] = set()
  84. self._network_mode: str = ""
  85. self._entrypoint: list[str] = []
  86. self._command: list[str] = []
  87. self._grace_period: int | None = None
  88. self._shm_size: int | None = None
  89. self._storage: Storage = Storage(self._render_instance, self)
  90. self._tmpfs: Tmpfs = Tmpfs(self._render_instance, self)
  91. self._ipc_mode: str | None = None
  92. self._pid_mode: str | None = None
  93. self._cgroup: str | None = None
  94. self._device_cgroup_rules: DeviceCGroupRules = DeviceCGroupRules(self._render_instance)
  95. self.sysctls: Sysctls = Sysctls(self._render_instance, self)
  96. self.configs: ContainerConfigs = ContainerConfigs(self._render_instance, self._render_instance.configs)
  97. self.deploy: Deploy = Deploy(self._render_instance)
  98. self.networks: set[str] = set()
  99. self.devices: Devices = Devices(self._render_instance)
  100. self.environment: Environment = Environment(self._render_instance, self.deploy.resources)
  101. self.dns: Dns = Dns(self._render_instance)
  102. self.depends: Depends = Depends(self._render_instance)
  103. self.healthcheck: Healthcheck = Healthcheck(self._render_instance)
  104. self.labels: Labels = Labels(self._render_instance)
  105. self.restart: RestartPolicy = RestartPolicy(self._render_instance)
  106. self.ports: Ports = Ports(self._render_instance)
  107. self.expose: Expose = Expose(self._render_instance)
  108. self._auto_set_network_mode()
  109. self._auto_add_labels()
  110. self._auto_add_groups()
  111. def _auto_add_groups(self):
  112. self.add_group(568)
  113. def _auto_set_network_mode(self):
  114. if self._render_instance.values.get("network", {}).get("host_network", False):
  115. self.set_network_mode("host")
  116. def _auto_add_labels(self):
  117. labels = self._render_instance.values.get("labels", [])
  118. if not labels:
  119. return
  120. for label in labels:
  121. containers = label.get("containers", [])
  122. if not containers:
  123. raise RenderError(f'Label [{label.get("key", "")}] must have at least one container')
  124. if self._name in containers:
  125. self.labels.add_label(label["key"], label["value"])
  126. def _resolve_image(self, image: str):
  127. images = self._render_instance.values["images"]
  128. if image not in images:
  129. raise RenderError(
  130. f"Image [{image}] not found in values. " f"Available images: [{', '.join(images.keys())}]"
  131. )
  132. repo = images[image].get("repository", "")
  133. tag = images[image].get("tag", "")
  134. if not repo:
  135. raise RenderError(f"Repository not found for image [{image}]")
  136. if not tag:
  137. raise RenderError(f"Tag not found for image [{image}]")
  138. return f"{repo}:{tag}"
  139. def build_image(self, content: list[str | None]):
  140. dockerfile = f"FROM {self._image}\n"
  141. for line in content:
  142. line = line.strip() if line else ""
  143. if not line:
  144. continue
  145. if line.startswith("FROM"):
  146. # TODO: This will also block multi-stage builds
  147. # We can revisit this later if we need it
  148. raise RenderError(
  149. "FROM cannot be used in build image. Define the base image when creating the container."
  150. )
  151. dockerfile += line + "\n"
  152. self._build_image = dockerfile
  153. self._image = get_image_with_hashed_data(self._image, dockerfile)
  154. def set_pull_policy(self, pull_policy: str):
  155. self._pull_policy = valid_pull_policy_or_raise(pull_policy)
  156. def set_user(self, user: int, group: int):
  157. for i in (user, group):
  158. if not isinstance(i, int) or i < 0:
  159. raise RenderError(f"User/Group [{i}] is not valid")
  160. self._user = f"{user}:{group}"
  161. def add_extra_host(self, host: str, ip: str):
  162. self._extra_hosts.add_host(host, ip)
  163. def add_group(self, group: int | str):
  164. if isinstance(group, str):
  165. group = str(group).strip()
  166. if group.isdigit():
  167. raise RenderError(f"Group is a number [{group}] but passed as a string")
  168. if group in self._group_add:
  169. raise RenderError(f"Group [{group}] already added")
  170. self._group_add.add(group)
  171. def get_additional_groups(self) -> list[int | str]:
  172. result = []
  173. if self.deploy.resources.has_gpus() or self.devices.has_gpus():
  174. result.append(44) # video
  175. result.append(107) # render
  176. return result
  177. def get_current_groups(self) -> list[str]:
  178. result = [str(g) for g in self._group_add]
  179. result.extend([str(g) for g in self.get_additional_groups()])
  180. return result
  181. def set_tty(self, enabled: bool = False):
  182. self._tty = enabled
  183. def set_stdin(self, enabled: bool = False):
  184. self._stdin_open = enabled
  185. def set_ipc_mode(self, ipc_mode: str):
  186. self._ipc_mode = valid_ipc_mode_or_raise(ipc_mode, self._render_instance.container_names())
  187. def set_pid_mode(self, mode: str = ""):
  188. self._pid_mode = valid_pid_mode_or_raise(mode, self._render_instance.container_names())
  189. def add_device_cgroup_rule(self, dev_grp_rule: str):
  190. self._device_cgroup_rules.add_rule(dev_grp_rule)
  191. def set_cgroup(self, cgroup: str):
  192. self._cgroup = valid_cgroup_or_raise(cgroup)
  193. def set_init(self, enabled: bool = False):
  194. self._init = enabled
  195. def set_read_only(self, enabled: bool = False):
  196. self._read_only = enabled
  197. def set_hostname(self, hostname: str):
  198. self._hostname = hostname
  199. def set_grace_period(self, grace_period: int):
  200. if grace_period < 0:
  201. raise RenderError(f"Grace period [{grace_period}] cannot be negative")
  202. self._grace_period = grace_period
  203. def set_privileged(self, enabled: bool = False):
  204. self._privileged = enabled
  205. def clear_caps(self):
  206. self._cap_add.clear()
  207. self._cap_drop.clear()
  208. def add_caps(self, caps: list[str]):
  209. for c in caps:
  210. if c in self._cap_add:
  211. raise RenderError(f"Capability [{c}] already added")
  212. self._cap_add.add(valid_cap_or_raise(c))
  213. def add_security_opt(self, key: str, value: str | bool | None = None, arg: str | None = None):
  214. self._security_opt.add_opt(key, value, arg)
  215. def remove_security_opt(self, key: str):
  216. self._security_opt.remove_opt(key)
  217. def set_network_mode(self, mode: str):
  218. self._network_mode = valid_network_mode_or_raise(mode, self._render_instance.container_names())
  219. def add_port(self, port_config: dict | None = None, dev_config: dict | None = None):
  220. port_config = port_config or {}
  221. dev_config = dev_config or {}
  222. # Merge port_config and dev_config (dev_config has precedence)
  223. config = port_config | dev_config
  224. bind_mode = valid_port_bind_mode_or_raise(config.get("bind_mode", ""))
  225. # Skip port if its neither published nor exposed
  226. if not bind_mode:
  227. return
  228. # Collect port config
  229. mode = valid_port_mode_or_raise(config.get("mode", "ingress"))
  230. host_port = config.get("port_number", 0)
  231. container_port = config.get("container_port", 0) or host_port
  232. protocol = config.get("protocol", "tcp")
  233. host_ips = config.get("host_ips") or ["0.0.0.0", "::"]
  234. if not isinstance(host_ips, list):
  235. raise RenderError(f"Expected [host_ips] to be a list, got [{host_ips}]")
  236. if bind_mode == "published":
  237. for host_ip in host_ips:
  238. self.ports._add_port(
  239. host_port, container_port, {"protocol": protocol, "host_ip": host_ip, "mode": mode}
  240. )
  241. elif bind_mode == "exposed":
  242. self.expose.add_port(container_port, protocol)
  243. def set_entrypoint(self, entrypoint: list[str]):
  244. self._entrypoint = [escape_dollar(str(e)) for e in entrypoint]
  245. def set_command(self, command: list[str]):
  246. self._command = [escape_dollar(str(e)) for e in command]
  247. def add_storage(self, mount_path: str, config: "IxStorage"):
  248. if config.get("type", "") == "tmpfs":
  249. self._tmpfs.add(mount_path, config)
  250. else:
  251. self._storage.add(mount_path, config)
  252. def add_docker_socket(self, read_only: bool = True, mount_path: str = "/var/run/docker.sock"):
  253. self.add_group(999)
  254. self._storage._add_docker_socket(read_only, mount_path)
  255. def add_udev(self, read_only: bool = True, mount_path: str = "/run/udev"):
  256. self._storage._add_udev(read_only, mount_path)
  257. def add_tun_device(self):
  258. self.devices._add_tun_device()
  259. def add_snd_device(self):
  260. self.add_group(29)
  261. self.devices._add_snd_device()
  262. def set_shm_size_mb(self, size: int):
  263. self._shm_size = size
  264. # Easily remove devices from the container
  265. # Useful in dependencies like postgres and redis
  266. # where there is no need to pass devices to them
  267. def remove_devices(self):
  268. self.deploy.resources.remove_devices()
  269. self.devices.remove_devices()
  270. @property
  271. def storage(self):
  272. return self._storage
  273. def render(self) -> dict[str, Any]:
  274. if self._network_mode and self.networks:
  275. raise RenderError("Cannot set both [network_mode] and [networks]")
  276. result = {
  277. "image": self._image,
  278. "platform": "linux/amd64",
  279. "tty": self._tty,
  280. "stdin_open": self._stdin_open,
  281. "restart": self.restart.render(),
  282. }
  283. if self._pull_policy:
  284. result["pull_policy"] = self._pull_policy
  285. if self.healthcheck.has_healthcheck():
  286. result["healthcheck"] = self.healthcheck.render()
  287. if self._hostname:
  288. result["hostname"] = self._hostname
  289. if self._build_image:
  290. result["build"] = {"tags": [self._image], "dockerfile_inline": self._build_image}
  291. if self.configs.has_configs():
  292. result["configs"] = self.configs.render()
  293. if self._ipc_mode is not None:
  294. result["ipc"] = self._ipc_mode
  295. if self._pid_mode is not None:
  296. result["pid"] = self._pid_mode
  297. if self._device_cgroup_rules.has_rules():
  298. result["device_cgroup_rules"] = self._device_cgroup_rules.render()
  299. if self._cgroup is not None:
  300. result["cgroup"] = self._cgroup
  301. if self._extra_hosts.has_hosts():
  302. result["extra_hosts"] = self._extra_hosts.render()
  303. if self._init is not None:
  304. result["init"] = self._init
  305. if self._read_only is not None:
  306. result["read_only"] = self._read_only
  307. if self._grace_period is not None:
  308. result["stop_grace_period"] = f"{self._grace_period}s"
  309. if self._user:
  310. result["user"] = self._user
  311. for g in self.get_additional_groups():
  312. self.add_group(g)
  313. if self._group_add:
  314. result["group_add"] = sorted(self._group_add, key=lambda g: (isinstance(g, str), g))
  315. if self._shm_size is not None:
  316. result["shm_size"] = f"{self._shm_size}M"
  317. if self._privileged is not None:
  318. result["privileged"] = self._privileged
  319. if self._cap_drop:
  320. result["cap_drop"] = sorted(self._cap_drop)
  321. if self._cap_add:
  322. result["cap_add"] = sorted(self._cap_add)
  323. if self._security_opt.has_opts():
  324. result["security_opt"] = self._security_opt.render()
  325. if self._network_mode:
  326. result["network_mode"] = self._network_mode
  327. if self.sysctls.has_sysctls():
  328. result["sysctls"] = self.sysctls.render()
  329. if self._network_mode != "host":
  330. if self.ports.has_ports():
  331. result["ports"] = self.ports.render()
  332. if self.expose.has_ports():
  333. result["expose"] = self.expose.render()
  334. if self._entrypoint:
  335. result["entrypoint"] = self._entrypoint
  336. if self._command:
  337. result["command"] = self._command
  338. if self.devices.has_devices():
  339. result["devices"] = self.devices.render()
  340. if self.deploy.has_deploy():
  341. result["deploy"] = self.deploy.render()
  342. if self.environment.has_variables():
  343. result["environment"] = self.environment.render()
  344. if self.labels.has_labels():
  345. result["labels"] = self.labels.render()
  346. if self.dns.has_dns_nameservers():
  347. result["dns"] = self.dns.render_dns_nameservers()
  348. if self.dns.has_dns_searches():
  349. result["dns_search"] = self.dns.render_dns_searches()
  350. if self.dns.has_dns_opts():
  351. result["dns_opt"] = self.dns.render_dns_opts()
  352. if self.depends.has_dependencies():
  353. result["depends_on"] = self.depends.render()
  354. if self._storage.has_mounts():
  355. result["volumes"] = self._storage.render()
  356. if self._tmpfs.has_tmpfs():
  357. result["tmpfs"] = self._tmpfs.render()
  358. return result