update actions config, fix lint errors

This commit is contained in:
david 2026-02-24 19:19:15 +03:00
parent a9b0fd34a2
commit 1c35df931a
26 changed files with 732 additions and 691 deletions

View file

@ -14,7 +14,7 @@ on:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: [docker]
strategy: strategy:
matrix: matrix:
@ -32,8 +32,6 @@ jobs:
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
--health-retries 5 --health-retries 5
ports:
- 5432:5432
redis: redis:
image: redis:7-alpine image: redis:7-alpine
@ -106,7 +104,7 @@ jobs:
coverage report --fail-under=80 coverage report --fail-under=80
security-scan: security-scan:
runs-on: ubuntu-latest runs-on: [docker]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4

View file

@ -1,56 +1,56 @@
name: CD # name: CD
on: # on:
push: # push:
branches: [ main ] # branches: [ main ]
workflow_dispatch: # workflow_dispatch:
jobs: # jobs:
deploy: # deploy:
runs-on: [docker] # runs-on: [docker]
steps: # steps:
- uses: actions/checkout@v3 # - uses: actions/checkout@v3
- name: Configure AWS credentials # - name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2 # uses: aws-actions/configure-aws-credentials@v2
with: # with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} # aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} # aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ secrets.AWS_REGION }} # aws-region: ${{ secrets.AWS_REGION }}
- name: Login to Amazon ECR # - name: Login to Amazon ECR
id: login-ecr # id: login-ecr
uses: aws-actions/amazon-ecr-login@v1 # uses: aws-actions/amazon-ecr-login@v1
- name: Build and push backend # - name: Build and push backend
env: # env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} # ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: crafting-shop-backend # ECR_REPOSITORY: crafting-shop-backend
IMAGE_TAG: ${{ github.sha }} # IMAGE_TAG: ${{ github.sha }}
run: | # run: |
cd backend # cd backend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . # docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG # docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest
docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest
- name: Build and push frontend # - name: Build and push frontend
env: # env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} # ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: crafting-shop-frontend # ECR_REPOSITORY: crafting-shop-frontend
IMAGE_TAG: ${{ github.sha }} # IMAGE_TAG: ${{ github.sha }}
run: | # run: |
cd frontend # cd frontend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . # docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG # docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest
docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest
- name: Deploy to ECS # - name: Deploy to ECS
uses: aws-actions/amazon-ecs-deploy-task-definition@v1 # uses: aws-actions/amazon-ecs-deploy-task-definition@v1
with: # with:
task-definition: crafting-shop-task # task-definition: crafting-shop-task
service: crafting-shop-service # service: crafting-shop-service
cluster: crafting-shop-cluster # cluster: crafting-shop-cluster
wait-for-service-stability: true # wait-for-service-stability: true

View file

@ -22,25 +22,26 @@ jobs:
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
--health-retries 5 --health-retries 5
ports:
- 5432:5432 container:
image: nikolaik/python-nodejs:python3.12-nodejs24-alpine
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v6
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 run: |
with: python --version
python-version: '3.11'
- name: Cache pip packages - name: Cache pip packages
uses: actions/cache@v3 uses: actions/cache@v3
with: with:
path: ~/.cache/pip # This path is on the runner, but we mounted it to the container
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} path: /tmp/pip-cache
key: ${{ runner.os }}-pip-${{ hashFiles('backend/requirements/dev.txt') }}
restore-keys: | restore-keys: |
${{ runner.os }}-pip- ${{ runner.os }}-pip-
- name: Install dependencies - name: Install dependencies
run: | run: |
cd backend cd backend
@ -50,12 +51,11 @@ jobs:
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
cd backend cd backend
flake8 app tests --count --select=E9,F63,F7,F82 --show-source --statistics flake8 app tests --count --max-complexity=10 --max-line-length=127 --statistics --show-source
flake8 app tests --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run tests - name: Run tests
env: env:
DATABASE_URL: postgresql://test:test@localhost:5432/test_db DATABASE_URL: postgresql://test:test@postgres:5432/test_db
SECRET_KEY: test-secret-key SECRET_KEY: test-secret-key
JWT_SECRET_KEY: test-jwt-secret JWT_SECRET_KEY: test-jwt-secret
FLASK_ENV: testing FLASK_ENV: testing
@ -63,13 +63,6 @@ jobs:
cd backend cd backend
pytest --cov=app --cov-report=xml --cov-report=term pytest --cov=app --cov-report=xml --cov-report=term
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
files: ./backend/coverage.xml
flags: backend
name: backend-coverage
frontend-test: frontend-test:
runs-on: [docker] runs-on: [docker]

View file

@ -126,7 +126,23 @@ lint-frontend: ## Lint frontend only
format: ## Format code format: ## Format code
@echo "Formatting backend..." @echo "Formatting backend..."
cd backend && . venv/bin/activate && black app tests && isort app tests cd backend && . venv/bin/activate && black app tests
cd backend && . venv/bin/activate && isort app tests
@echo "Formatting frontend..."
cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}"
format-backend: ## Format backend code only
@echo "Formatting backend with black..."
cd backend && . venv/bin/activate && black app tests
@echo "Sorting imports with isort..."
cd backend && . venv/bin/activate && isort app tests
format-backend-check: ## Check if backend code needs formatting
@echo "Checking backend formatting..."
cd backend && . venv/bin/activate && black --check app tests
cd backend && . venv/bin/activate && isort --check-only app tests
format-frontend: ## Format frontend code only
@echo "Formatting frontend..." @echo "Formatting frontend..."
cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}" cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}"

View file

@ -1,11 +1,12 @@
import json import os
from dotenv import load_dotenv
from flask import Flask, jsonify from flask import Flask, jsonify
from flask_cors import CORS from flask_cors import CORS
from flask_jwt_extended import JWTManager from flask_jwt_extended import JWTManager
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate from flask_migrate import Migrate
import os from flask_sqlalchemy import SQLAlchemy
from dotenv import load_dotenv
# Create extensions but don't initialize them yet # Create extensions but don't initialize them yet
db = SQLAlchemy() db = SQLAlchemy()
migrate = Migrate() migrate = Migrate()
@ -13,54 +14,62 @@ jwt = JWTManager()
cors = CORS() cors = CORS()
load_dotenv(override=True) load_dotenv(override=True)
def create_app(config_name=None): def create_app(config_name=None):
"""Application factory pattern""" """Application factory pattern"""
app = Flask(__name__) app = Flask(__name__)
# Load configuration # Load configuration
if config_name is None: if config_name is None:
config_name = os.environ.get("FLASK_ENV", "development") config_name = os.environ.get("FLASK_ENV", "development")
from app.config import config_by_name from app.config import config_by_name
app.config.from_object(config_by_name[config_name]) app.config.from_object(config_by_name[config_name])
print('----------------------------------------------------------') print("----------------------------------------------------------")
print(F'------------------ENVIRONMENT: {config_name}-------------------------------------') print(
f"------------------ENVIRONMENT: {config_name}-------------------------------------"
)
# print(F'------------------CONFIG: {app.config}-------------------------------------') # print(F'------------------CONFIG: {app.config}-------------------------------------')
# print(json.dumps(dict(app.config), indent=2, default=str)) # print(json.dumps(dict(app.config), indent=2, default=str))
print('----------------------------------------------------------') print("----------------------------------------------------------")
# Initialize extensions with app # Initialize extensions with app
db.init_app(app) db.init_app(app)
migrate.init_app(app, db) migrate.init_app(app, db)
jwt.init_app(app) jwt.init_app(app)
cors.init_app(app, resources={r"/api/*": {"origins": app.config.get("CORS_ORIGINS", "*")}}) cors.init_app(
app, resources={r"/api/*": {"origins": app.config.get("CORS_ORIGINS", "*")}}
)
# Initialize Celery # Initialize Celery
from app.celery import init_celery from app.celery import init_celery
init_celery(app) init_celery(app)
# Import models (required for migrations) # Import models (required for migrations)
from app.models import user, product, order from app.models import order, product, user # noqa: F401
# Register blueprints # Register blueprints
from app.routes import api_bp, health_bp from app.routes import api_bp, health_bp
app.register_blueprint(api_bp, url_prefix="/api") app.register_blueprint(api_bp, url_prefix="/api")
app.register_blueprint(health_bp) app.register_blueprint(health_bp)
# Global error handlers # Global error handlers
@app.errorhandler(404) @app.errorhandler(404)
def not_found(error): def not_found(error):
print(f"404 Error: {error}") print(f"404 Error: {error}")
return jsonify({"error": "Not found"}), 404 return jsonify({"error": "Not found"}), 404
@app.errorhandler(500) @app.errorhandler(500)
def internal_error(error): def internal_error(error):
print(f"500 Error: {error}") print(f"500 Error: {error}")
return jsonify({"error": "Internal server error"}), 500 return jsonify({"error": "Internal server error"}), 500
@app.errorhandler(422) @app.errorhandler(422)
def validation_error(error): def validation_error(error):
print(f"422 Error: {error}") print(f"422 Error: {error}")
return jsonify({"error": "Validation error"}), 422 return jsonify({"error": "Validation error"}), 422
return app return app

View file

@ -9,10 +9,10 @@ from flask import Flask
def make_celery(app: Flask) -> Celery: def make_celery(app: Flask) -> Celery:
""" """
Create and configure a Celery application with Flask context. Create and configure a Celery application with Flask context.
Args: Args:
app: Flask application instance app: Flask application instance
Returns: Returns:
Configured Celery application instance. Configured Celery application instance.
""" """
@ -20,36 +20,38 @@ def make_celery(app: Flask) -> Celery:
celery_app = Celery( celery_app = Celery(
app.import_name, app.import_name,
broker=app.config["CELERY"]["broker_url"], broker=app.config["CELERY"]["broker_url"],
backend=app.config["CELERY"]["result_backend"] backend=app.config["CELERY"]["result_backend"],
) )
# Update configuration from Flask config # Update configuration from Flask config
celery_app.conf.update(app.config["CELERY"]) celery_app.conf.update(app.config["CELERY"])
# Set up Flask application context for tasks # Set up Flask application context for tasks
# This ensures tasks have access to Flask extensions (db, etc.) # This ensures tasks have access to Flask extensions (db, etc.)
class ContextTask(celery_app.Task): class ContextTask(celery_app.Task):
"""Celery task that runs within Flask application context.""" """Celery task that runs within Flask application context."""
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
with app.app_context(): with app.app_context():
return self.run(*args, **kwargs) return self.run(*args, **kwargs)
celery_app.Task = ContextTask celery_app.Task = ContextTask
# Auto-discover tasks in the tasks module # Auto-discover tasks in the tasks module
celery_app.autodiscover_tasks(['app.celery.tasks']) celery_app.autodiscover_tasks(["app.celery.tasks"])
# Configure Beat schedule # Configure Beat schedule
from .beat_schedule import configure_beat_schedule from .beat_schedule import configure_beat_schedule
configure_beat_schedule(celery_app) configure_beat_schedule(celery_app)
# Import tasks to ensure they're registered # Import tasks to ensure they're registered
from .tasks import example_tasks from .tasks import example_tasks # noqa: F401
print(f"✅ Celery configured with broker: {celery_app.conf.broker_url}") print(f"✅ Celery configured with broker: {celery_app.conf.broker_url}")
print(f"✅ Celery configured with backend: {celery_app.conf.result_backend}") print(f"✅ Celery configured with backend: {celery_app.conf.result_backend}")
print(f"✅ Beat schedule configured with {len(celery_app.conf.beat_schedule)} tasks") print(f"✅ Beat schedule configured with {len(celery_app.conf.beat_schedule)} tasks")
return celery_app return celery_app
@ -61,10 +63,10 @@ def init_celery(app: Flask) -> Celery:
""" """
Initialize the global celery instance with Flask app. Initialize the global celery instance with Flask app.
This should be called in create_app() after Flask app is created. This should be called in create_app() after Flask app is created.
Args: Args:
app: Flask application instance app: Flask application instance
Returns: Returns:
Configured Celery application instance Configured Celery application instance
""" """

