import os
import boto3
import pika
import json
import time
import threading
import functools
# Get the environment variables
r2_aws_region = "auto"
r2_aws_access_key_id = os.getenv('R2_AWS_ACCESS_KEY_ID')
r2_aws_secret_access_key = os.getenv('R2_AWS_SECRET_ACCESS_KEY')
r2_s3_endpoint_url = os.getenv('R2_S3_ENDPOINT_URL')
r2_bucket_name = os.getenv('R2_BUCKET_NAME')
amqp_url = os.getenv('AMQP_URL')
job_queue = os.getenv('JOB_QUEUE')
machine_id = os.getenv('SALAD_MACHINE_ID')
# Create the R2 client
r2 = boto3.client('s3',
aws_access_key_id=r2_aws_access_key_id,
aws_secret_access_key=r2_aws_secret_access_key,
region_name=r2_aws_region,
endpoint_url=r2_s3_endpoint_url)
def download_checkpoint(job_id: str) -> dict:
'''
Download the checkpoint from S3
Parameters:
- job_id: str, the job ID
Returns:
- checkpoint: dict, the checkpoint
'''
try:
response = r2.get_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json'
)
except r2.exceptions.NoSuchKey:
return None
checkpoint = json.loads(response['Body'].read())
return checkpoint
def upload_checkpoint(job_id: str, checkpoint: dict) -> None:
'''
Upload the checkpoint to S3
Parameters:
- job_id: str, the job ID
- checkpoint: dict, the checkpoint
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json',
Body=json.dumps(checkpoint)
)
print(f'Checkpoint uploaded for job {job_id}', flush=True)
def validate_job(job: dict) -> bool:
'''
Validate the job
Parameters:
- job: dict, the job
Returns:
- bool, whether the job is valid
'''
# This is a very simple function for our very simple application.
# You should replace this with your actual validation logic.
return 'job_id' in job and 'steps' in job
cancel_signal = threading.Event()
def do_the_actual_work(job: dict, checkpoint: dict) -> int | None:
'''
Do the actual work for the job. This function will simulate work by
sleeping for 30 seconds and incrementing the step and sum in the
checkpoint.
Parameters:
- job: dict, the job
- checkpoint: dict, the checkpoint
'''
global cancel_signal
print(f'Starting job {job["job_id"]}', flush=True)
print(f"Max steps: {job['steps']}", flush=True)
print(f"Starting step: {checkpoint['step']}", flush=True)
while checkpoint['step'] < job['steps'] and not cancel_signal.is_set():
# Simulate work
print(
f"Working on job {job['job_id']}, step {checkpoint['step']}", flush=True)
time.sleep(30)
if cancel_signal.is_set():
# If we were interrupted, we need to return None to indicate that
# the job was interrupted.
return None
# Update the checkpoint.
checkpoint['step'] += 1
checkpoint['sum'] += checkpoint['step']
upload_checkpoint(job['job_id'], checkpoint)
print(f'Job {job["job_id"]} finished')
return checkpoint['sum']
def upload_result(job_id: str, result: int) -> None:
'''
Upload the result to S3
Parameters:
- job_id: str, the job ID
- result: int, the result
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/result.txt',
Body=str(result)
)
print(f'Result uploaded for job {job_id}', flush=True)
def ack_message(channel, delivery_tag):
'''
Acknowledge the message, indicating that it has been processed successfully
'''
if channel.is_open:
channel.basic_ack(delivery_tag)
else:
# Channel is already closed, so we can't ack this message;
print("Channel is closed, message not acked")
def nack_message(channel, delivery_tag, requeue=True):
'''
Reject the message, indicating that it has not been processed successfully
'''
if channel.is_open:
channel.basic_nack(delivery_tag, requeue=requeue)
else:
# Channel is already closed, so we can't nack this message;
print("Channel is closed, message not nacked")
def process_job(channel, delivery_tag, body):
job = json.loads(body)
print(f"Received job {job['job_id']}", flush=True)
# If there's a checkpoint, we want to use it, but if not, we need to
# initialize our state.
checkpoint = download_checkpoint(job['job_id'])
if checkpoint is None:
checkpoint = {'step': 0, 'sum': 0}
# Some jobs may have a validation step. For instance, dreambooth training may have a step
# that verifies if all inputs have faces. If the validation fails, we should stop the job
# and not retry it, but instead move it to the DLQ. In this situation, we can
# be confident that the job will never succeed.
if not validate_job(job):
cb = functools.partial(nack_message, channel, delivery_tag, False)
if channel.is_open:
channel.connection.add_callback_threadsafe(cb)
return
# Now we can do the actual work
try:
result = do_the_actual_work(job, checkpoint)
except Exception as e:
print(f"Error in job {job['job_id']}: {str(e)}")
cb = functools.partial(nack_message, channel, delivery_tag)
if channel.is_open:
channel.connection.add_callback_threadsafe(cb)
return
if result is None:
# Job was interrupted
cb = functools.partial(nack_message, channel, delivery_tag)
if channel.is_open:
channel.connection.add_callback_threadsafe(cb)
return
# Upload the result and ack the message
upload_result(job['job_id'], result)
cb = functools.partial(ack_message, channel, delivery_tag)
if channel.is_open:
channel.connection.add_callback_threadsafe(cb)
def on_message(channel, method_frame, header_frame, body, args):
threads = args
delivery_tag = method_frame.delivery_tag
t = threading.Thread(target=process_job, args=(
channel, delivery_tag, body))
t.start()
threads.append(t)
if __name__ == "__main__":
# We will be doing all of our work in separate threads, so that rabbitmq's heartbeat
# can be properly handled.
threads = []
while True:
try:
# Create the connection and channel, heartbeating every 30 seconds.
connection = pika.BlockingConnection(
pika.URLParameters(amqp_url + "?heartbeat=30"))
channel = connection.channel()
# We only want 1 job at a time per worker
channel.basic_qos(prefetch_count=1)
# Start consuming the messages
on_message_callback = functools.partial(
on_message, args=(threads)
)
channel.basic_consume(
queue=job_queue, on_message_callback=on_message_callback, consumer_tag=machine_id)
channel.start_consuming()
# Don't recover if connection was closed by broker
except pika.exceptions.ConnectionClosedByBroker:
print("Connection closed by broker")
break
# Don't recover on channel errors
except pika.exceptions.AMQPChannelError as e:
print("Channel error")
print(str(e))
break
# Recover on all other connection errors
except pika.exceptions.AMQPConnectionError as e:
print("Connection error, retrying...")
print(str(e))
time.sleep(1)
continue
except KeyboardInterrupt:
print("Keyboard interrupt")
channel.stop_consuming()
break
except Exception as e:
print(f"Error: {str(e)}")
break
cancel_signal.set()
print("Exiting")
for thread in threads:
thread.join()
connection.close()