Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,19 @@ def __init__(self, config: Dict[str, Any]):
self.jobs_handler = jobs_handler

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)
# Concurrency modifier is user-provided and can return invalid values (e.g. None).
# Defensive validation prevents crashes like: TypeError: '<' not supported between 'int' and 'NoneType'
# when current_concurrency is used for queue sizing / task scheduling.
try:
new_concurrency = self.concurrency_modifier(self.current_concurrency)
except Exception as error:
log.warn(
f"JobScaler.set_scale | concurrency_modifier raised {type(error).__name__}: {error}. "
f"Keeping concurrency at {self.current_concurrency}."
)
new_concurrency = self.current_concurrency

self.current_concurrency = self._sanitize_concurrency(new_concurrency)

if self.jobs_queue and (self.current_concurrency == self.jobs_queue.maxsize):
# no need to resize
Expand All @@ -88,6 +100,34 @@ async def set_scale(self):
f"JobScaler.set_scale | New concurrency set to: {self.current_concurrency}"
)

@staticmethod
def _sanitize_concurrency(value: Any) -> int:
"""
Coerce a user-provided concurrency value into a safe integer >= 1.
"""
# Reject common footguns explicitly.
if value is None or isinstance(value, bool) or isinstance(value, float):
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

try:
v = int(value)
except Exception:
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

if v < 1:
log.warn(
f"JobScaler.set_scale | Invalid concurrency value: {value!r}. Defaulting to 1."
)
return 1

return v

def start(self):
"""
This is required for the worker to be able to shut down gracefully
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
from unittest import TestCase

from runpod.serverless.modules.rp_scale import JobScaler


class TestJobScalerConcurrencyValidation(TestCase):
def test_concurrency_modifier_none_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: None})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

def test_concurrency_modifier_zero_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: 0})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

def test_concurrency_modifier_negative_defaults_to_one(self):
scaler = JobScaler({"concurrency_modifier": lambda _: -3})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 1)

def test_concurrency_modifier_valid_int_is_applied(self):
scaler = JobScaler({"concurrency_modifier": lambda _: 4})
asyncio.run(scaler.set_scale())
self.assertEqual(scaler.current_concurrency, 4)