View file

@ -4,7 +4,6 @@ This defines when scheduled tasks should run.
""" """
from celery.schedules import crontab from celery.schedules import crontab
# Celery Beat schedule configuration # Celery Beat schedule configuration
beat_schedule = { beat_schedule = {
# Run every minute (for testing/demo) # Run every minute (for testing/demo)
@ -14,14 +13,12 @@ beat_schedule = {
"args": ("Celery Beat",), "args": ("Celery Beat",),
"options": {"queue": "default"}, "options": {"queue": "default"},
}, },
# Run daily at 9:00 AM # Run daily at 9:00 AM
"send-daily-report": { "send-daily-report": {
"task": "tasks.send_daily_report", "task": "tasks.send_daily_report",
"schedule": crontab(hour=9, minute=0), # 9:00 AM daily "schedule": crontab(hour=9, minute=0), # 9:00 AM daily
"options": {"queue": "reports"}, "options": {"queue": "reports"},
}, },
# Run every hour at minute 0 # Run every hour at minute 0
"update-product-stats-hourly": { "update-product-stats-hourly": {
"task": "tasks.update_product_statistics", "task": "tasks.update_product_statistics",
@ -29,7 +26,6 @@ beat_schedule = {
"args": (None,), # Update all products "args": (None,), # Update all products
"options": {"queue": "stats"}, "options": {"queue": "stats"},
}, },
# Run every Monday at 8:00 AM # Run every Monday at 8:00 AM
"weekly-maintenance": { "weekly-maintenance": {
"task": "tasks.long_running_task", "task": "tasks.long_running_task",
@ -37,7 +33,6 @@ beat_schedule = {
"args": (5,), # 5 iterations "args": (5,), # 5 iterations
"options": {"queue": "maintenance"}, "options": {"queue": "maintenance"},
}, },
# Run every 5 minutes (for monitoring/heartbeat) # Run every 5 minutes (for monitoring/heartbeat)
"heartbeat-check": { "heartbeat-check": {
"task": "tasks.print_hello", "task": "tasks.print_hello",
@ -51,16 +46,16 @@ beat_schedule = {
def configure_beat_schedule(celery_app): def configure_beat_schedule(celery_app):
""" """
Configure Celery Beat schedule on the Celery app. Configure Celery Beat schedule on the Celery app.
Args: Args:
celery_app: Celery application instance celery_app: Celery application instance
""" """
celery_app.conf.beat_schedule = beat_schedule celery_app.conf.beat_schedule = beat_schedule
# Configure timezone # Configure timezone
celery_app.conf.timezone = "UTC" celery_app.conf.timezone = "UTC"
celery_app.conf.enable_utc = True celery_app.conf.enable_utc = True
# Configure task routes for scheduled tasks # Configure task routes for scheduled tasks
celery_app.conf.task_routes = { celery_app.conf.task_routes = {
"tasks.print_hello": {"queue": "default"}, "tasks.print_hello": {"queue": "default"},
@ -68,7 +63,7 @@ def configure_beat_schedule(celery_app):
"tasks.update_product_statistics": {"queue": "stats"}, "tasks.update_product_statistics": {"queue": "stats"},
"tasks.long_running_task": {"queue": "maintenance"}, "tasks.long_running_task": {"queue": "maintenance"},
} }
# Configure queues # Configure queues
celery_app.conf.task_queues = { celery_app.conf.task_queues = {
"default": { "default": {
@ -96,4 +91,4 @@ def configure_beat_schedule(celery_app):
"exchange_type": "direct", "exchange_type": "direct",
"routing_key": "monitoring", "routing_key": "monitoring",
}, },
} }

View file

@ -4,12 +4,12 @@ Tasks are organized by domain/functionality.
""" """
# Import all task modules here to ensure they're registered with Celery # Import all task modules here to ensure they're registered with Celery
from . import example_tasks from . import example_tasks # noqa: F401
# Re-export tasks for easier imports # Re-export tasks for easier imports
from .example_tasks import ( from .example_tasks import ( # noqa: F401
print_hello,
divide_numbers, divide_numbers,
print_hello,
send_daily_report, send_daily_report,
update_product_statistics, update_product_statistics,
) )

View file

@ -2,11 +2,11 @@
Example Celery tasks for the Crafting Shop application. Example Celery tasks for the Crafting Shop application.
These tasks demonstrate various Celery features and best practices. These tasks demonstrate various Celery features and best practices.
""" """
import time
import logging import logging
import time
from datetime import datetime from datetime import datetime
from celery import shared_task from celery import shared_task
from celery.exceptions import MaxRetriesExceededError
# Get logger for this module # Get logger for this module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -16,10 +16,10 @@ logger = logging.getLogger(__name__)
def print_hello(self, name: str = "World") -> str: def print_hello(self, name: str = "World") -> str:
""" """
Simple task that prints a greeting. Simple task that prints a greeting.
Args: Args:
name: Name to greet (default: "World") name: Name to greet (default: "World")
Returns: Returns:
Greeting message Greeting message
""" """
@ -36,51 +36,47 @@ def print_hello(self, name: str = "World") -> str:
retry_backoff=True, retry_backoff=True,
retry_backoff_max=60, retry_backoff_max=60,
retry_jitter=True, retry_jitter=True,
max_retries=3 max_retries=3,
) )
def divide_numbers(self, x: float, y: float) -> float: def divide_numbers(self, x: float, y: float) -> float:
""" """
Task that demonstrates error handling and retry logic. Task that demonstrates error handling and retry logic.
Args: Args:
x: Numerator x: Numerator
y: Denominator y: Denominator
Returns: Returns:
Result of division Result of division
Raises: Raises:
ZeroDivisionError: If y is zero (will trigger retry) ZeroDivisionError: If y is zero (will trigger retry)
""" """
logger.info(f"Dividing {x} by {y} (attempt {self.request.retries + 1})") logger.info(f"Dividing {x} by {y} (attempt {self.request.retries + 1})")
if y == 0: if y == 0:
logger.warning(f"Division by zero detected, retrying...") logger.warning("Division by zero detected, retrying...")
raise ZeroDivisionError("Cannot divide by zero") raise ZeroDivisionError("Cannot divide by zero")
result = x / y result = x / y
logger.info(f"Division result: {result}") logger.info(f"Division result: {result}")
return result return result
@shared_task( @shared_task(bind=True, name="tasks.send_daily_report", ignore_result=False)
bind=True,
name="tasks.send_daily_report",
ignore_result=False
)
def send_daily_report(self) -> dict: def send_daily_report(self) -> dict:
""" """
Simulates sending a daily report. Simulates sending a daily report.
This task would typically send emails, generate reports, etc. This task would typically send emails, generate reports, etc.
Returns: Returns:
Dictionary with report details Dictionary with report details
""" """
logger.info("Starting daily report generation...") logger.info("Starting daily report generation...")
# Simulate some work # Simulate some work
time.sleep(2) time.sleep(2)
report_data = { report_data = {
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"task_id": self.request.id, "task_id": self.request.id,
@ -90,46 +86,43 @@ def send_daily_report(self) -> dict:
"total_products": 150, "total_products": 150,
"total_orders": 42, "total_orders": 42,
"total_users": 89, "total_users": 89,
"revenue": 12500.75 "revenue": 12500.75,
} },
} }
logger.info(f"Daily report generated: {report_data}") logger.info(f"Daily report generated: {report_data}")
print(f"📊 Daily Report Generated at {report_data['date']}") print(f"📊 Daily Report Generated at {report_data['date']}")
return report_data return report_data
@shared_task( @shared_task(
bind=True, bind=True, name="tasks.update_product_statistics", queue="stats", priority=5
name="tasks.update_product_statistics",
queue="stats",
priority=5
) )
def update_product_statistics(self, product_id: int = None) -> dict: def update_product_statistics(self, product_id: int = None) -> dict:
""" """
Simulates updating product statistics. Simulates updating product statistics.
Demonstrates task routing to a specific queue. Demonstrates task routing to a specific queue.
Args: Args:
product_id: Optional specific product ID to update. product_id: Optional specific product ID to update.
If None, updates all products. If None, updates all products.
Returns: Returns:
Dictionary with update results Dictionary with update results
""" """
logger.info(f"Updating product statistics for product_id={product_id}") logger.info(f"Updating product statistics for product_id={product_id}")
# Simulate database work # Simulate database work
time.sleep(1) time.sleep(1)
if product_id is None: if product_id is None:
# Update all products # Update all products
result = { result = {
"task": "update_all_product_stats", "task": "update_all_product_stats",
"status": "completed", "status": "completed",
"products_updated": 150, "products_updated": 150,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat(),
} }
else: else:
# Update specific product # Update specific product
@ -138,13 +131,9 @@ def update_product_statistics(self, product_id: int = None) -> dict:
"product_id": product_id, "product_id": product_id,
"status": "completed", "status": "completed",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"new_stats": { "new_stats": {"views": 125, "purchases": 15, "rating": 4.5},
"views": 125,
"purchases": 15,
"rating": 4.5
}
} }
logger.info(f"Product statistics updated: {result}") logger.info(f"Product statistics updated: {result}")
return result return result
@ -153,46 +142,46 @@ def update_product_statistics(self, product_id: int = None) -> dict:
bind=True, bind=True,
name="tasks.long_running_task", name="tasks.long_running_task",
time_limit=300, # 5 minutes time_limit=300, # 5 minutes
soft_time_limit=240 # 4 minutes soft_time_limit=240, # 4 minutes
) )
def long_running_task(self, iterations: int = 10) -> dict: def long_running_task(self, iterations: int = 10) -> dict:
""" """
Simulates a long-running task with progress tracking. Simulates a long-running task with progress tracking.
Args: Args:
iterations: Number of iterations to simulate iterations: Number of iterations to simulate
Returns: Returns:
Dictionary with results Dictionary with results
""" """
logger.info(f"Starting long-running task with {iterations} iterations") logger.info(f"Starting long-running task with {iterations} iterations")
results = [] results = []
for i in range(iterations): for i in range(iterations):
# Check if task has been revoked # Check if task has been revoked
if self.is_aborted(): if self.is_aborted():
logger.warning("Task was aborted") logger.warning("Task was aborted")
return {"status": "aborted", "completed_iterations": i} return {"status": "aborted", "completed_iterations": i}
# Simulate work # Simulate work
time.sleep(1) time.sleep(1)
# Update progress # Update progress
progress = (i + 1) / iterations * 100 progress = (i + 1) / iterations * 100
self.update_state( self.update_state(
state="PROGRESS", state="PROGRESS",
meta={"current": i + 1, "total": iterations, "progress": progress} meta={"current": i + 1, "total": iterations, "progress": progress},
) )
results.append(f"iteration_{i + 1}") results.append(f"iteration_{i + 1}")
logger.info(f"Completed iteration {i + 1}/{iterations}") logger.info(f"Completed iteration {i + 1}/{iterations}")
final_result = { final_result = {
"status": "completed", "status": "completed",
"iterations": iterations, "iterations": iterations,
"results": results, "results": results,
"completed_at": datetime.now().isoformat() "completed_at": datetime.now().isoformat(),
} }
logger.info(f"Long-running task completed: {final_result}") logger.info(f"Long-running task completed: {final_result}")
return final_result return final_result

