environment.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Any, TYPE_CHECKING
  2. if TYPE_CHECKING:
  3. from render import Render
  4. try:
  5. from .error import RenderError
  6. from .formatter import escape_dollar
  7. from .resources import Resources
  8. except ImportError:
  9. from error import RenderError
  10. from formatter import escape_dollar
  11. from resources import Resources
  12. class Environment:
  13. def __init__(self, render_instance: "Render", resources: Resources):
  14. self._render_instance = render_instance
  15. self._resources = resources
  16. # Stores variables that user defined
  17. self._user_vars: dict[str, Any] = {}
  18. # Stores variables that are automatically added (based on values)
  19. self._auto_variables: dict[str, Any] = {}
  20. # Stores variables that are added by the application developer
  21. self._app_dev_variables: dict[str, Any] = {}
  22. self._skip_generic_variables: bool = render_instance.values.get("skip_generic_variables", False)
  23. self._auto_add_variables_from_values()
  24. def _auto_add_variables_from_values(self):
  25. if not self._skip_generic_variables:
  26. self._add_generic_variables()
  27. self._add_nvidia_variables()
  28. def _add_generic_variables(self):
  29. self._auto_variables["TZ"] = self._render_instance.values.get("TZ", "Etc/UTC")
  30. self._auto_variables["UMASK"] = self._render_instance.values.get("UMASK", "002")
  31. self._auto_variables["UMASK_SET"] = self._render_instance.values.get("UMASK", "002")
  32. run_as = self._render_instance.values.get("run_as", {})
  33. user = run_as.get("user")
  34. group = run_as.get("group")
  35. if user:
  36. self._auto_variables["PUID"] = user
  37. self._auto_variables["UID"] = user
  38. self._auto_variables["USER_ID"] = user
  39. if group:
  40. self._auto_variables["PGID"] = group
  41. self._auto_variables["GID"] = group
  42. self._auto_variables["GROUP_ID"] = group
  43. def _add_nvidia_variables(self):
  44. if self._resources._nvidia_ids:
  45. self._auto_variables["NVIDIA_DRIVER_CAPABILITIES"] = "all"
  46. self._auto_variables["NVIDIA_VISIBLE_DEVICES"] = ",".join(sorted(self._resources._nvidia_ids))
  47. else:
  48. self._auto_variables["NVIDIA_VISIBLE_DEVICES"] = "void"
  49. def _format_value(self, v: Any) -> str:
  50. value = str(v)
  51. # str(bool) returns "True" or "False",
  52. # but we want "true" or "false"
  53. if isinstance(v, bool):
  54. value = value.lower()
  55. return value
  56. def add_env(self, name: str, value: Any):
  57. if not name:
  58. raise RenderError(f"Environment variable name cannot be empty. [{name}]")
  59. if name in self._app_dev_variables.keys():
  60. raise RenderError(
  61. f"Found duplicate environment variable [{name}] in application developer environment variables."
  62. )
  63. self._app_dev_variables[name] = value
  64. def add_user_envs(self, user_env: list[dict]):
  65. for item in user_env:
  66. if not item.get("name"):
  67. raise RenderError(f"Environment variable name cannot be empty. [{item}]")
  68. if item["name"] in self._user_vars.keys():
  69. raise RenderError(
  70. f"Found duplicate environment variable [{item['name']}] in user environment variables."
  71. )
  72. self._user_vars[item["name"]] = item.get("value")
  73. def has_variables(self):
  74. return len(self._auto_variables) > 0 or len(self._user_vars) > 0 or len(self._app_dev_variables) > 0
  75. def render(self):
  76. result: dict[str, str] = {}
  77. # Add envs from auto variables
  78. result.update({k: self._format_value(v) for k, v in self._auto_variables.items()})
  79. # Track defined keys for faster lookup
  80. defined_keys = set(result.keys())
  81. # Add envs from application developer (prohibit overwriting auto variables)
  82. for k, v in self._app_dev_variables.items():
  83. if k in defined_keys:
  84. raise RenderError(f"Environment variable [{k}] is already defined automatically from the library.")
  85. result[k] = self._format_value(v)
  86. defined_keys.add(k)
  87. # Add envs from user (prohibit overwriting app developer envs and auto variables)
  88. for k, v in self._user_vars.items():
  89. if k in defined_keys:
  90. raise RenderError(f"Environment variable [{k}] is already defined from the application developer.")
  91. result[k] = self._format_value(v)
  92. return {k: escape_dollar(v) for k, v in result.items()}