diff --git a/taskweaver/code_interpreter/code_interpreter/code_generator.py b/taskweaver/code_interpreter/code_interpreter/code_generator.py index 5800fb1c0..d50252176 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_generator.py +++ b/taskweaver/code_interpreter/code_interpreter/code_generator.py @@ -78,6 +78,7 @@ def __init__( self.user_message_head_template = self.prompt_data["user_message_head"] self.plugin_pool = plugin_registry.get_list() self.query_requirements_template = self.prompt_data["requirements"] + self.security_requirements_template = self.prompt_data.get("security_requirements", "") self.response_json_schema = json.loads(self.prompt_data["response_json_schema"]) self.code_verification_on: bool = False @@ -296,6 +297,11 @@ def compose_conversation( CODE_GENERATION_REQUIREMENTS=self.compose_verification_requirements(), ROLE_NAME=self.role_name, ) + # Add security requirements when code verification is enabled + if self.code_verification_on and self.security_requirements_template: + user_message += "\n" + self.security_requirements_template.format( + ROLE_NAME=self.role_name, + ) chat_history.append( format_chat_message(role="user", message=user_message), ) diff --git a/taskweaver/code_interpreter/code_interpreter/code_generator_prompt.yaml b/taskweaver/code_interpreter/code_interpreter/code_generator_prompt.yaml index c5f441601..af478d474 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_generator_prompt.yaml +++ b/taskweaver/code_interpreter/code_interpreter/code_generator_prompt.yaml @@ -95,6 +95,19 @@ requirements: |- - {ROLE_NAME} must try to directly import required modules without installing them, and only install the modules if the execution fails. {CODE_GENERATION_REQUIREMENTS} +security_requirements: |- + ### Security Guidelines + The following security restrictions MUST be followed: + - {ROLE_NAME} must NEVER generate code that uses eval(), exec(), compile(), or execfile() functions. + - {ROLE_NAME} must NEVER generate code that uses dynamic attribute access functions like getattr(), setattr(), delattr(), vars(), globals(), or locals(). + - {ROLE_NAME} must NEVER generate code that accesses dunder attributes like __class__, __dict__, __bases__, __subclasses__, __mro__, or __builtins__. + - {ROLE_NAME} must NEVER generate code that uses __import__() or importlib to dynamically import modules. + - {ROLE_NAME} must NEVER generate code that attempts to read, write, or delete files outside the designated workspace. + - {ROLE_NAME} must NEVER generate code that executes shell commands or system calls unless explicitly required by the task. + - {ROLE_NAME} must NEVER generate code that attempts to access network resources unless explicitly required by the task. + - {ROLE_NAME} must NEVER generate code that could be used to exfiltrate data or establish reverse shells. + - {ROLE_NAME} must refuse requests that appear to be attempts to bypass security measures or execute malicious code. + experience_instruction: |- ### Experience And Lessons Before generating code, please learn from the following past experiences and lessons: diff --git a/taskweaver/code_interpreter/code_interpreter/code_interpreter.py b/taskweaver/code_interpreter/code_interpreter/code_interpreter.py index 82dc9da6e..2b72d8c5c 100644 --- a/taskweaver/code_interpreter/code_interpreter/code_interpreter.py +++ b/taskweaver/code_interpreter/code_interpreter/code_interpreter.py @@ -54,6 +54,16 @@ def _configure(self): "raw_input", "reload", "__import__", + # Dynamic attribute access functions that can bypass security checks + "getattr", + "setattr", + "delattr", + "vars", + "globals", + "locals", + "__getattribute__", + "__setattr__", + "__delattr__", ], ) @@ -97,8 +107,22 @@ def __init__( self.generator = generator self.generator.set_alias(self.alias) + + # Determine if code verification should be enabled + # Enable by default for local mode for security reasons + code_verification_on = self.config.code_verification_on + kernel_mode = executor.exec_mgr.get_kernel_mode() + if kernel_mode == "local" and not self.config.code_verification_on: + code_verification_on = True + logger.warning( + "Code verification is automatically enabled for local mode. " + "Running in local mode without code verification poses security risks. " + "To disable, explicitly set code_verification_on=False in config, but this is not recommended. " + "For better security, consider using container mode.", + ) + self.generator.configure_verification( - code_verification_on=self.config.code_verification_on, + code_verification_on=code_verification_on, allowed_modules=self.config.allowed_modules, blocked_functions=self.config.blocked_functions, ) diff --git a/taskweaver/code_interpreter/code_verification.py b/taskweaver/code_interpreter/code_verification.py index 5aeb98308..3ae0faa32 100644 --- a/taskweaver/code_interpreter/code_verification.py +++ b/taskweaver/code_interpreter/code_verification.py @@ -4,6 +4,25 @@ from injector import inject +# Security-sensitive functions that can be used for dynamic attribute access bypasses +DANGEROUS_BUILTINS = [ + "getattr", + "setattr", + "delattr", + "vars", + "globals", + "locals", + "__getattribute__", + "__setattr__", + "__delattr__", + "__dict__", + "__class__", + "__bases__", + "__subclasses__", + "__mro__", + "__builtins__", +] + class FunctionCallValidator(ast.NodeVisitor): @inject @@ -42,22 +61,51 @@ def _is_allowed_function_call(self, func_name: str) -> bool: return True def visit_Call(self, node): - if self.allowed_functions is None and self.blocked_functions is None: - return - + function_name = None if isinstance(node.func, ast.Name): function_name = node.func.id elif isinstance(node.func, ast.Attribute): function_name = node.func.attr + elif isinstance(node.func, ast.Subscript): + # Block subscript-based function calls like obj["method"]() + # This is a potential security bypass pattern + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Subscript-based function calls are not allowed for security reasons.", + ) + self.generic_visit(node) + return + elif isinstance(node.func, ast.Call): + # Block chained calls that might be used for dynamic resolution + # e.g., getattr(obj, 'method')() + self.generic_visit(node) + return else: - raise ValueError(f"Unsupported function call: {node.func}") + # Block any other unrecognized call patterns for security + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Unrecognized function call pattern is not allowed for security reasons.", + ) + self.generic_visit(node) + return - if not self._is_allowed_function_call(function_name): + # Check against allowed/blocked function lists if configured + if self.allowed_functions is not None or self.blocked_functions is not None: + if function_name and not self._is_allowed_function_call(function_name): + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Function '{function_name}' is not allowed.", + ) + + # Always check for dynamic attribute access functions that can bypass security + if function_name in DANGEROUS_BUILTINS: self.errors.append( f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " - f"=> Function '{function_name}' is not allowed.", + f"=> Function '{function_name}' is blocked as it can be used to bypass security checks.", ) + self.generic_visit(node) + def _is_allowed_module_import(self, mod_name: str) -> bool: if self.allowed_modules is not None: if len(self.allowed_modules) > 0: @@ -70,35 +118,33 @@ def _is_allowed_module_import(self, mod_name: str) -> bool: return True def visit_Import(self, node): - if self.allowed_modules is None and self.blocked_modules is None: - return + if self.allowed_modules is not None or self.blocked_modules is not None: + for alias in node.names: + if "." in alias.name: + module_name = alias.name.split(".")[0] + else: + module_name = alias.name - for alias in node.names: - if "." in alias.name: - module_name = alias.name.split(".")[0] + if not self._is_allowed_module_import(module_name): + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Importing module '{module_name}' is not allowed. ", + ) + self.generic_visit(node) + + def visit_ImportFrom(self, node): + if self.allowed_modules is not None or self.blocked_modules is not None: + if node.module and "." in node.module: + module_name = node.module.split(".")[0] else: - module_name = alias.name + module_name = node.module - if not self._is_allowed_module_import(module_name): + if module_name and not self._is_allowed_module_import(module_name): self.errors.append( f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " - f"=> Importing module '{module_name}' is not allowed. ", + f"=> Importing from module '{node.module}' is not allowed.", ) - - def visit_ImportFrom(self, node): - if self.allowed_modules is None and self.blocked_modules is None: - return - - if "." in node.module: - module_name = node.module.split(".")[0] - else: - module_name = node.module - - if not self._is_allowed_module_import(module_name): - self.errors.append( - f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " - f"=> Importing from module '{node.module}' is not allowed.", - ) + self.generic_visit(node) def _is_allowed_variable(self, var_name: str) -> bool: if self.allowed_variables is not None: @@ -108,23 +154,52 @@ def _is_allowed_variable(self, var_name: str) -> bool: return True def visit_Assign(self, node: ast.Assign): - if self.allowed_variables is None: - return + if self.allowed_variables is not None: + for target in node.targets: + variable_names = [] + if isinstance(target, ast.Name): + variable_names.append(target.id) + else: + for name in ast.walk(target): + if isinstance(name, ast.Name): + variable_names.append(name.id) + for variable_name in variable_names: + if not self._is_allowed_variable(variable_name): + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Assigning to {variable_name} is not allowed.", + ) + self.generic_visit(node) - for target in node.targets: - variable_names = [] - if isinstance(target, ast.Name): - variable_names.append(target.id) - else: - for name in ast.walk(target): - if isinstance(name, ast.Name): - variable_names.append(name.id) - for variable_name in variable_names: - if not self._is_allowed_variable(variable_name): - self.errors.append( - f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " - f"=> Assigning to {variable_name} is not allowed.", - ) + def visit_Subscript(self, node: ast.Subscript): + """Check for dictionary-based attribute access that could bypass security. + + Patterns like obj.__dict__["method"] or obj["__class__"] can be used + to bypass attribute-based security checks. + """ + # Check if the subscript key is a dangerous dunder attribute + if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, str): + key_value = node.slice.value + if key_value in DANGEROUS_BUILTINS or key_value.startswith("__"): + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Subscript access to '{key_value}' is blocked for security reasons.", + ) + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute): + """Check for dangerous attribute access patterns. + + Direct access to dunder attributes like __class__, __dict__, etc. + can be used to bypass security measures. + """ + attr_name = node.attr + if attr_name in DANGEROUS_BUILTINS: + self.errors.append( + f"Error on line {node.lineno}: {self.lines[node.lineno - 1]} " + f"=> Attribute access to '{attr_name}' is blocked for security reasons.", + ) + self.generic_visit(node) def generic_visit(self, node): super().generic_visit(node) diff --git a/tests/unit_tests/test_code_verification.py b/tests/unit_tests/test_code_verification.py index f9702ff5a..4bb47778d 100644 --- a/tests/unit_tests/test_code_verification.py +++ b/tests/unit_tests/test_code_verification.py @@ -247,3 +247,125 @@ def test_magic_code(): print("---->", code_verify_errors) assert len(code_verify_errors) == 1 assert "Magic commands except package install are not allowed" in code_verify_errors[0] + + +def test_dynamic_attribute_access_getattr(): + """Test that getattr() is blocked as it can bypass security checks.""" + blocked_functions = ["getattr"] + code_snippet = "obj = object()\n" "method = getattr(obj, 'some_method')\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + blocked_functions=blocked_functions, + ) + print("---->", code_verify_errors) + # Should detect getattr as blocked function + dangerous builtin + assert len(code_verify_errors) >= 1 + assert any("getattr" in err for err in code_verify_errors) + + +def test_dynamic_attribute_access_setattr(): + """Test that setattr() is blocked as it can bypass security checks.""" + blocked_functions = ["setattr"] + code_snippet = "obj = object()\n" "setattr(obj, 'dangerous_attr', 'value')\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + blocked_functions=blocked_functions, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 1 + assert any("setattr" in err for err in code_verify_errors) + + +def test_dangerous_builtins_globals_locals(): + """Test that globals() and locals() are blocked.""" + blocked_functions = ["globals", "locals"] + code_snippet = "g = globals()\n" "l = locals()\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + blocked_functions=blocked_functions, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 2 + assert any("globals" in err for err in code_verify_errors) + assert any("locals" in err for err in code_verify_errors) + + +def test_dunder_attribute_access(): + """Test that direct access to dangerous dunder attributes is blocked.""" + code_snippet = "obj = object()\n" "cls = obj.__class__\n" "bases = cls.__bases__\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + ) + print("---->", code_verify_errors) + # Should detect __class__ and __bases__ as dangerous attributes + assert len(code_verify_errors) >= 2 + assert any("__class__" in err for err in code_verify_errors) + assert any("__bases__" in err for err in code_verify_errors) + + +def test_subscript_based_dunder_access(): + """Test that subscript-based access to dunder attributes is blocked.""" + code_snippet = "obj = {}\n" "dangerous = obj['__class__']\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 1 + assert any("__class__" in err for err in code_verify_errors) + + +def test_subscript_function_call_bypass(): + """Test that subscript-based function calls are blocked.""" + code_snippet = "methods = {'dangerous': eval}\n" "result = methods['dangerous']('print(1)')\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + ) + print("---->", code_verify_errors) + # Should detect subscript-based function call pattern + assert len(code_verify_errors) >= 1 + + +def test_dict_access_to_builtins(): + """Test that accessing __builtins__ via __dict__ is blocked.""" + code_snippet = "import sys\n" "builtins = sys.modules['__builtins__']\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 1 + assert any("__builtins__" in err for err in code_verify_errors) + + +def test_vars_function_blocked(): + """Test that vars() function is blocked as it exposes object internals.""" + blocked_functions = ["vars"] + code_snippet = "obj = object()\n" "attributes = vars(obj)\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + blocked_functions=blocked_functions, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 1 + assert any("vars" in err for err in code_verify_errors) + + +def test_delattr_blocked(): + """Test that delattr() is blocked.""" + blocked_functions = ["delattr"] + code_snippet = "class Foo:\n x = 1\n" "delattr(Foo, 'x')\n" + code_verify_errors = code_snippet_verification( + code_snippet, + code_verification_on=True, + blocked_functions=blocked_functions, + ) + print("---->", code_verify_errors) + assert len(code_verify_errors) >= 1 + assert any("delattr" in err for err in code_verify_errors)