View file

@ -4,17 +4,20 @@ from datetime import timedelta
class Config: class Config:
"""Base configuration""" """Base configuration"""
SECRET_KEY = os.environ.get("SECRET_KEY") or "dev-secret-key-change-in-production" SECRET_KEY = os.environ.get("SECRET_KEY") or "dev-secret-key-change-in-production"
SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_TRACK_MODIFICATIONS = False
JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"] JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"]
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=1) JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=1)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30) JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=30)
CORS_ORIGINS = os.environ.get("CORS_ORIGINS", "*") CORS_ORIGINS = os.environ.get("CORS_ORIGINS", "*")
# Celery Configuration # Celery Configuration
CELERY = { CELERY = {
"broker_url": os.environ.get("CELERY_BROKER_URL", "redis://redis:6379/0"), "broker_url": os.environ.get("CELERY_BROKER_URL", "redis://redis:6379/0"),
"result_backend": os.environ.get("CELERY_RESULT_BACKEND", "redis://redis:6379/0"), "result_backend": os.environ.get(
"CELERY_RESULT_BACKEND", "redis://redis:6379/0"
),
"task_serializer": "json", "task_serializer": "json",
"result_serializer": "json", "result_serializer": "json",
"accept_content": ["json"], "accept_content": ["json"],
@ -31,12 +34,14 @@ class Config:
class DevelopmentConfig(Config): class DevelopmentConfig(Config):
"""Development configuration""" """Development configuration"""
DEBUG = True DEBUG = True
SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or "sqlite:///dev.db" SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or "sqlite:///dev.db"
class TestingConfig(Config): class TestingConfig(Config):
"""Testing configuration""" """Testing configuration"""
TESTING = True TESTING = True
SQLALCHEMY_DATABASE_URI = os.environ.get("TEST_DATABASE_URL") or "sqlite:///test.db" SQLALCHEMY_DATABASE_URI = os.environ.get("TEST_DATABASE_URL") or "sqlite:///test.db"
WTF_CSRF_ENABLED = False WTF_CSRF_ENABLED = False
@ -44,9 +49,12 @@ class TestingConfig(Config):
class ProductionConfig(Config): class ProductionConfig(Config):
"""Production configuration""" """Production configuration"""
DEBUG = False DEBUG = False
SQLALCHEMY_DATABASE_URI = os.environ.get("DATABASE_URL") or "postgresql://user:password@localhost/proddb" SQLALCHEMY_DATABASE_URI = (
os.environ.get("DATABASE_URL") or "postgresql://user:password@localhost/proddb"
)
# Security headers # Security headers
SESSION_COOKIE_SECURE = True SESSION_COOKIE_SECURE = True
SESSION_COOKIE_HTTPONLY = True SESSION_COOKIE_HTTPONLY = True
@ -56,5 +64,5 @@ class ProductionConfig(Config):
config_by_name = { config_by_name = {
"dev": DevelopmentConfig, "dev": DevelopmentConfig,
"test": TestingConfig, "test": TestingConfig,
"prod": ProductionConfig "prod": ProductionConfig,
} }

View file

@ -1,5 +1,5 @@
from app.models.user import User
from app.models.product import Product
from app.models.order import Order, OrderItem from app.models.order import Order, OrderItem
from app.models.product import Product
from app.models.user import User
__all__ = ["User", "Product", "Order", "OrderItem"] __all__ = ["User", "Product", "Order", "OrderItem"]

View file

@ -1,23 +1,34 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from app import db from app import db
class Order(db.Model): class Order(db.Model):
"""Order model""" """Order model"""
__tablename__ = "orders" __tablename__ = "orders"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False)
status = db.Column(db.String(20), default="pending", index=True) status = db.Column(db.String(20), default="pending", index=True)
total_amount = db.Column(db.Numeric(10, 2), nullable=False) total_amount = db.Column(db.Numeric(10, 2), nullable=False)
shipping_address = db.Column(db.Text) shipping_address = db.Column(db.Text)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
user = db.relationship("User", back_populates="orders") user = db.relationship("User", back_populates="orders")
items = db.relationship("OrderItem", back_populates="order", lazy="dynamic", cascade="all, delete-orphan") items = db.relationship(
"OrderItem",
back_populates="order",
lazy="dynamic",
cascade="all, delete-orphan",
)
def to_dict(self): def to_dict(self):
"""Convert order to dictionary""" """Convert order to dictionary"""
return { return {
@ -28,27 +39,28 @@ class Order(db.Model):
"shipping_address": self.shipping_address, "shipping_address": self.shipping_address,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None,
"items": [item.to_dict() for item in self.items] "items": [item.to_dict() for item in self.items],
} }
def __repr__(self): def __repr__(self):
return f"<Order {self.id}>" return f"<Order {self.id}>"
class OrderItem(db.Model): class OrderItem(db.Model):
"""Order Item model""" """Order Item model"""
__tablename__ = "order_items" __tablename__ = "order_items"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
order_id = db.Column(db.Integer, db.ForeignKey("orders.id"), nullable=False) order_id = db.Column(db.Integer, db.ForeignKey("orders.id"), nullable=False)
product_id = db.Column(db.Integer, db.ForeignKey("products.id"), nullable=False) product_id = db.Column(db.Integer, db.ForeignKey("products.id"), nullable=False)
quantity = db.Column(db.Integer, nullable=False) quantity = db.Column(db.Integer, nullable=False)
price = db.Column(db.Numeric(10, 2), nullable=False) price = db.Column(db.Numeric(10, 2), nullable=False)
# Relationships # Relationships
order = db.relationship("Order", back_populates="items") order = db.relationship("Order", back_populates="items")
product = db.relationship("Product", back_populates="order_items") product = db.relationship("Product", back_populates="order_items")
def to_dict(self): def to_dict(self):
"""Convert order item to dictionary""" """Convert order item to dictionary"""
return { return {
@ -56,8 +68,8 @@ class OrderItem(db.Model):
"order_id": self.order_id, "order_id": self.order_id,
"product_id": self.product_id, "product_id": self.product_id,
"quantity": self.quantity, "quantity": self.quantity,
"price": float(self.price) if self.price else None "price": float(self.price) if self.price else None,
} }
def __repr__(self): def __repr__(self):
return f"<OrderItem {self.id}>" return f"<OrderItem {self.id}>"

View file

@ -1,11 +1,13 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from app import db from app import db
class Product(db.Model): class Product(db.Model):
"""Product model""" """Product model"""
__tablename__ = "products" __tablename__ = "products"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(200), nullable=False, index=True) name = db.Column(db.String(200), nullable=False, index=True)
description = db.Column(db.Text) description = db.Column(db.Text)
@ -14,11 +16,15 @@ class Product(db.Model):
image_url = db.Column(db.String(500)) image_url = db.Column(db.String(500))
is_active = db.Column(db.Boolean, default=True) is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
order_items = db.relationship("OrderItem", back_populates="product", lazy="dynamic") order_items = db.relationship("OrderItem", back_populates="product", lazy="dynamic")
def to_dict(self): def to_dict(self):
"""Convert product to dictionary""" """Convert product to dictionary"""
return { return {
@ -30,8 +36,8 @@ class Product(db.Model):
"image_url": self.image_url, "image_url": self.image_url,
"is_active": self.is_active, "is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None "updated_at": self.updated_at.isoformat() if self.updated_at else None,
} }
def __repr__(self): def __repr__(self):
return f"<Product {self.name}>" return f"<Product {self.name}>"

View file

@ -1,12 +1,15 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from werkzeug.security import generate_password_hash, check_password_hash
from werkzeug.security import check_password_hash, generate_password_hash
from app import db from app import db
class User(db.Model): class User(db.Model):
"""User model""" """User model"""
__tablename__ = "users" __tablename__ = "users"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
email = db.Column(db.String(120), unique=True, nullable=False, index=True) email = db.Column(db.String(120), unique=True, nullable=False, index=True)
username = db.Column(db.String(80), unique=True, nullable=False, index=True) username = db.Column(db.String(80), unique=True, nullable=False, index=True)
@ -16,19 +19,23 @@ class User(db.Model):
is_active = db.Column(db.Boolean, default=True) is_active = db.Column(db.Boolean, default=True)
is_admin = db.Column(db.Boolean, default=False) is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
orders = db.relationship("Order", back_populates="user", lazy="dynamic") orders = db.relationship("Order", back_populates="user", lazy="dynamic")
def set_password(self, password): def set_password(self, password):
"""Hash and set password""" """Hash and set password"""
self.password_hash = generate_password_hash(password) self.password_hash = generate_password_hash(password)
def check_password(self, password): def check_password(self, password):
"""Check if provided password matches hash""" """Check if provided password matches hash"""
return check_password_hash(self.password_hash, password) return check_password_hash(self.password_hash, password)
def to_dict(self): def to_dict(self):
"""Convert user to dictionary""" """Convert user to dictionary"""
return { return {
@ -40,8 +47,8 @@ class User(db.Model):
"is_active": self.is_active, "is_active": self.is_active,
"is_admin": self.is_admin, "is_admin": self.is_admin,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None "updated_at": self.updated_at.isoformat() if self.updated_at else None,
} }
def __repr__(self): def __repr__(self):
return f"<User {self.username}>" return f"<User {self.username}>"

