from clients import sqs, r2, sqs_queue_url, sqs_dlq_url, r2_bucket_name
import json
import time
import threading
import boto3
visibility_timeout = 60
def get_job() -> tuple:
'''
Get the job from the SQS queue
Returns:
- job: dict, the job to be processed
- receipt_handle: str, the receipt handle of the message
If there are no messages in the queue, return None, None
'''
response = sqs.receive_message(
QueueUrl=sqs_queue_url,
AttributeNames=['All'],
MaxNumberOfMessages=1,
MessageAttributeNames=['All'],
VisibilityTimeout=visibility_timeout,
WaitTimeSeconds=20
)
if 'Messages' in response:
message = response['Messages'][0]
receipt_handle = message['ReceiptHandle']
job = json.loads(message['Body'])
return job, receipt_handle
else:
return None, None
def heartbeat_job(receipt_handle: str, heartbeat_stop_signal: threading.Event, job_stop_signal: threading.Event) -> None:
'''
Send a heartbeat to the SQS queue to keep the job alive
Parameters:
- receipt_handle: str, the receipt handle of the message
- heartbeat_stop_signal: threading.Event, a signal to stop the heartbeat
- job_stop_signal: threading.Event, a signal to stop the main job
'''
while not heartbeat_stop_signal.is_set():
try:
sqs.change_message_visibility(
QueueUrl=sqs_queue_url,
ReceiptHandle=receipt_handle,
VisibilityTimeout=visibility_timeout
)
time.sleep(visibility_timeout / 2)
except boto3.SQS.Client.exceptions.ReceiptHandleIsInvalid:
# If the receipt handle is invalid, it means the job has been
# acknowledged, or the message has been given to another worker.
# In this case, we can stop the heartbeat, and interrupt the
# main job.
job_stop_signal.set()
break
def release_job(receipt_handle: str) -> None:
'''
Release the job back to the SQS queue
Parameters:
- receipt_handle: str, the receipt handle of the message
'''
try:
sqs.change_message_visibility(
QueueUrl=sqs_queue_url,
ReceiptHandle=receipt_handle,
VisibilityTimeout=0
)
except boto3.SQS.Client.exceptions.ReceiptHandleIsInvalid:
# If the receipt handle is invalid, it means the job has been
# acknowledged, or the message has been given to another worker.
# In this case, we can ignore the error, because we were trying to
# release the job anyway.
pass
def acknowledge_job(receipt_handle: str) -> None:
'''
Acknowledge the job and delete it from the SQS queue
Parameters:
- receipt_handle: str, the receipt handle of the message
'''
sqs.delete_message(
QueueUrl=sqs_queue_url,
ReceiptHandle=receipt_handle
)
def fail_job(job, receipt_handle: str) -> None:
'''
Move the job to the dead-letter queue
Parameters:
- receipt_handle: str, the receipt handle of the message
'''
# First remove job from the queue
acknowledge_job(receipt_handle)
# Then send it to the DLQ
sqs.send_message(
QueueUrl=sqs_dlq_url,
MessageBody=json.dumps(job)
)
def download_checkpoint(job_id: str) -> dict:
'''
Download the checkpoint from S3
Parameters:
- job_id: str, the job ID
Returns:
- checkpoint: dict, the checkpoint
'''
try:
response = r2.get_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json'
)
except boto3.exceptions.S3.NoSuchKey:
return None
checkpoint = json.loads(response['Body'].read())
return checkpoint
def upload_checkpoint(job_id: str, checkpoint: dict) -> None:
'''
Upload the checkpoint to S3
Parameters:
- job_id: str, the job ID
- checkpoint: dict, the checkpoint
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json',
Body=json.dumps(checkpoint)
)
def validate_job(job: dict) -> bool:
'''
Validate the job
Parameters:
- job: dict, the job
Returns:
- bool, whether the job is valid
'''
# This is a very simple function for our very simple application.
# You should replace this with your actual validation logic.
return 'job_id' in job and 'steps' in job
def do_the_actual_work(job: dict, checkpoint: dict, stop_signal: threading.Event) -> int | None:
'''
Do the actual work for the job. This function will simulate work by
sleeping for 30 seconds and incrementing the step and sum in the
checkpoint.
Parameters:
- job: dict, the job
- checkpoint: dict, the checkpoint
- stop_signal: threading.Event, a signal to stop the work
'''
while checkpoint['step'] < job['steps'] and not stop_signal.is_set():
# Simulate work
time.sleep(30)
# If the job was interrupted, we don't want to upload the
# checkpoint, because it may conflict with the next worker.
if not stop_signal.is_set():
# Update the checkpoint.
checkpoint['step'] += 1
checkpoint['sum'] += checkpoint['step']
upload_checkpoint(job['job_id'], checkpoint)
if not stop_signal.is_set():
return checkpoint['sum']
else:
return None
def upload_result(job_id: str, result: int) -> None:
'''
Upload the result to S3
Parameters:
- job_id: str, the job ID
- result: int, the result
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/result.txt',
Body=str(result)
)
def process_job(job: dict, receipt_handle: str) -> None:
# Now that we have the job, we need to start a separate thread that
# heartbeats for it. This will keep the job alive in the SQS queue.
# Separate threads are critical here, because our main work is likely
# blocking, and we don't want to block the heartbeat.
heartbeat_stop_signal = threading.Event()
job_stop_signal = threading.Event()
heartbeat_thread = threading.Thread(
target=heartbeat_job, args=(
receipt_handle, heartbeat_stop_signal, job_stop_signal))
heartbeat_thread.start()
# If there's a checkpoint, we want to use it, but if not, we need to
# initialize our state.
checkpoint = download_checkpoint(job['job_id'])
if checkpoint is None:
checkpoint = {'step': 0, 'sum': 0}
# Some jobs may have a validation step. For instance, dreambooth training may have a step
# that verifies if all inputs have faces. If the validation fails, we should stop the job
# and not retry it, but instead move it to the DLQ. In this situation, we can
# be confident that the job will never succeed.
if not validate_job(job):
heartbeat_stop_signal.set()
fail_job(job, receipt_handle)
heartbeat_thread.join()
return
# Now we can do the actual work
try:
result = do_the_actual_work(job, checkpoint)
if result is None:
# This means the job was interrupted, so we need to release it
# back to the queue.
heartbeat_stop_signal.set()
heartbeat_thread.join()
release_job(receipt_handle)
return
# The job isn't really done until the result is uploaded.
upload_result(job['job_id'], result)
# Once the result is uploaded, we can acknowledge the job and stop
# the heartbeat.
acknowledge_job(receipt_handle)
heartbeat_stop_signal.set()
heartbeat_thread.join()
except Exception as e:
# If there's an error, we need to release the job back to the queue
# so it can be retried.
heartbeat_stop_signal.set()
heartbeat_thread.join()
release_job(receipt_handle)
return
if __name__ == '__main__':
while True:
job, receipt_handle = get_job()
if job is not None:
process_job(job, receipt_handle)
else:
time.sleep(10)