import os
import boto3
import json
import time
import threading
from google.cloud import pubsub_v1
# Get the environment variables
r2_aws_region = "auto"
r2_aws_access_key_id = os.getenv('R2_AWS_ACCESS_KEY_ID')
r2_aws_secret_access_key = os.getenv('R2_AWS_SECRET_ACCESS_KEY')
r2_s3_endpoint_url = os.getenv('R2_S3_ENDPOINT_URL')
r2_bucket_name = os.getenv('R2_BUCKET_NAME')
project_id = os.getenv('PROJECT_ID')
subscription_id = os.getenv('SUBSCRIPTION_ID')
ack_deadline_seconds = int(os.getenv('ACK_DEADLINE_SECONDS'))
# Create the R2 client
r2 = boto3.client('s3',
aws_access_key_id=r2_aws_access_key_id,
aws_secret_access_key=r2_aws_secret_access_key,
region_name=r2_aws_region,
endpoint_url=r2_s3_endpoint_url)
# Create the Pub/Sub client
subscriber = pubsub_v1.SubscriberClient()
subscription_path = subscriber.subscription_path(
project_id, subscription_id)
def download_checkpoint(job_id: str) -> dict:
'''
Download the checkpoint from S3
Parameters:
- job_id: str, the job ID
Returns:
- checkpoint: dict, the checkpoint
'''
try:
response = r2.get_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json'
)
except r2.exceptions.NoSuchKey:
return None
checkpoint = json.loads(response['Body'].read())
return checkpoint
def upload_checkpoint(job_id: str, checkpoint: dict) -> None:
'''
Upload the checkpoint to S3
Parameters:
- job_id: str, the job ID
- checkpoint: dict, the checkpoint
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/checkpoint.json',
Body=json.dumps(checkpoint)
)
print(f'Checkpoint uploaded for job {job_id}', flush=True)
def validate_job(job: dict) -> bool:
'''
Validate the job
Parameters:
- job: dict, the job
Returns:
- bool, whether the job is valid
'''
# This is a very simple function for our very simple application.
# You should replace this with your actual validation logic.
return 'job_id' in job and 'steps' in job
def upload_result(job_id: str, result: int) -> None:
'''
Upload the result to S3
Parameters:
- job_id: str, the job ID
- result: int, the result
'''
r2.put_object(
Bucket=r2_bucket_name,
Key=f'{job_id}/result.txt',
Body=str(result)
)
print(f'Result uploaded for job {job_id}', flush=True)
def do_the_actual_work(job: dict, checkpoint: dict, cancel_signal: threading.Event) -> int | None:
'''
Do the actual work for the job. This function will simulate work by
sleeping for 30 seconds and incrementing the step and sum in the
checkpoint.
Parameters:
- job: dict, the job
- checkpoint: dict, the checkpoint
'''
print(f'Starting job {job["job_id"]}', flush=True)
print(f"Max steps: {job['steps']}", flush=True)
print(f"Starting step: {checkpoint['step']}", flush=True)
while checkpoint['step'] < job['steps'] and not cancel_signal.is_set():
# Simulate work
print(
f"Working on job {job['job_id']}, step {checkpoint['step']}", flush=True)
time.sleep(30)
if cancel_signal.is_set():
# If we were interrupted, we need to return None to indicate that
# the job was interrupted.
return None
# Update the checkpoint.
checkpoint['step'] += 1
checkpoint['sum'] += checkpoint['step']
upload_checkpoint(job['job_id'], checkpoint)
print(f'Job {job["job_id"]} finished')
return checkpoint['sum']
def heartbeat_job(ack_id: str, heartbeat_stop_signal: threading.Event, job_stop_signal: threading.Event) -> None:
'''
Send a heartbeat to the GCP subscription
Parameters:
- receipt_handle: str, the receipt handle of the message
- heartbeat_stop_signal: threading.Event, a signal to stop the heartbeat
- job_stop_signal: threading.Event, a signal to stop the main job
'''
while not heartbeat_stop_signal.is_set():
try:
subscriber.modify_ack_deadline(
subscription=subscription_path,
ack_ids=[ack_id],
ack_deadline_seconds=ack_deadline_seconds
)
time.sleep(ack_deadline_seconds // 2)
except Exception as e:
print(f"Error in heartbeat: {str(e)}", flush=True)
job_stop_signal.set()
break
def ack_job(ack_id: str) -> None:
'''
Acknowledge the job
Parameters:
- ack_id: str, the ack ID
'''
subscriber.acknowledge(subscription=subscription_path, ack_ids=[ack_id])
def nack_job(ack_id: str) -> None:
'''
Reject the job
Parameters:
- ack_id: str, the ack ID
'''
subscriber.modify_ack_deadline(
subscription=subscription_path,
ack_ids=[ack_id],
ack_deadline_seconds=0
)
def process_job(job, ack_id):
print(f"Received job {job['job_id']}", flush=True)
# Start the heartbeat thread to keep the job alive
heartbeat_stop_signal = threading.Event()
job_stop_signal = threading.Event()
heartbeat_thread = threading.Thread(
target=heartbeat_job, args=(
ack_id, heartbeat_stop_signal, job_stop_signal))
heartbeat_thread.start()
# If there's a checkpoint, we want to use it, but if not, we need to
# initialize our state.
checkpoint = download_checkpoint(job['job_id'])
if checkpoint is None:
print('No checkpoint found. Initializing state.', flush=True)
checkpoint = {'step': 0, 'sum': 0}
else:
print(
f'Found checkpoint, resuming from step {checkpoint["step"]}', flush=True)
# Some jobs may have a validation step. For instance, dreambooth training may have a step
# that verifies if all inputs have faces. If the validation fails, we should nack the job.
if not validate_job(job):
print(f"Job {job['job_id']} failed validation")
heartbeat_stop_signal.set()
nack_job(ack_id)
return
# Now we can do the actual work
try:
result = do_the_actual_work(job, checkpoint, job_stop_signal)
except Exception as e:
print(f"Error in job {job['job_id']}: {str(e)}", flush=True)
heartbeat_stop_signal.set()
nack_job(ack_id)
return
if result is None:
# Heartbeat failed, so we need to nack the job
print(f"Heartbeat for {job['job_id']} failed", flush=True)
heartbeat_stop_signal.set()
nack_job(ack_id)
return
# Upload the result and ack the message
upload_result(job['job_id'], result)
heartbeat_stop_signal.set()
ack_job(ack_id)
heartbeat_thread.join()
if __name__ == '__main__':
print("Polling for messages", flush=True)
while True:
response = subscriber.pull(
subscription=subscription_path, max_messages=1, timeout=30)
if not response or len(response.received_messages) == 0:
print("No messages received, sleeping for 10s", flush=True)
time.sleep(10)
continue
message = response.received_messages[0]
ack_id = message.ack_id
body = json.loads(message.message.data.decode('utf-8'))
process_job(body, ack_id)