View file

@ -1,4 +1,4 @@
from .api import api_bp from .api import api_bp
from .health import health_bp from .health import health_bp
__all__ = ["api_bp", "health_bp"] __all__ = ["api_bp", "health_bp"]

View file

@ -1,12 +1,15 @@
import time from flask import Blueprint, jsonify, request
from decimal import Decimal from flask_jwt_extended import (
create_access_token,
create_refresh_token,
get_jwt_identity,
jwt_required,
)
from pydantic import ValidationError from pydantic import ValidationError
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity, create_access_token, create_refresh_token
from app import db from app import db
from app.models import User, Product, OrderItem, Order
from app.celery import celery from app.celery import celery
from app.models import Order, OrderItem, Product, User
from app.schemas import ProductCreateRequest, ProductResponse from app.schemas import ProductCreateRequest, ProductResponse
api_bp = Blueprint("api", __name__) api_bp = Blueprint("api", __name__)
@ -17,24 +20,24 @@ api_bp = Blueprint("api", __name__)
def register(): def register():
"""Register a new user""" """Register a new user"""
data = request.get_json() data = request.get_json()
if not data or not data.get("email") or not data.get("password"): if not data or not data.get("email") or not data.get("password"):
return jsonify({"error": "Email and password are required"}), 400 return jsonify({"error": "Email and password are required"}), 400
if User.query.filter_by(email=data["email"]).first(): if User.query.filter_by(email=data["email"]).first():
return jsonify({"error": "Email already exists"}), 400 return jsonify({"error": "Email already exists"}), 400
user = User( user = User(
email=data["email"], email=data["email"],
username=data.get("username", data["email"].split("@")[0]), username=data.get("username", data["email"].split("@")[0]),
first_name=data.get("first_name"), first_name=data.get("first_name"),
last_name=data.get("last_name") last_name=data.get("last_name"),
) )
user.set_password(data["password"]) user.set_password(data["password"])
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
return jsonify(user.to_dict()), 201 return jsonify(user.to_dict()), 201
@ -42,26 +45,31 @@ def register():
def login(): def login():
"""Login user""" """Login user"""
data = request.get_json() data = request.get_json()
if not data or not data.get("email") or not data.get("password"): if not data or not data.get("email") or not data.get("password"):
return jsonify({"error": "Email and password are required"}), 400 return jsonify({"error": "Email and password are required"}), 400
user = User.query.filter_by(email=data["email"]).first() user = User.query.filter_by(email=data["email"]).first()
if not user or not user.check_password(data["password"]): if not user or not user.check_password(data["password"]):
return jsonify({"error": "Invalid credentials"}), 401 return jsonify({"error": "Invalid credentials"}), 401
if not user.is_active: if not user.is_active:
return jsonify({"error": "Account is inactive"}), 401 return jsonify({"error": "Account is inactive"}), 401
access_token = create_access_token(identity=str(user.id)) access_token = create_access_token(identity=str(user.id))
refresh_token = create_refresh_token(identity=str(user.id)) refresh_token = create_refresh_token(identity=str(user.id))
return jsonify({ return (
"user": user.to_dict(), jsonify(
"access_token": access_token, {
"refresh_token": refresh_token "user": user.to_dict(),
}), 200 "access_token": access_token,
"refresh_token": refresh_token,
}
),
200,
)
@api_bp.route("/users/me", methods=["GET"]) @api_bp.route("/users/me", methods=["GET"])
@ -70,10 +78,10 @@ def get_current_user():
"""Get current user""" """Get current user"""
user_id = int(get_jwt_identity()) user_id = int(get_jwt_identity())
user = db.session.get(User, user_id) user = db.session.get(User, user_id)
if not user: if not user:
return jsonify({"error": "User not found"}), 404 return jsonify({"error": "User not found"}), 404
return jsonify(user.to_dict()), 200 return jsonify(user.to_dict()), 200
@ -81,12 +89,11 @@ def get_current_user():
@api_bp.route("/products", methods=["GET"]) @api_bp.route("/products", methods=["GET"])
def get_products(): def get_products():
"""Get all products""" """Get all products"""
# time.sleep(5) # This adds a 5 second delay # time.sleep(5) # This adds a 5 second delay
products = Product.query.filter_by(is_active=True).all() products = Product.query.filter_by(is_active=True).all()
return jsonify([product.to_dict() for product in products]), 200 return jsonify([product.to_dict() for product in products]), 200
@ -105,29 +112,29 @@ def create_product():
"""Create a new product (admin only)""" """Create a new product (admin only)"""
user_id = int(get_jwt_identity()) user_id = int(get_jwt_identity())
user = db.session.get(User, user_id) user = db.session.get(User, user_id)
if not user or not user.is_admin: if not user or not user.is_admin:
return jsonify({"error": "Admin access required"}), 403 return jsonify({"error": "Admin access required"}), 403
try: try:
# Validate request data using Pydantic schema # Validate request data using Pydantic schema
product_data = ProductCreateRequest(**request.get_json()) product_data = ProductCreateRequest(**request.get_json())
product = Product( product = Product(
name=product_data.name, name=product_data.name,
description=product_data.description, description=product_data.description,
price=product_data.price, price=product_data.price,
stock=product_data.stock, stock=product_data.stock,
image_url=product_data.image_url image_url=product_data.image_url,
) )
db.session.add(product) db.session.add(product)
db.session.commit() db.session.commit()
# Use Pydantic schema for response # Use Pydantic schema for response
response = ProductResponse.model_validate(product) response = ProductResponse.model_validate(product)
return jsonify(response.model_dump()), 201 return jsonify(response.model_dump()), 201
except ValidationError as e: except ValidationError as e:
print(f"Pydantic Validation Error: {e.errors()}") print(f"Pydantic Validation Error: {e.errors()}")
return jsonify({"error": "Validation error", "details": e.errors()}), 400 return jsonify({"error": "Validation error", "details": e.errors()}), 400
@ -139,25 +146,25 @@ def update_product(product_id):
"""Update a product (admin only)""" """Update a product (admin only)"""
user_id = int(get_jwt_identity()) user_id = int(get_jwt_identity())
user = db.session.get(User, user_id) user = db.session.get(User, user_id)
if not user or not user.is_admin: if not user or not user.is_admin:
return jsonify({"error": "Admin access required"}), 403 return jsonify({"error": "Admin access required"}), 403
product = db.session.get(Product, product_id) product = db.session.get(Product, product_id)
if not product: if not product:
return jsonify({"error": "Product not found"}), 404 return jsonify({"error": "Product not found"}), 404
data = request.get_json() data = request.get_json()
product.name = data.get("name", product.name) product.name = data.get("name", product.name)
product.description = data.get("description", product.description) product.description = data.get("description", product.description)
product.price = data.get("price", product.price) product.price = data.get("price", product.price)
product.stock = data.get("stock", product.stock) product.stock = data.get("stock", product.stock)
product.image_url = data.get("image_url", product.image_url) product.image_url = data.get("image_url", product.image_url)
product.is_active = data.get("is_active", product.is_active) product.is_active = data.get("is_active", product.is_active)
db.session.commit() db.session.commit()
return jsonify(product.to_dict()), 200 return jsonify(product.to_dict()), 200
@ -167,17 +174,17 @@ def delete_product(product_id):
"""Delete a product (admin only)""" """Delete a product (admin only)"""
user_id = int(get_jwt_identity()) user_id = int(get_jwt_identity())
user = db.session.get(User, user_id) user = db.session.get(User, user_id)
if not user or not user.is_admin: if not user or not user.is_admin:
return jsonify({"error": "Admin access required"}), 403 return jsonify({"error": "Admin access required"}), 403
product = db.session.get(Product, product_id) product = db.session.get(Product, product_id)
if not product: if not product:
return jsonify({"error": "Product not found"}), 404 return jsonify({"error": "Product not found"}), 404
db.session.delete(product) db.session.delete(product)
db.session.commit() db.session.commit()
return jsonify({"message": "Product deleted"}), 200 return jsonify({"message": "Product deleted"}), 200
@ -197,49 +204,54 @@ def create_order():
"""Create a new order""" """Create a new order"""
user_id = int(get_jwt_identity()) user_id = int(get_jwt_identity())
data = request.get_json() data = request.get_json()
if not data or not data.get("items"): if not data or not data.get("items"):
return jsonify({"error": "Order items are required"}), 400 return jsonify({"error": "Order items are required"}), 400
total_amount = 0 total_amount = 0
order_items = [] order_items = []
for item_data in data["items"]: for item_data in data["items"]:
product = db.session.get(Product, item_data["product_id"]) product = db.session.get(Product, item_data["product_id"])
if not product: if not product:
return jsonify({"error": f'Product {item_data["product_id"]} not found'}), 404 return (
jsonify({"error": f'Product {item_data["product_id"]} not found'}),
404,
)
if product.stock < item_data["quantity"]: if product.stock < item_data["quantity"]:
return jsonify({"error": f'Insufficient stock for {product.name}'}), 400 return jsonify({"error": f"Insufficient stock for {product.name}"}), 400
item_total = product.price * item_data["quantity"] item_total = product.price * item_data["quantity"]
total_amount += item_total total_amount += item_total
order_items.append({ order_items.append(
"product": product, {
"quantity": item_data["quantity"], "product": product,
"price": product.price "quantity": item_data["quantity"],
}) "price": product.price,
}
)
order = Order( order = Order(
user_id=user_id, user_id=user_id,
total_amount=total_amount, total_amount=total_amount,
shipping_address=data.get("shipping_address") shipping_address=data.get("shipping_address"),
) )
db.session.add(order) db.session.add(order)
db.session.flush() db.session.flush()
for item_data in order_items: for item_data in order_items:
order_item = OrderItem( order_item = OrderItem(
order_id=order.id, order_id=order.id,
product_id=item_data["product"].id, product_id=item_data["product"].id,
quantity=item_data["quantity"], quantity=item_data["quantity"],
price=item_data["price"] price=item_data["price"],
) )
item_data["product"].stock -= item_data["quantity"] item_data["product"].stock -= item_data["quantity"]
db.session.add(order_item) db.session.add(order_item)
db.session.commit() db.session.commit()
return jsonify(order.to_dict()), 201 return jsonify(order.to_dict()), 201
@ -251,12 +263,12 @@ def get_order(order_id):
order = db.session.get(Order, order_id) order = db.session.get(Order, order_id)
if not order: if not order:
return jsonify({"error": "Order not found"}), 404 return jsonify({"error": "Order not found"}), 404
if order.user_id != user_id: if order.user_id != user_id:
user = db.session.get(User, user_id) user = db.session.get(User, user_id)
if not user or not user.is_admin: if not user or not user.is_admin:
return jsonify({"error": "Access denied"}), 403 return jsonify({"error": "Access denied"}), 403
return jsonify(order.to_dict()), 200 return jsonify(order.to_dict()), 200
@ -267,14 +279,15 @@ def trigger_hello_task():
"""Trigger the hello task""" """Trigger the hello task"""
data = request.get_json() or {} data = request.get_json() or {}
name = data.get("name", "World") name = data.get("name", "World")
task = celery.send_task("tasks.print_hello", args=[name]) task = celery.send_task("tasks.print_hello", args=[name])
return jsonify({ return (
"message": "Hello task triggered", jsonify(
"task_id": task.id, {"message": "Hello task triggered", "task_id": task.id, "status": "pending"}
"status": "pending" ),
}), 202 202,
)
@api_bp.route("/tasks/divide", methods=["POST"]) @api_bp.route("/tasks/divide", methods=["POST"])
@ -284,15 +297,20 @@ def trigger_divide_task():
data = request.get_json() or {} data = request.get_json() or {}
x = data.get("x", 10) x = data.get("x", 10)
y = data.get("y", 2) y = data.get("y", 2)
task = celery.send_task("tasks.divide_numbers", args=[x, y]) task = celery.send_task("tasks.divide_numbers", args=[x, y])
return jsonify({ return (
"message": "Divide task triggered", jsonify(
"task_id": task.id, {
"operation": f"{x} / {y}", "message": "Divide task triggered",
"status": "pending" "task_id": task.id,
}), 202 "operation": f"{x} / {y}",
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/report", methods=["POST"]) @api_bp.route("/tasks/report", methods=["POST"])
@ -300,12 +318,17 @@ def trigger_divide_task():
def trigger_report_task(): def trigger_report_task():
"""Trigger the daily report task""" """Trigger the daily report task"""
task = celery.send_task("tasks.send_daily_report") task = celery.send_task("tasks.send_daily_report")
return jsonify({ return (
"message": "Daily report task triggered", jsonify(
"task_id": task.id, {
"status": "pending" "message": "Daily report task triggered",
}), 202 "task_id": task.id,
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/stats", methods=["POST"]) @api_bp.route("/tasks/stats", methods=["POST"])
@ -314,19 +337,15 @@ def trigger_stats_task():
"""Trigger product statistics update task""" """Trigger product statistics update task"""
data = request.get_json() or {} data = request.get_json() or {}
product_id = data.get("product_id") product_id = data.get("product_id")
if product_id: if product_id:
task = celery.send_task("tasks.update_product_statistics", args=[product_id]) task = celery.send_task("tasks.update_product_statistics", args=[product_id])
message = f"Product statistics update triggered for product {product_id}" message = f"Product statistics update triggered for product {product_id}"
else: else:
task = celery.send_task("tasks.update_product_statistics", args=[None]) task = celery.send_task("tasks.update_product_statistics", args=[None])
message = "Product statistics update triggered for all products" message = "Product statistics update triggered for all products"
return jsonify({ return jsonify({"message": message, "task_id": task.id, "status": "pending"}), 202
"message": message,
"task_id": task.id,
"status": "pending"
}), 202
@api_bp.route("/tasks/long-running", methods=["POST"]) @api_bp.route("/tasks/long-running", methods=["POST"])
@ -335,14 +354,19 @@ def trigger_long_running_task():
"""Trigger a long-running task""" """Trigger a long-running task"""
data = request.get_json() or {} data = request.get_json() or {}
iterations = data.get("iterations", 10) iterations = data.get("iterations", 10)
task = celery.send_task("tasks.long_running_task", args=[iterations]) task = celery.send_task("tasks.long_running_task", args=[iterations])
return jsonify({ return (
"message": f"Long-running task triggered with {iterations} iterations", jsonify(
"task_id": task.id, {
"status": "pending" "message": f"Long-running task triggered with {iterations} iterations",
}), 202 "task_id": task.id,
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/<task_id>", methods=["GET"]) @api_bp.route("/tasks/<task_id>", methods=["GET"])
@ -350,20 +374,20 @@ def trigger_long_running_task():
def get_task_status(task_id): def get_task_status(task_id):
"""Get the status of a Celery task""" """Get the status of a Celery task"""
task_result = celery.AsyncResult(task_id) task_result = celery.AsyncResult(task_id)
response = { response = {
"task_id": task_id, "task_id": task_id,
"status": task_result.status, "status": task_result.status,
"ready": task_result.ready() "ready": task_result.ready(),
} }
if task_result.ready(): if task_result.ready():
if task_result.successful(): if task_result.successful():
response["result"] = task_result.result response["result"] = task_result.result
else: else:
response["error"] = str(task_result.result) response["error"] = str(task_result.result)
response["traceback"] = task_result.traceback response["traceback"] = task_result.traceback
return jsonify(response), 200 return jsonify(response), 200
@ -374,20 +398,18 @@ def celery_health():
# Try to ping the worker # Try to ping the worker
inspector = celery.control.inspect() inspector = celery.control.inspect()
stats = inspector.stats() stats = inspector.stats()
if stats: if stats:
return jsonify({ return (
"status": "healthy", jsonify(
"workers": len(stats), {"status": "healthy", "workers": len(stats), "workers_info": stats}
"workers_info": stats ),
}), 200 200,
)
else: else:
return jsonify({ return (
"status": "unhealthy", jsonify({"status": "unhealthy", "message": "No workers available"}),
"message": "No workers available" 503,
}), 503 )
except Exception as e: except Exception as e:
return jsonify({ return jsonify({"status": "error", "message": str(e)}), 500
"status": "error",
"message": str(e)
}), 500

View file

@ -1,22 +1,16 @@
from flask import Blueprint, jsonify from flask import Blueprint, jsonify
health_bp = Blueprint('health', __name__) health_bp = Blueprint("health", __name__)
@health_bp.route('/', methods=['GET']) @health_bp.route("/", methods=["GET"])
def health_check(): def health_check():
"""Health check endpoint""" """Health check endpoint"""
return jsonify({ return jsonify({"status": "healthy", "service": "crafting-shop-backend"}), 200
'status': 'healthy',
'service': 'crafting-shop-backend'
}), 200
@health_bp.route('/readiness', methods=['GET']) @health_bp.route("/readiness", methods=["GET"])
def readiness_check(): def readiness_check():
"""Readiness check endpoint""" """Readiness check endpoint"""
# Add database check here if needed # Add database check here if needed
return jsonify({ return jsonify({"status": "ready", "service": "crafting-shop-backend"}), 200
'status': 'ready',
'service': 'crafting-shop-backend'
}), 200

View file

@ -1,4 +1,4 @@
"""Pydantic schemas for request/response validation""" """Pydantic schemas for request/response validation"""
from app.schemas.product import ProductCreateRequest, ProductResponse from app.schemas.product import ProductCreateRequest, ProductResponse
__all__ = ["ProductCreateRequest", "ProductResponse"] __all__ = ["ProductCreateRequest", "ProductResponse"]

View file

@ -1,12 +1,14 @@
"""Pydantic schemas for Product model""" """Pydantic schemas for Product model"""
from pydantic import BaseModel, Field, field_validator, ConfigDict
from decimal import Decimal
from datetime import datetime from datetime import datetime
from decimal import Decimal
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ProductCreateRequest(BaseModel): class ProductCreateRequest(BaseModel):
"""Schema for creating a new product""" """Schema for creating a new product"""
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"example": { "example": {
@ -14,17 +16,21 @@ class ProductCreateRequest(BaseModel):
"description": "A beautiful handcrafted bowl made from oak", "description": "A beautiful handcrafted bowl made from oak",
"price": 45.99, "price": 45.99,
"stock": 10, "stock": 10,
"image_url": "https://example.com/bowl.jpg" "image_url": "https://example.com/bowl.jpg",
} }
} }
) )
name: str = Field(..., min_length=1, max_length=200, description="Product name") name: str = Field(..., min_length=1, max_length=200, description="Product name")
description: Optional[str] = Field(None, description="Product description") description: Optional[str] = Field(None, description="Product description")
price: Decimal = Field(..., gt=0, description="Product price (must be greater than 0)") price: Decimal = Field(
..., gt=0, description="Product price (must be greater than 0)"
)
stock: int = Field(default=0, ge=0, description="Product stock quantity") stock: int = Field(default=0, ge=0, description="Product stock quantity")
image_url: Optional[str] = Field(None, max_length=500, description="Product image URL") image_url: Optional[str] = Field(
None, max_length=500, description="Product image URL"
)
@field_validator("price") @field_validator("price")
@classmethod @classmethod
def validate_price(cls, v: Decimal) -> Decimal: def validate_price(cls, v: Decimal) -> Decimal:
@ -36,6 +42,7 @@ class ProductCreateRequest(BaseModel):
class ProductResponse(BaseModel): class ProductResponse(BaseModel):
"""Schema for product response""" """Schema for product response"""
model_config = ConfigDict( model_config = ConfigDict(
from_attributes=True, from_attributes=True,
json_schema_extra={ json_schema_extra={
@ -48,11 +55,11 @@ class ProductResponse(BaseModel):
"image_url": "https://example.com/bowl.jpg", "image_url": "https://example.com/bowl.jpg",
"is_active": True, "is_active": True,
"created_at": "2024-01-15T10:30:00", "created_at": "2024-01-15T10:30:00",
"updated_at": "2024-01-15T10:30:00" "updated_at": "2024-01-15T10:30:00",
} }
} },
) )
id: int id: int
name: str name: str
description: Optional[str] = None description: Optional[str] = None

View file

@ -1 +1 @@
"""Business logic services""" """Business logic services"""

View file

@ -1 +1 @@
"""Utility functions and helpers""" """Utility functions and helpers"""

View file

@ -1 +1 @@
"""Tests package for Flask application""" """Tests package for Flask application"""

View file

@ -1,28 +1,31 @@
"""Pytest configuration and fixtures""" """Pytest configuration and fixtures"""
import pytest
import tempfile
import os import os
import tempfile
import pytest
from faker import Faker from faker import Faker
from app import create_app, db from app import create_app, db
from app.models import User, Product, Order, OrderItem from app.models import Order, OrderItem, Product, User
fake = Faker() fake = Faker()
@pytest.fixture(scope='function') @pytest.fixture(scope="function")
def app(): def app():
"""Create application for testing with isolated database""" """Create application for testing with isolated database"""
db_fd, db_path = tempfile.mkstemp() db_fd, db_path = tempfile.mkstemp()
app = create_app(config_name='test') app = create_app(config_name="test")
app.config.update({ app.config.update(
'TESTING': True, {
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}', "TESTING": True,
'WTF_CSRF_ENABLED': False, "SQLALCHEMY_DATABASE_URI": f"sqlite:///{db_path}",
'JWT_SECRET_KEY': 'test-secret-keytest-secret-keytest-secret-keytest-secret-keytest-secret-key', "WTF_CSRF_ENABLED": False,
'SERVER_NAME': 'localhost.localdomain' "JWT_SECRET_KEY": "test-secret-keytest-secret-keytest-secret-keytest-secret-keytest-secret-key",
}) "SERVER_NAME": "localhost.localdomain",
}
)
with app.app_context(): with app.app_context():
db.create_all() db.create_all()
@ -62,9 +65,9 @@ def admin_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=True, is_admin=True,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -79,9 +82,9 @@ def regular_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=False, is_admin=False,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -96,9 +99,9 @@ def inactive_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=False, is_admin=False,
is_active=False is_active=False,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -112,7 +115,7 @@ def product(db_session):
description=fake.paragraph(), description=fake.paragraph(),
price=fake.pydecimal(left_digits=2, right_digits=2, positive=True), price=fake.pydecimal(left_digits=2, right_digits=2, positive=True),
stock=fake.pyint(min_value=0, max_value=100), stock=fake.pyint(min_value=0, max_value=100),
image_url=fake.url() image_url=fake.url(),
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
@ -129,7 +132,7 @@ def products(db_session):
description=fake.paragraph(), description=fake.paragraph(),
price=fake.pydecimal(left_digits=2, right_digits=2, positive=True), price=fake.pydecimal(left_digits=2, right_digits=2, positive=True),
stock=fake.pyint(min_value=0, max_value=100), stock=fake.pyint(min_value=0, max_value=100),
image_url=fake.url() image_url=fake.url(),
) )
db_session.add(product) db_session.add(product)
products.append(product) products.append(product)
@ -140,40 +143,36 @@ def products(db_session):
@pytest.fixture @pytest.fixture
def auth_headers(client, regular_user): def auth_headers(client, regular_user):
"""Get authentication headers for a regular user""" """Get authentication headers for a regular user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': regular_user.email, "/api/auth/login", json={"email": regular_user.email, "password": "password123"}
'password': 'password123' )
})
data = response.get_json() data = response.get_json()
token = data['access_token'] token = data["access_token"]
print(f"Auth headers token for user {regular_user.email}: {token[:50]}...") print(f"Auth headers token for user {regular_user.email}: {token[:50]}...")
return {'Authorization': f'Bearer {token}'} return {"Authorization": f"Bearer {token}"}
@pytest.fixture @pytest.fixture
def admin_headers(client, admin_user): def admin_headers(client, admin_user):
"""Get authentication headers for an admin user""" """Get authentication headers for an admin user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': admin_user.email, "/api/auth/login", json={"email": admin_user.email, "password": "password123"}
'password': 'password123' )
})
data = response.get_json() data = response.get_json()
token = data['access_token'] token = data["access_token"]
print(f"Admin headers token for user {admin_user.email}: {token[:50]}...") print(f"Admin headers token for user {admin_user.email}: {token[:50]}...")
return {'Authorization': f'Bearer {token}'} return {"Authorization": f"Bearer {token}"}
@pytest.fixture @pytest.fixture
def order(db_session, regular_user, products): def order(db_session, regular_user, products):
"""Create an order for testing""" """Create an order for testing"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id, total_amount=0.0, shipping_address=fake.address()
total_amount=0.0,
shipping_address=fake.address()
) )
db_session.add(order) db_session.add(order)
db_session.flush() db_session.flush()
total_amount = 0 total_amount = 0
for i, product in enumerate(products[:2]): for i, product in enumerate(products[:2]):
quantity = fake.pyint(min_value=1, max_value=5) quantity = fake.pyint(min_value=1, max_value=5)
@ -181,11 +180,11 @@ def order(db_session, regular_user, products):
order_id=order.id, order_id=order.id,
product_id=product.id, product_id=product.id,
quantity=quantity, quantity=quantity,
price=product.price price=product.price,
) )
total_amount += float(product.price) * quantity total_amount += float(product.price) * quantity
db_session.add(order_item) db_session.add(order_item)
order.total_amount = total_amount order.total_amount = total_amount
db_session.commit() db_session.commit()
return order return order

View file

@ -1,8 +1,9 @@
"""Test models""" """Test models"""
import pytest
from decimal import Decimal from decimal import Decimal
from datetime import datetime
from app.models import User, Product, Order, OrderItem import pytest
from app.models import Order, OrderItem, Product, User
class TestUserModel: class TestUserModel:
@ -12,62 +13,62 @@ class TestUserModel:
def test_user_creation(self, db_session): def test_user_creation(self, db_session):
"""Test creating a user""" """Test creating a user"""
user = User( user = User(
email='test@example.com', email="test@example.com",
username='testuser', username="testuser",
first_name='Test', first_name="Test",
last_name='User', last_name="User",
is_admin=False, is_admin=False,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert user.id is not None assert user.id is not None
assert user.email == 'test@example.com' assert user.email == "test@example.com"
assert user.username == 'testuser' assert user.username == "testuser"
assert user.first_name == 'Test' assert user.first_name == "Test"
assert user.last_name == 'User' assert user.last_name == "User"
@pytest.mark.unit @pytest.mark.unit
def test_user_password_hashing(self, db_session): def test_user_password_hashing(self, db_session):
"""Test password hashing and verification""" """Test password hashing and verification"""
user = User(email='test@example.com', username='testuser') user = User(email="test@example.com", username="testuser")
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert user.check_password('password123') is True assert user.check_password("password123") is True
assert user.check_password('wrongpassword') is False assert user.check_password("wrongpassword") is False
@pytest.mark.unit @pytest.mark.unit
def test_user_to_dict(self, db_session): def test_user_to_dict(self, db_session):
"""Test user serialization to dictionary""" """Test user serialization to dictionary"""
user = User( user = User(
email='test@example.com', email="test@example.com",
username='testuser', username="testuser",
first_name='Test', first_name="Test",
last_name='User' last_name="User",
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
user_dict = user.to_dict() user_dict = user.to_dict()
assert user_dict['email'] == 'test@example.com' assert user_dict["email"] == "test@example.com"
assert user_dict['username'] == 'testuser' assert user_dict["username"] == "testuser"
assert 'password' not in user_dict assert "password" not in user_dict
assert 'password_hash' not in user_dict assert "password_hash" not in user_dict
@pytest.mark.unit @pytest.mark.unit
def test_user_repr(self, db_session): def test_user_repr(self, db_session):
"""Test user string representation""" """Test user string representation"""
user = User(email='test@example.com', username='testuser') user = User(email="test@example.com", username="testuser")
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert repr(user) == '<User testuser>' assert repr(user) == "<User testuser>"
class TestProductModel: class TestProductModel:
@ -77,18 +78,18 @@ class TestProductModel:
def test_product_creation(self, db_session): def test_product_creation(self, db_session):
"""Test creating a product""" """Test creating a product"""
product = Product( product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('99.99'), price=Decimal("99.99"),
stock=10, stock=10,
image_url='https://example.com/product.jpg' image_url="https://example.com/product.jpg",
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
assert product.id is not None assert product.id is not None
assert product.name == 'Test Product' assert product.name == "Test Product"
assert product.price == Decimal('99.99') assert product.price == Decimal("99.99")
assert product.stock == 10 assert product.stock == 10
assert product.is_active is True assert product.is_active is True
@ -96,27 +97,24 @@ class TestProductModel:
def test_product_to_dict(self, db_session): def test_product_to_dict(self, db_session):
"""Test product serialization to dictionary""" """Test product serialization to dictionary"""
product = Product( product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('99.99'), price=Decimal("99.99"),
stock=10 stock=10,
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
product_dict = product.to_dict() product_dict = product.to_dict()
assert product_dict['name'] == 'Test Product' assert product_dict["name"] == "Test Product"
assert product_dict['price'] == 99.99 assert product_dict["price"] == 99.99
assert isinstance(product_dict['created_at'], str) assert isinstance(product_dict["created_at"], str)
assert isinstance(product_dict['updated_at'], str) assert isinstance(product_dict["updated_at"], str)
@pytest.mark.unit @pytest.mark.unit
def test_product_defaults(self, db_session): def test_product_defaults(self, db_session):
"""Test product default values""" """Test product default values"""
product = Product( product = Product(name="Test Product", price=Decimal("9.99"))
name='Test Product',
price=Decimal('9.99')
)
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
@ -128,11 +126,11 @@ class TestProductModel:
@pytest.mark.unit @pytest.mark.unit
def test_product_repr(self, db_session): def test_product_repr(self, db_session):
"""Test product string representation""" """Test product string representation"""
product = Product(name='Test Product', price=Decimal('9.99')) product = Product(name="Test Product", price=Decimal("9.99"))
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
assert repr(product) == '<Product Test Product>' assert repr(product) == "<Product Test Product>"
class TestOrderModel: class TestOrderModel:
@ -143,31 +141,31 @@ class TestOrderModel:
"""Test creating an order""" """Test creating an order"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id,
total_amount=Decimal('199.99'), total_amount=Decimal("199.99"),
shipping_address='123 Test St' shipping_address="123 Test St",
) )
db_session.add(order) db_session.add(order)
db_session.commit() db_session.commit()
assert order.id is not None assert order.id is not None
assert order.user_id == regular_user.id assert order.user_id == regular_user.id
assert order.total_amount == Decimal('199.99') assert order.total_amount == Decimal("199.99")
@pytest.mark.unit @pytest.mark.unit
def test_order_to_dict(self, db_session, regular_user): def test_order_to_dict(self, db_session, regular_user):
"""Test order serialization to dictionary""" """Test order serialization to dictionary"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id,
total_amount=Decimal('199.99'), total_amount=Decimal("199.99"),
shipping_address='123 Test St' shipping_address="123 Test St",
) )
db_session.add(order) db_session.add(order)
db_session.commit() db_session.commit()
order_dict = order.to_dict() order_dict = order.to_dict()
assert order_dict['user_id'] == regular_user.id assert order_dict["user_id"] == regular_user.id
assert order_dict['total_amount'] == 199.99 assert order_dict["total_amount"] == 199.99
assert isinstance(order_dict['created_at'], str) assert isinstance(order_dict["created_at"], str)
class TestOrderItemModel: class TestOrderItemModel:
@ -177,10 +175,7 @@ class TestOrderItemModel:
def test_order_item_creation(self, db_session, order, product): def test_order_item_creation(self, db_session, order, product):
"""Test creating an order item""" """Test creating an order item"""
order_item = OrderItem( order_item = OrderItem(
order_id=order.id, order_id=order.id, product_id=product.id, quantity=2, price=product.price
product_id=product.id,
quantity=2,
price=product.price
) )
db_session.add(order_item) db_session.add(order_item)
db_session.commit() db_session.commit()
@ -194,15 +189,12 @@ class TestOrderItemModel:
def test_order_item_to_dict(self, db_session, order, product): def test_order_item_to_dict(self, db_session, order, product):
"""Test order item serialization to dictionary""" """Test order item serialization to dictionary"""
order_item = OrderItem( order_item = OrderItem(
order_id=order.id, order_id=order.id, product_id=product.id, quantity=2, price=product.price
product_id=product.id,
quantity=2,
price=product.price
) )
db_session.add(order_item) db_session.add(order_item)
db_session.commit() db_session.commit()
item_dict = order_item.to_dict() item_dict = order_item.to_dict()
assert item_dict['order_id'] == order.id assert item_dict["order_id"] == order.id
assert item_dict['product_id'] == product.id assert item_dict["product_id"] == product.id
assert item_dict['quantity'] == 2 assert item_dict["quantity"] == 2

View file

@ -1,7 +1,5 @@
"""Test API routes""" """Test API routes"""
import pytest import pytest
import json
from decimal import Decimal
class TestAuthRoutes: class TestAuthRoutes:
@ -10,101 +8,109 @@ class TestAuthRoutes:
@pytest.mark.auth @pytest.mark.auth
def test_register_success(self, client): def test_register_success(self, client):
"""Test successful user registration""" """Test successful user registration"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': 'newuser@example.com', "/api/auth/register",
'password': 'password123', json={
'username': 'newuser', "email": "newuser@example.com",
'first_name': 'New', "password": "password123",
'last_name': 'User' "username": "newuser",
}) "first_name": "New",
"last_name": "User",
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['email'] == 'newuser@example.com' assert data["email"] == "newuser@example.com"
assert data['username'] == 'newuser' assert data["username"] == "newuser"
assert 'password' not in data assert "password" not in data
assert 'password_hash' not in data assert "password_hash" not in data
@pytest.mark.auth @pytest.mark.auth
def test_register_missing_fields(self, client): def test_register_missing_fields(self, client):
"""Test registration with missing required fields""" """Test registration with missing required fields"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': 'newuser@example.com' "/api/auth/register", json={"email": "newuser@example.com"}
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'error' in data assert "error" in data
@pytest.mark.auth @pytest.mark.auth
def test_register_duplicate_email(self, client, regular_user): def test_register_duplicate_email(self, client, regular_user):
"""Test registration with duplicate email""" """Test registration with duplicate email"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': regular_user.email, "/api/auth/register",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'already exists' in data['error'].lower() assert "already exists" in data["error"].lower()
@pytest.mark.auth @pytest.mark.auth
def test_login_success(self, client, regular_user): def test_login_success(self, client, regular_user):
"""Test successful login""" """Test successful login"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': regular_user.email, "/api/auth/login",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert 'access_token' in data assert "access_token" in data
assert 'refresh_token' in data assert "refresh_token" in data
assert data['user']['email'] == regular_user.email assert data["user"]["email"] == regular_user.email
@pytest.mark.auth @pytest.mark.auth
@pytest.mark.parametrize("email,password,expected_status", [ @pytest.mark.parametrize(
("wrong@example.com", "password123", 401), "email,password,expected_status",
("user@example.com", "wrongpassword", 401), [
(None, "password123", 400), ("wrong@example.com", "password123", 401),
("user@example.com", None, 400), ("user@example.com", "wrongpassword", 401),
]) (None, "password123", 400),
def test_login_validation(self, client, regular_user, email, password, expected_status): ("user@example.com", None, 400),
],
)
def test_login_validation(
self, client, regular_user, email, password, expected_status
):
"""Test login with various invalid inputs""" """Test login with various invalid inputs"""
login_data = {} login_data = {}
if email is not None: if email is not None:
login_data['email'] = email login_data["email"] = email
if password is not None: if password is not None:
login_data['password'] = password login_data["password"] = password
response = client.post('/api/auth/login', json=login_data) response = client.post("/api/auth/login", json=login_data)
assert response.status_code == expected_status assert response.status_code == expected_status
@pytest.mark.auth @pytest.mark.auth
def test_login_inactive_user(self, client, inactive_user): def test_login_inactive_user(self, client, inactive_user):
"""Test login with inactive user""" """Test login with inactive user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': inactive_user.email, "/api/auth/login",
'password': 'password123' json={"email": inactive_user.email, "password": "password123"},
}) )
assert response.status_code == 401 assert response.status_code == 401
data = response.get_json() data = response.get_json()
assert 'inactive' in data['error'].lower() assert "inactive" in data["error"].lower()
@pytest.mark.auth @pytest.mark.auth
def test_get_current_user(self, client, auth_headers, regular_user): def test_get_current_user(self, client, auth_headers, regular_user):
"""Test getting current user""" """Test getting current user"""
response = client.get('/api/users/me', headers=auth_headers) response = client.get("/api/users/me", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['email'] == regular_user.email assert data["email"] == regular_user.email
@pytest.mark.auth @pytest.mark.auth
def test_get_current_user_unauthorized(self, client): def test_get_current_user_unauthorized(self, client):
"""Test getting current user without authentication""" """Test getting current user without authentication"""
response = client.get('/api/users/me') response = client.get("/api/users/me")
assert response.status_code == 401 assert response.status_code == 401
@ -114,7 +120,7 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_products(self, client, products): def test_get_products(self, client, products):
"""Test getting all products""" """Test getting all products"""
response = client.get('/api/products') response = client.get("/api/products")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -123,7 +129,7 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_products_empty(self, client): def test_get_products_empty(self, client):
"""Test getting products when none exist""" """Test getting products when none exist"""
response = client.get('/api/products') response = client.get("/api/products")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -132,113 +138,122 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_single_product(self, client, product): def test_get_single_product(self, client, product):
"""Test getting a single product""" """Test getting a single product"""
response = client.get(f'/api/products/{product.id}') response = client.get(f"/api/products/{product.id}")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['id'] == product.id assert data["id"] == product.id
assert data['name'] == product.name assert data["name"] == product.name
@pytest.mark.product @pytest.mark.product
def test_get_product_not_found(self, client): def test_get_product_not_found(self, client):
"""Test getting non-existent product""" """Test getting non-existent product"""
response = client.get('/api/products/999') response = client.get("/api/products/999")
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.product @pytest.mark.product
def test_create_product_admin(self, client, admin_headers): def test_create_product_admin(self, client, admin_headers):
"""Test creating product as admin""" """Test creating product as admin"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'description': 'A new product', headers=admin_headers,
'price': 29.99, json={
'stock': 10 "name": "New Product",
}) "description": "A new product",
"price": 29.99,
"stock": 10,
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['name'] == 'New Product' assert data["name"] == "New Product"
assert data['price'] == 29.99 assert data["price"] == 29.99
@pytest.mark.product @pytest.mark.product
def test_create_product_regular_user(self, client, auth_headers): def test_create_product_regular_user(self, client, auth_headers):
"""Test creating product as regular user (should fail)""" """Test creating product as regular user (should fail)"""
response = client.post('/api/products', headers=auth_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'price': 29.99 headers=auth_headers,
}) json={"name": "New Product", "price": 29.99},
)
assert response.status_code == 403 assert response.status_code == 403
data = response.get_json() data = response.get_json()
assert 'admin' in data['error'].lower() assert "admin" in data["error"].lower()
@pytest.mark.product @pytest.mark.product
def test_create_product_unauthorized(self, client): def test_create_product_unauthorized(self, client):
"""Test creating product without authentication""" """Test creating product without authentication"""
response = client.post('/api/products', json={ response = client.post(
'name': 'New Product', "/api/products", json={"name": "New Product", "price": 29.99}
'price': 29.99 )
})
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.product @pytest.mark.product
def test_create_product_validation_error(self, client, admin_headers): def test_create_product_validation_error(self, client, admin_headers):
"""Test creating product with invalid data""" """Test creating product with invalid data"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'price': -10.99 headers=admin_headers,
}) json={"name": "New Product", "price": -10.99},
)
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'Validation error' in data['error'] assert "Validation error" in data["error"]
@pytest.mark.product @pytest.mark.product
def test_create_product_missing_required_fields(self, client, admin_headers): def test_create_product_missing_required_fields(self, client, admin_headers):
"""Test creating product with missing required fields""" """Test creating product with missing required fields"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'description': 'Missing name and price' "/api/products",
}) headers=admin_headers,
json={"description": "Missing name and price"},
)
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'Validation error' in data['error'] assert "Validation error" in data["error"]
@pytest.mark.product @pytest.mark.product
def test_create_product_minimal_data(self, client, admin_headers): def test_create_product_minimal_data(self, client, admin_headers):
"""Test creating product with minimal valid data""" """Test creating product with minimal valid data"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'Minimal Product', "/api/products",
'price': 19.99 headers=admin_headers,
}) json={"name": "Minimal Product", "price": 19.99},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['name'] == 'Minimal Product' assert data["name"] == "Minimal Product"
assert data['stock'] == 0 # Default value assert data["stock"] == 0 # Default value
@pytest.mark.product @pytest.mark.product
def test_update_product_admin(self, client, admin_headers, product): def test_update_product_admin(self, client, admin_headers, product):
"""Test updating product as admin""" """Test updating product as admin"""
response = client.put(f'/api/products/{product.id}', headers=admin_headers, json={ response = client.put(
'name': 'Updated Product', f"/api/products/{product.id}",
'price': 39.99 headers=admin_headers,
}) json={"name": "Updated Product", "price": 39.99},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['name'] == 'Updated Product' assert data["name"] == "Updated Product"
assert data['price'] == 39.99 assert data["price"] == 39.99
@pytest.mark.product @pytest.mark.product
def test_delete_product_admin(self, client, admin_headers, product): def test_delete_product_admin(self, client, admin_headers, product):
"""Test deleting product as admin""" """Test deleting product as admin"""
response = client.delete(f'/api/products/{product.id}', headers=admin_headers) response = client.delete(f"/api/products/{product.id}", headers=admin_headers)
assert response.status_code == 200 assert response.status_code == 200
# Verify product is deleted # Verify product is deleted
response = client.get(f'/api/products/{product.id}') response = client.get(f"/api/products/{product.id}")
assert response.status_code == 404 assert response.status_code == 404
@ -248,7 +263,7 @@ class TestOrderRoutes:
@pytest.mark.order @pytest.mark.order
def test_get_orders(self, client, auth_headers, order): def test_get_orders(self, client, auth_headers, order):
"""Test getting orders for current user""" """Test getting orders for current user"""
response = client.get('/api/orders', headers=auth_headers) response = client.get("/api/orders", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -257,63 +272,68 @@ class TestOrderRoutes:
@pytest.mark.order @pytest.mark.order
def test_get_orders_unauthorized(self, client): def test_get_orders_unauthorized(self, client):
"""Test getting orders without authentication""" """Test getting orders without authentication"""
response = client.get('/api/orders') response = client.get("/api/orders")
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.order @pytest.mark.order
def test_create_order(self, client, auth_headers, products): def test_create_order(self, client, auth_headers, products):
"""Test creating an order""" """Test creating an order"""
response = client.post('/api/orders', headers=auth_headers, json={ response = client.post(
'items': [ "/api/orders",
{'product_id': products[0].id, 'quantity': 2}, headers=auth_headers,
{'product_id': products[1].id, 'quantity': 1} json={
], "items": [
'shipping_address': '123 Test St' {"product_id": products[0].id, "quantity": 2},
}) {"product_id": products[1].id, "quantity": 1},
],
"shipping_address": "123 Test St",
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert 'id' in data assert "id" in data
assert len(data['items']) == 2 assert len(data["items"]) == 2
@pytest.mark.order @pytest.mark.order
def test_create_order_insufficient_stock(self, client, auth_headers, db_session, products): def test_create_order_insufficient_stock(
self, client, auth_headers, db_session, products
):
"""Test creating order with insufficient stock""" """Test creating order with insufficient stock"""
# Set stock to 0 # Set stock to 0
products[0].stock = 0 products[0].stock = 0
db_session.commit() db_session.commit()
response = client.post('/api/orders', headers=auth_headers, json={ response = client.post(
'items': [ "/api/orders",
{'product_id': products[0].id, 'quantity': 2} headers=auth_headers,
] json={"items": [{"product_id": products[0].id, "quantity": 2}]},
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'insufficient' in data['error'].lower() assert "insufficient" in data["error"].lower()
@pytest.mark.order @pytest.mark.order
def test_get_single_order(self, client, auth_headers, order): def test_get_single_order(self, client, auth_headers, order):
"""Test getting a single order""" """Test getting a single order"""
response = client.get(f'/api/orders/{order.id}', headers=auth_headers) response = client.get(f"/api/orders/{order.id}", headers=auth_headers)
print('test_get_single_order', response.get_json()) print("test_get_single_order", response.get_json())
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['id'] == order.id assert data["id"] == order.id
@pytest.mark.order @pytest.mark.order
def test_get_other_users_order(self, client, admin_headers, regular_user, products): def test_get_other_users_order(self, client, admin_headers, regular_user, products):
"""Test admin accessing another user's order""" """Test admin accessing another user's order"""
# Create an order for regular_user # Create an order for regular_user
client.post('/api/auth/login', json={ client.post(
'email': regular_user.email, "/api/auth/login",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
# Admin should be able to access any order # Admin should be able to access any order
response = client.get(f'/api/orders/1', headers=admin_headers)
# This test assumes order exists, adjust as needed # This test assumes order exists, adjust as needed
pass pass

View file

@ -1,7 +1,9 @@
"""Test Pydantic schemas""" """Test Pydantic schemas"""
from decimal import Decimal
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from decimal import Decimal
from app.schemas import ProductCreateRequest, ProductResponse from app.schemas import ProductCreateRequest, ProductResponse
@ -12,31 +14,28 @@ class TestProductCreateRequestSchema:
def test_valid_product_request(self): def test_valid_product_request(self):
"""Test valid product creation request""" """Test valid product creation request"""
data = { data = {
'name': 'Handcrafted Wooden Bowl', "name": "Handcrafted Wooden Bowl",
'description': 'A beautiful handcrafted bowl', "description": "A beautiful handcrafted bowl",
'price': 45.99, "price": 45.99,
'stock': 10, "stock": 10,
'image_url': 'https://example.com/bowl.jpg' "image_url": "https://example.com/bowl.jpg",
} }
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.name == data['name'] assert product.name == data["name"]
assert product.description == data['description'] assert product.description == data["description"]
assert product.price == Decimal('45.99') assert product.price == Decimal("45.99")
assert product.stock == 10 assert product.stock == 10
assert product.image_url == data['image_url'] assert product.image_url == data["image_url"]
@pytest.mark.unit @pytest.mark.unit
def test_minimal_valid_request(self): def test_minimal_valid_request(self):
"""Test minimal valid request (only required fields)""" """Test minimal valid request (only required fields)"""
data = { data = {"name": "Simple Product", "price": 19.99}
'name': 'Simple Product',
'price': 19.99
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.name == 'Simple Product' assert product.name == "Simple Product"
assert product.price == Decimal('19.99') assert product.price == Decimal("19.99")
assert product.stock == 0 assert product.stock == 0
assert product.description is None assert product.description is None
assert product.image_url is None assert product.image_url is None
@ -44,134 +43,107 @@ class TestProductCreateRequestSchema:
@pytest.mark.unit @pytest.mark.unit
def test_missing_name(self): def test_missing_name(self):
"""Test request with missing name""" """Test request with missing name"""
data = { data = {"price": 19.99}
'price': 19.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('name',) for error in errors) assert any(error["loc"] == ("name",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_missing_price(self): def test_missing_price(self):
"""Test request with missing price""" """Test request with missing price"""
data = { data = {"name": "Test Product"}
'name': 'Test Product'
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('price',) for error in errors) assert any(error["loc"] == ("price",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_negative(self): def test_invalid_price_negative(self):
"""Test request with negative price""" """Test request with negative price"""
data = { data = {"name": "Test Product", "price": -10.99}
'name': 'Test Product',
'price': -10.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than' for error in errors) assert any(error["type"] == "greater_than" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_zero(self): def test_invalid_price_zero(self):
"""Test request with zero price""" """Test request with zero price"""
data = { data = {"name": "Test Product", "price": 0.0}
'name': 'Test Product',
'price': 0.0
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than' for error in errors) assert any(error["type"] == "greater_than" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_too_many_decimals(self): def test_invalid_price_too_many_decimals(self):
"""Test request with too many decimal places""" """Test request with too many decimal places"""
data = { data = {"name": "Test Product", "price": 10.999}
'name': 'Test Product',
'price': 10.999
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any('decimal places' in str(error).lower() for error in errors) assert any("decimal places" in str(error).lower() for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_stock_negative(self): def test_invalid_stock_negative(self):
"""Test request with negative stock""" """Test request with negative stock"""
data = { data = {"name": "Test Product", "price": 19.99, "stock": -5}
'name': 'Test Product',
'price': 19.99,
'stock': -5
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than_equal' for error in errors) assert any(error["type"] == "greater_than_equal" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_name_too_long(self): def test_name_too_long(self):
"""Test request with name exceeding max length""" """Test request with name exceeding max length"""
data = { data = {"name": "A" * 201, "price": 19.99} # Exceeds 200 character limit
'name': 'A' * 201, # Exceeds 200 character limit
'price': 19.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('name',) for error in errors) assert any(error["loc"] == ("name",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_image_url_too_long(self): def test_image_url_too_long(self):
"""Test request with image_url exceeding max length""" """Test request with image_url exceeding max length"""
data = { data = {
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'image_url': 'A' * 501 # Exceeds 500 character limit "image_url": "A" * 501, # Exceeds 500 character limit
} }
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('image_url',) for error in errors) assert any(error["loc"] == ("image_url",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_price_string_conversion(self): def test_price_string_conversion(self):
"""Test price string to Decimal conversion""" """Test price string to Decimal conversion"""
data = { data = {"name": "Test Product", "price": "29.99"}
'name': 'Test Product',
'price': '29.99'
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.price == Decimal('29.99') assert product.price == Decimal("29.99")
@pytest.mark.unit @pytest.mark.unit
def test_stock_string_conversion(self): def test_stock_string_conversion(self):
"""Test stock string to int conversion""" """Test stock string to int conversion"""
data = { data = {"name": "Test Product", "price": 19.99, "stock": "10"}
'name': 'Test Product',
'price': 19.99,
'stock': '10'
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.stock == 10 assert product.stock == 10
@ -185,20 +157,20 @@ class TestProductResponseSchema:
def test_valid_product_response(self): def test_valid_product_response(self):
"""Test valid product response""" """Test valid product response"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'description': 'A test product', "description": "A test product",
'price': 45.99, "price": 45.99,
'stock': 10, "stock": 10,
'image_url': 'https://example.com/product.jpg', "image_url": "https://example.com/product.jpg",
'is_active': True, "is_active": True,
'created_at': '2024-01-15T10:30:00', "created_at": "2024-01-15T10:30:00",
'updated_at': '2024-01-15T10:30:00' "updated_at": "2024-01-15T10:30:00",
} }
product = ProductResponse(**data) product = ProductResponse(**data)
assert product.id == 1 assert product.id == 1
assert product.name == 'Test Product' assert product.name == "Test Product"
assert product.price == 45.99 assert product.price == 45.99
assert product.stock == 10 assert product.stock == 10
assert product.is_active is True assert product.is_active is True
@ -207,11 +179,11 @@ class TestProductResponseSchema:
def test_product_response_with_none_fields(self): def test_product_response_with_none_fields(self):
"""Test product response with optional None fields""" """Test product response with optional None fields"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 0, "stock": 0,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
@ -224,19 +196,19 @@ class TestProductResponseSchema:
def test_model_validate_from_sqlalchemy(self, db_session): def test_model_validate_from_sqlalchemy(self, db_session):
"""Test validating SQLAlchemy model to Pydantic schema""" """Test validating SQLAlchemy model to Pydantic schema"""
from app.models import Product from app.models import Product
db_product = Product( db_product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('45.99'), price=Decimal("45.99"),
stock=10 stock=10,
) )
db_session.add(db_product) db_session.add(db_product)
db_session.commit() db_session.commit()
# Validate using model_validate (for SQLAlchemy models) # Validate using model_validate (for SQLAlchemy models)
response = ProductResponse.model_validate(db_product) response = ProductResponse.model_validate(db_product)
assert response.name == 'Test Product' assert response.name == "Test Product"
assert response.price == 45.99 assert response.price == 45.99
assert response.stock == 10 assert response.stock == 10
@ -244,34 +216,34 @@ class TestProductResponseSchema:
def test_model_dump(self): def test_model_dump(self):
"""Test model_dump method""" """Test model_dump method"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 5, "stock": 5,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
dumped = product.model_dump() dumped = product.model_dump()
assert isinstance(dumped, dict) assert isinstance(dumped, dict)
assert dumped['id'] == 1 assert dumped["id"] == 1
assert dumped['name'] == 'Test Product' assert dumped["name"] == "Test Product"
assert dumped['price'] == 19.99 assert dumped["price"] == 19.99
@pytest.mark.unit @pytest.mark.unit
def test_model_dump_json(self): def test_model_dump_json(self):
"""Test model_dump_json method""" """Test model_dump_json method"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 5, "stock": 5,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
json_str = product.model_dump_json() json_str = product.model_dump_json()
assert isinstance(json_str, str) assert isinstance(json_str, str)
assert 'Test Product' in json_str assert "Test Product" in json_str