update actions config, fix lint errors

This commit is contained in:
david 2026-02-24 19:19:15 +03:00
parent a9b0fd34a2
commit 1c35df931a
26 changed files with 732 additions and 691 deletions

View file

@ -14,7 +14,7 @@ on:
jobs: jobs:
test: test:
runs-on: ubuntu-latest runs-on: [docker]
strategy: strategy:
matrix: matrix:
@ -32,8 +32,6 @@ jobs:
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
--health-retries 5 --health-retries 5
ports:
- 5432:5432
redis: redis:
image: redis:7-alpine image: redis:7-alpine
@ -106,7 +104,7 @@ jobs:
coverage report --fail-under=80 coverage report --fail-under=80
security-scan: security-scan:
runs-on: ubuntu-latest runs-on: [docker]
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4

View file

@ -1,56 +1,56 @@
name: CD # name: CD
on: # on:
push: # push:
branches: [ main ] # branches: [ main ]
workflow_dispatch: # workflow_dispatch:
jobs: # jobs:
deploy: # deploy:
runs-on: [docker] # runs-on: [docker]
steps: # steps:
- uses: actions/checkout@v3 # - uses: actions/checkout@v3
- name: Configure AWS credentials # - name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v2 # uses: aws-actions/configure-aws-credentials@v2
with: # with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} # aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} # aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ secrets.AWS_REGION }} # aws-region: ${{ secrets.AWS_REGION }}
- name: Login to Amazon ECR # - name: Login to Amazon ECR
id: login-ecr # id: login-ecr
uses: aws-actions/amazon-ecr-login@v1 # uses: aws-actions/amazon-ecr-login@v1
- name: Build and push backend # - name: Build and push backend
env: # env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} # ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: crafting-shop-backend # ECR_REPOSITORY: crafting-shop-backend
IMAGE_TAG: ${{ github.sha }} # IMAGE_TAG: ${{ github.sha }}
run: | # run: |
cd backend # cd backend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . # docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG # docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest
docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest
- name: Build and push frontend # - name: Build and push frontend
env: # env:
ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} # ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
ECR_REPOSITORY: crafting-shop-frontend # ECR_REPOSITORY: crafting-shop-frontend
IMAGE_TAG: ${{ github.sha }} # IMAGE_TAG: ${{ github.sha }}
run: | # run: |
cd frontend # cd frontend
docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . # docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG # docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker tag $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG $ECR_REGISTRY/$ECR_REPOSITORY:latest
docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest # docker push $ECR_REGISTRY/$ECR_REPOSITORY:latest
- name: Deploy to ECS # - name: Deploy to ECS
uses: aws-actions/amazon-ecs-deploy-task-definition@v1 # uses: aws-actions/amazon-ecs-deploy-task-definition@v1
with: # with:
task-definition: crafting-shop-task # task-definition: crafting-shop-task
service: crafting-shop-service # service: crafting-shop-service
cluster: crafting-shop-cluster # cluster: crafting-shop-cluster
wait-for-service-stability: true # wait-for-service-stability: true

View file

@ -22,22 +22,23 @@ jobs:
--health-interval 10s --health-interval 10s
--health-timeout 5s --health-timeout 5s
--health-retries 5 --health-retries 5
ports:
- 5432:5432 container:
image: nikolaik/python-nodejs:python3.12-nodejs24-alpine
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v6
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v4 run: |
with: python --version
python-version: '3.11'
- name: Cache pip packages - name: Cache pip packages
uses: actions/cache@v3 uses: actions/cache@v3
with: with:
path: ~/.cache/pip # This path is on the runner, but we mounted it to the container
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} path: /tmp/pip-cache
key: ${{ runner.os }}-pip-${{ hashFiles('backend/requirements/dev.txt') }}
restore-keys: | restore-keys: |
${{ runner.os }}-pip- ${{ runner.os }}-pip-
@ -50,12 +51,11 @@ jobs:
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
cd backend cd backend
flake8 app tests --count --select=E9,F63,F7,F82 --show-source --statistics flake8 app tests --count --max-complexity=10 --max-line-length=127 --statistics --show-source
flake8 app tests --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Run tests - name: Run tests
env: env:
DATABASE_URL: postgresql://test:test@localhost:5432/test_db DATABASE_URL: postgresql://test:test@postgres:5432/test_db
SECRET_KEY: test-secret-key SECRET_KEY: test-secret-key
JWT_SECRET_KEY: test-jwt-secret JWT_SECRET_KEY: test-jwt-secret
FLASK_ENV: testing FLASK_ENV: testing
@ -63,13 +63,6 @@ jobs:
cd backend cd backend
pytest --cov=app --cov-report=xml --cov-report=term pytest --cov=app --cov-report=xml --cov-report=term
- name: Upload coverage
uses: codecov/codecov-action@v3
with:
files: ./backend/coverage.xml
flags: backend
name: backend-coverage
frontend-test: frontend-test:
runs-on: [docker] runs-on: [docker]

View file

@ -126,7 +126,23 @@ lint-frontend: ## Lint frontend only
format: ## Format code format: ## Format code
@echo "Formatting backend..." @echo "Formatting backend..."
cd backend && . venv/bin/activate && black app tests && isort app tests cd backend && . venv/bin/activate && black app tests
cd backend && . venv/bin/activate && isort app tests
@echo "Formatting frontend..."
cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}"
format-backend: ## Format backend code only
@echo "Formatting backend with black..."
cd backend && . venv/bin/activate && black app tests
@echo "Sorting imports with isort..."
cd backend && . venv/bin/activate && isort app tests
format-backend-check: ## Check if backend code needs formatting
@echo "Checking backend formatting..."
cd backend && . venv/bin/activate && black --check app tests
cd backend && . venv/bin/activate && isort --check-only app tests
format-frontend: ## Format frontend code only
@echo "Formatting frontend..." @echo "Formatting frontend..."
cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}" cd frontend && npx prettier --write "src/**/*.{js,jsx,ts,tsx,css}"

View file

@ -1,11 +1,12 @@
import json import os
from dotenv import load_dotenv
from flask import Flask, jsonify from flask import Flask, jsonify
from flask_cors import CORS from flask_cors import CORS
from flask_jwt_extended import JWTManager from flask_jwt_extended import JWTManager
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate from flask_migrate import Migrate
import os from flask_sqlalchemy import SQLAlchemy
from dotenv import load_dotenv
# Create extensions but don't initialize them yet # Create extensions but don't initialize them yet
db = SQLAlchemy() db = SQLAlchemy()
migrate = Migrate() migrate = Migrate()
@ -13,6 +14,7 @@ jwt = JWTManager()
cors = CORS() cors = CORS()
load_dotenv(override=True) load_dotenv(override=True)
def create_app(config_name=None): def create_app(config_name=None):
"""Application factory pattern""" """Application factory pattern"""
app = Flask(__name__) app = Flask(__name__)
@ -22,28 +24,35 @@ def create_app(config_name=None):
config_name = os.environ.get("FLASK_ENV", "development") config_name = os.environ.get("FLASK_ENV", "development")
from app.config import config_by_name from app.config import config_by_name
app.config.from_object(config_by_name[config_name]) app.config.from_object(config_by_name[config_name])
print('----------------------------------------------------------') print("----------------------------------------------------------")
print(F'------------------ENVIRONMENT: {config_name}-------------------------------------') print(
f"------------------ENVIRONMENT: {config_name}-------------------------------------"
)
# print(F'------------------CONFIG: {app.config}-------------------------------------') # print(F'------------------CONFIG: {app.config}-------------------------------------')
# print(json.dumps(dict(app.config), indent=2, default=str)) # print(json.dumps(dict(app.config), indent=2, default=str))
print('----------------------------------------------------------') print("----------------------------------------------------------")
# Initialize extensions with app # Initialize extensions with app
db.init_app(app) db.init_app(app)
migrate.init_app(app, db) migrate.init_app(app, db)
jwt.init_app(app) jwt.init_app(app)
cors.init_app(app, resources={r"/api/*": {"origins": app.config.get("CORS_ORIGINS", "*")}}) cors.init_app(
app, resources={r"/api/*": {"origins": app.config.get("CORS_ORIGINS", "*")}}
)
# Initialize Celery # Initialize Celery
from app.celery import init_celery from app.celery import init_celery
init_celery(app) init_celery(app)
# Import models (required for migrations) # Import models (required for migrations)
from app.models import user, product, order from app.models import order, product, user # noqa: F401
# Register blueprints # Register blueprints
from app.routes import api_bp, health_bp from app.routes import api_bp, health_bp
app.register_blueprint(api_bp, url_prefix="/api") app.register_blueprint(api_bp, url_prefix="/api")
app.register_blueprint(health_bp) app.register_blueprint(health_bp)

View file

@ -20,7 +20,7 @@ def make_celery(app: Flask) -> Celery:
celery_app = Celery( celery_app = Celery(
app.import_name, app.import_name,
broker=app.config["CELERY"]["broker_url"], broker=app.config["CELERY"]["broker_url"],
backend=app.config["CELERY"]["result_backend"] backend=app.config["CELERY"]["result_backend"],
) )
# Update configuration from Flask config # Update configuration from Flask config
@ -30,6 +30,7 @@ def make_celery(app: Flask) -> Celery:
# This ensures tasks have access to Flask extensions (db, etc.) # This ensures tasks have access to Flask extensions (db, etc.)
class ContextTask(celery_app.Task): class ContextTask(celery_app.Task):
"""Celery task that runs within Flask application context.""" """Celery task that runs within Flask application context."""
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
with app.app_context(): with app.app_context():
return self.run(*args, **kwargs) return self.run(*args, **kwargs)
@ -37,14 +38,15 @@ def make_celery(app: Flask) -> Celery:
celery_app.Task = ContextTask celery_app.Task = ContextTask
# Auto-discover tasks in the tasks module # Auto-discover tasks in the tasks module
celery_app.autodiscover_tasks(['app.celery.tasks']) celery_app.autodiscover_tasks(["app.celery.tasks"])
# Configure Beat schedule # Configure Beat schedule
from .beat_schedule import configure_beat_schedule from .beat_schedule import configure_beat_schedule
configure_beat_schedule(celery_app) configure_beat_schedule(celery_app)
# Import tasks to ensure they're registered # Import tasks to ensure they're registered
from .tasks import example_tasks from .tasks import example_tasks # noqa: F401
print(f"✅ Celery configured with broker: {celery_app.conf.broker_url}") print(f"✅ Celery configured with broker: {celery_app.conf.broker_url}")
print(f"✅ Celery configured with backend: {celery_app.conf.result_backend}") print(f"✅ Celery configured with backend: {celery_app.conf.result_backend}")

View file

@ -4,7 +4,6 @@ This defines when scheduled tasks should run.
""" """
from celery.schedules import crontab from celery.schedules import crontab
# Celery Beat schedule configuration # Celery Beat schedule configuration
beat_schedule = { beat_schedule = {
# Run every minute (for testing/demo) # Run every minute (for testing/demo)
@ -14,14 +13,12 @@ beat_schedule = {
"args": ("Celery Beat",), "args": ("Celery Beat",),
"options": {"queue": "default"}, "options": {"queue": "default"},
}, },
# Run daily at 9:00 AM # Run daily at 9:00 AM
"send-daily-report": { "send-daily-report": {
"task": "tasks.send_daily_report", "task": "tasks.send_daily_report",
"schedule": crontab(hour=9, minute=0), # 9:00 AM daily "schedule": crontab(hour=9, minute=0), # 9:00 AM daily
"options": {"queue": "reports"}, "options": {"queue": "reports"},
}, },
# Run every hour at minute 0 # Run every hour at minute 0
"update-product-stats-hourly": { "update-product-stats-hourly": {
"task": "tasks.update_product_statistics", "task": "tasks.update_product_statistics",
@ -29,7 +26,6 @@ beat_schedule = {
"args": (None,), # Update all products "args": (None,), # Update all products
"options": {"queue": "stats"}, "options": {"queue": "stats"},
}, },
# Run every Monday at 8:00 AM # Run every Monday at 8:00 AM
"weekly-maintenance": { "weekly-maintenance": {
"task": "tasks.long_running_task", "task": "tasks.long_running_task",
@ -37,7 +33,6 @@ beat_schedule = {
"args": (5,), # 5 iterations "args": (5,), # 5 iterations
"options": {"queue": "maintenance"}, "options": {"queue": "maintenance"},
}, },
# Run every 5 minutes (for monitoring/heartbeat) # Run every 5 minutes (for monitoring/heartbeat)
"heartbeat-check": { "heartbeat-check": {
"task": "tasks.print_hello", "task": "tasks.print_hello",

View file

@ -4,12 +4,12 @@ Tasks are organized by domain/functionality.
""" """
# Import all task modules here to ensure they're registered with Celery # Import all task modules here to ensure they're registered with Celery
from . import example_tasks from . import example_tasks # noqa: F401
# Re-export tasks for easier imports # Re-export tasks for easier imports
from .example_tasks import ( from .example_tasks import ( # noqa: F401
print_hello,
divide_numbers, divide_numbers,
print_hello,
send_daily_report, send_daily_report,
update_product_statistics, update_product_statistics,
) )

View file

@ -2,11 +2,11 @@
Example Celery tasks for the Crafting Shop application. Example Celery tasks for the Crafting Shop application.
These tasks demonstrate various Celery features and best practices. These tasks demonstrate various Celery features and best practices.
""" """
import time
import logging import logging
import time
from datetime import datetime from datetime import datetime
from celery import shared_task from celery import shared_task
from celery.exceptions import MaxRetriesExceededError
# Get logger for this module # Get logger for this module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ def print_hello(self, name: str = "World") -> str:
retry_backoff=True, retry_backoff=True,
retry_backoff_max=60, retry_backoff_max=60,
retry_jitter=True, retry_jitter=True,
max_retries=3 max_retries=3,
) )
def divide_numbers(self, x: float, y: float) -> float: def divide_numbers(self, x: float, y: float) -> float:
""" """
@ -55,7 +55,7 @@ def divide_numbers(self, x: float, y: float) -> float:
logger.info(f"Dividing {x} by {y} (attempt {self.request.retries + 1})") logger.info(f"Dividing {x} by {y} (attempt {self.request.retries + 1})")
if y == 0: if y == 0:
logger.warning(f"Division by zero detected, retrying...") logger.warning("Division by zero detected, retrying...")
raise ZeroDivisionError("Cannot divide by zero") raise ZeroDivisionError("Cannot divide by zero")
result = x / y result = x / y
@ -63,11 +63,7 @@ def divide_numbers(self, x: float, y: float) -> float:
return result return result
@shared_task( @shared_task(bind=True, name="tasks.send_daily_report", ignore_result=False)
bind=True,
name="tasks.send_daily_report",
ignore_result=False
)
def send_daily_report(self) -> dict: def send_daily_report(self) -> dict:
""" """
Simulates sending a daily report. Simulates sending a daily report.
@ -90,8 +86,8 @@ def send_daily_report(self) -> dict:
"total_products": 150, "total_products": 150,
"total_orders": 42, "total_orders": 42,
"total_users": 89, "total_users": 89,
"revenue": 12500.75 "revenue": 12500.75,
} },
} }
logger.info(f"Daily report generated: {report_data}") logger.info(f"Daily report generated: {report_data}")
@ -101,10 +97,7 @@ def send_daily_report(self) -> dict:
@shared_task( @shared_task(
bind=True, bind=True, name="tasks.update_product_statistics", queue="stats", priority=5
name="tasks.update_product_statistics",
queue="stats",
priority=5
) )
def update_product_statistics(self, product_id: int = None) -> dict: def update_product_statistics(self, product_id: int = None) -> dict:
""" """
@ -129,7 +122,7 @@ def update_product_statistics(self, product_id: int = None) -> dict:
"task": "update_all_product_stats", "task": "update_all_product_stats",
"status": "completed", "status": "completed",
"products_updated": 150, "products_updated": 150,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat(),
} }
else: else:
# Update specific product # Update specific product
@ -138,11 +131,7 @@ def update_product_statistics(self, product_id: int = None) -> dict:
"product_id": product_id, "product_id": product_id,
"status": "completed", "status": "completed",
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
"new_stats": { "new_stats": {"views": 125, "purchases": 15, "rating": 4.5},
"views": 125,
"purchases": 15,
"rating": 4.5
}
} }
logger.info(f"Product statistics updated: {result}") logger.info(f"Product statistics updated: {result}")
@ -153,7 +142,7 @@ def update_product_statistics(self, product_id: int = None) -> dict:
bind=True, bind=True,
name="tasks.long_running_task", name="tasks.long_running_task",
time_limit=300, # 5 minutes time_limit=300, # 5 minutes
soft_time_limit=240 # 4 minutes soft_time_limit=240, # 4 minutes
) )
def long_running_task(self, iterations: int = 10) -> dict: def long_running_task(self, iterations: int = 10) -> dict:
""" """
@ -181,7 +170,7 @@ def long_running_task(self, iterations: int = 10) -> dict:
progress = (i + 1) / iterations * 100 progress = (i + 1) / iterations * 100
self.update_state( self.update_state(
state="PROGRESS", state="PROGRESS",
meta={"current": i + 1, "total": iterations, "progress": progress} meta={"current": i + 1, "total": iterations, "progress": progress},
) )
results.append(f"iteration_{i + 1}") results.append(f"iteration_{i + 1}")
@ -191,7 +180,7 @@ def long_running_task(self, iterations: int = 10) -> dict:
"status": "completed", "status": "completed",
"iterations": iterations, "iterations": iterations,
"results": results, "results": results,
"completed_at": datetime.now().isoformat() "completed_at": datetime.now().isoformat(),
} }
logger.info(f"Long-running task completed: {final_result}") logger.info(f"Long-running task completed: {final_result}")

View file

@ -4,6 +4,7 @@ from datetime import timedelta
class Config: class Config:
"""Base configuration""" """Base configuration"""
SECRET_KEY = os.environ.get("SECRET_KEY") or "dev-secret-key-change-in-production" SECRET_KEY = os.environ.get("SECRET_KEY") or "dev-secret-key-change-in-production"
SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_TRACK_MODIFICATIONS = False
JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"] JWT_SECRET_KEY = os.environ["JWT_SECRET_KEY"]
@ -14,7 +15,9 @@ class Config:
# Celery Configuration # Celery Configuration
CELERY = { CELERY = {
"broker_url": os.environ.get("CELERY_BROKER_URL", "redis://redis:6379/0"), "broker_url": os.environ.get("CELERY_BROKER_URL", "redis://redis:6379/0"),
"result_backend": os.environ.get("CELERY_RESULT_BACKEND", "redis://redis:6379/0"), "result_backend": os.environ.get(
"CELERY_RESULT_BACKEND", "redis://redis:6379/0"
),
"task_serializer": "json", "task_serializer": "json",
"result_serializer": "json", "result_serializer": "json",
"accept_content": ["json"], "accept_content": ["json"],
@ -31,12 +34,14 @@ class Config:
class DevelopmentConfig(Config): class DevelopmentConfig(Config):
"""Development configuration""" """Development configuration"""
DEBUG = True DEBUG = True
SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or "sqlite:///dev.db" SQLALCHEMY_DATABASE_URI = os.environ.get("DEV_DATABASE_URL") or "sqlite:///dev.db"
class TestingConfig(Config): class TestingConfig(Config):
"""Testing configuration""" """Testing configuration"""
TESTING = True TESTING = True
SQLALCHEMY_DATABASE_URI = os.environ.get("TEST_DATABASE_URL") or "sqlite:///test.db" SQLALCHEMY_DATABASE_URI = os.environ.get("TEST_DATABASE_URL") or "sqlite:///test.db"
WTF_CSRF_ENABLED = False WTF_CSRF_ENABLED = False
@ -44,8 +49,11 @@ class TestingConfig(Config):
class ProductionConfig(Config): class ProductionConfig(Config):
"""Production configuration""" """Production configuration"""
DEBUG = False DEBUG = False
SQLALCHEMY_DATABASE_URI = os.environ.get("DATABASE_URL") or "postgresql://user:password@localhost/proddb" SQLALCHEMY_DATABASE_URI = (
os.environ.get("DATABASE_URL") or "postgresql://user:password@localhost/proddb"
)
# Security headers # Security headers
SESSION_COOKIE_SECURE = True SESSION_COOKIE_SECURE = True
@ -56,5 +64,5 @@ class ProductionConfig(Config):
config_by_name = { config_by_name = {
"dev": DevelopmentConfig, "dev": DevelopmentConfig,
"test": TestingConfig, "test": TestingConfig,
"prod": ProductionConfig "prod": ProductionConfig,
} }

View file

@ -1,5 +1,5 @@
from app.models.user import User
from app.models.product import Product
from app.models.order import Order, OrderItem from app.models.order import Order, OrderItem
from app.models.product import Product
from app.models.user import User
__all__ = ["User", "Product", "Order", "OrderItem"] __all__ = ["User", "Product", "Order", "OrderItem"]

View file

@ -1,9 +1,11 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from app import db from app import db
class Order(db.Model): class Order(db.Model):
"""Order model""" """Order model"""
__tablename__ = "orders" __tablename__ = "orders"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -12,11 +14,20 @@ class Order(db.Model):
total_amount = db.Column(db.Numeric(10, 2), nullable=False) total_amount = db.Column(db.Numeric(10, 2), nullable=False)
shipping_address = db.Column(db.Text) shipping_address = db.Column(db.Text)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
user = db.relationship("User", back_populates="orders") user = db.relationship("User", back_populates="orders")
items = db.relationship("OrderItem", back_populates="order", lazy="dynamic", cascade="all, delete-orphan") items = db.relationship(
"OrderItem",
back_populates="order",
lazy="dynamic",
cascade="all, delete-orphan",
)
def to_dict(self): def to_dict(self):
"""Convert order to dictionary""" """Convert order to dictionary"""
@ -28,7 +39,7 @@ class Order(db.Model):
"shipping_address": self.shipping_address, "shipping_address": self.shipping_address,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None,
"items": [item.to_dict() for item in self.items] "items": [item.to_dict() for item in self.items],
} }
def __repr__(self): def __repr__(self):
@ -37,6 +48,7 @@ class Order(db.Model):
class OrderItem(db.Model): class OrderItem(db.Model):
"""Order Item model""" """Order Item model"""
__tablename__ = "order_items" __tablename__ = "order_items"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -56,7 +68,7 @@ class OrderItem(db.Model):
"order_id": self.order_id, "order_id": self.order_id,
"product_id": self.product_id, "product_id": self.product_id,
"quantity": self.quantity, "quantity": self.quantity,
"price": float(self.price) if self.price else None "price": float(self.price) if self.price else None,
} }
def __repr__(self): def __repr__(self):

View file

@ -1,9 +1,11 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from app import db from app import db
class Product(db.Model): class Product(db.Model):
"""Product model""" """Product model"""
__tablename__ = "products" __tablename__ = "products"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -14,7 +16,11 @@ class Product(db.Model):
image_url = db.Column(db.String(500)) image_url = db.Column(db.String(500))
is_active = db.Column(db.Boolean, default=True) is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
order_items = db.relationship("OrderItem", back_populates="product", lazy="dynamic") order_items = db.relationship("OrderItem", back_populates="product", lazy="dynamic")
@ -30,7 +36,7 @@ class Product(db.Model):
"image_url": self.image_url, "image_url": self.image_url,
"is_active": self.is_active, "is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None "updated_at": self.updated_at.isoformat() if self.updated_at else None,
} }
def __repr__(self): def __repr__(self):

View file

@ -1,10 +1,13 @@
from datetime import datetime, UTC from datetime import UTC, datetime
from werkzeug.security import generate_password_hash, check_password_hash
from werkzeug.security import check_password_hash, generate_password_hash
from app import db from app import db
class User(db.Model): class User(db.Model):
"""User model""" """User model"""
__tablename__ = "users" __tablename__ = "users"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -16,7 +19,11 @@ class User(db.Model):
is_active = db.Column(db.Boolean, default=True) is_active = db.Column(db.Boolean, default=True)
is_admin = db.Column(db.Boolean, default=False) is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC)) created_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC))
updated_at = db.Column(db.DateTime, default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) updated_at = db.Column(
db.DateTime,
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
# Relationships # Relationships
orders = db.relationship("Order", back_populates="user", lazy="dynamic") orders = db.relationship("Order", back_populates="user", lazy="dynamic")
@ -40,7 +47,7 @@ class User(db.Model):
"is_active": self.is_active, "is_active": self.is_active,
"is_admin": self.is_admin, "is_admin": self.is_admin,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None "updated_at": self.updated_at.isoformat() if self.updated_at else None,
} }
def __repr__(self): def __repr__(self):

View file

@ -1,12 +1,15 @@
import time from flask import Blueprint, jsonify, request
from decimal import Decimal from flask_jwt_extended import (
create_access_token,
create_refresh_token,
get_jwt_identity,
jwt_required,
)
from pydantic import ValidationError from pydantic import ValidationError
from flask import Blueprint, request, jsonify
from flask_jwt_extended import jwt_required, get_jwt_identity, create_access_token, create_refresh_token
from app import db from app import db
from app.models import User, Product, OrderItem, Order
from app.celery import celery from app.celery import celery
from app.models import Order, OrderItem, Product, User
from app.schemas import ProductCreateRequest, ProductResponse from app.schemas import ProductCreateRequest, ProductResponse
api_bp = Blueprint("api", __name__) api_bp = Blueprint("api", __name__)
@ -28,7 +31,7 @@ def register():
email=data["email"], email=data["email"],
username=data.get("username", data["email"].split("@")[0]), username=data.get("username", data["email"].split("@")[0]),
first_name=data.get("first_name"), first_name=data.get("first_name"),
last_name=data.get("last_name") last_name=data.get("last_name"),
) )
user.set_password(data["password"]) user.set_password(data["password"])
@ -57,11 +60,16 @@ def login():
access_token = create_access_token(identity=str(user.id)) access_token = create_access_token(identity=str(user.id))
refresh_token = create_refresh_token(identity=str(user.id)) refresh_token = create_refresh_token(identity=str(user.id))
return jsonify({ return (
"user": user.to_dict(), jsonify(
"access_token": access_token, {
"refresh_token": refresh_token "user": user.to_dict(),
}), 200 "access_token": access_token,
"refresh_token": refresh_token,
}
),
200,
)
@api_bp.route("/users/me", methods=["GET"]) @api_bp.route("/users/me", methods=["GET"])
@ -82,7 +90,6 @@ def get_current_user():
def get_products(): def get_products():
"""Get all products""" """Get all products"""
# time.sleep(5) # This adds a 5 second delay # time.sleep(5) # This adds a 5 second delay
products = Product.query.filter_by(is_active=True).all() products = Product.query.filter_by(is_active=True).all()
@ -118,7 +125,7 @@ def create_product():
description=product_data.description, description=product_data.description,
price=product_data.price, price=product_data.price,
stock=product_data.stock, stock=product_data.stock,
image_url=product_data.image_url image_url=product_data.image_url,
) )
db.session.add(product) db.session.add(product)
@ -207,22 +214,27 @@ def create_order():
for item_data in data["items"]: for item_data in data["items"]:
product = db.session.get(Product, item_data["product_id"]) product = db.session.get(Product, item_data["product_id"])
if not product: if not product:
return jsonify({"error": f'Product {item_data["product_id"]} not found'}), 404 return (
jsonify({"error": f'Product {item_data["product_id"]} not found'}),
404,
)
if product.stock < item_data["quantity"]: if product.stock < item_data["quantity"]:
return jsonify({"error": f'Insufficient stock for {product.name}'}), 400 return jsonify({"error": f"Insufficient stock for {product.name}"}), 400
item_total = product.price * item_data["quantity"] item_total = product.price * item_data["quantity"]
total_amount += item_total total_amount += item_total
order_items.append({ order_items.append(
"product": product, {
"quantity": item_data["quantity"], "product": product,
"price": product.price "quantity": item_data["quantity"],
}) "price": product.price,
}
)
order = Order( order = Order(
user_id=user_id, user_id=user_id,
total_amount=total_amount, total_amount=total_amount,
shipping_address=data.get("shipping_address") shipping_address=data.get("shipping_address"),
) )
db.session.add(order) db.session.add(order)
@ -233,7 +245,7 @@ def create_order():
order_id=order.id, order_id=order.id,
product_id=item_data["product"].id, product_id=item_data["product"].id,
quantity=item_data["quantity"], quantity=item_data["quantity"],
price=item_data["price"] price=item_data["price"],
) )
item_data["product"].stock -= item_data["quantity"] item_data["product"].stock -= item_data["quantity"]
db.session.add(order_item) db.session.add(order_item)
@ -270,11 +282,12 @@ def trigger_hello_task():
task = celery.send_task("tasks.print_hello", args=[name]) task = celery.send_task("tasks.print_hello", args=[name])
return jsonify({ return (
"message": "Hello task triggered", jsonify(
"task_id": task.id, {"message": "Hello task triggered", "task_id": task.id, "status": "pending"}
"status": "pending" ),
}), 202 202,
)
@api_bp.route("/tasks/divide", methods=["POST"]) @api_bp.route("/tasks/divide", methods=["POST"])
@ -287,12 +300,17 @@ def trigger_divide_task():
task = celery.send_task("tasks.divide_numbers", args=[x, y]) task = celery.send_task("tasks.divide_numbers", args=[x, y])
return jsonify({ return (
"message": "Divide task triggered", jsonify(
"task_id": task.id, {
"operation": f"{x} / {y}", "message": "Divide task triggered",
"status": "pending" "task_id": task.id,
}), 202 "operation": f"{x} / {y}",
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/report", methods=["POST"]) @api_bp.route("/tasks/report", methods=["POST"])
@ -301,11 +319,16 @@ def trigger_report_task():
"""Trigger the daily report task""" """Trigger the daily report task"""
task = celery.send_task("tasks.send_daily_report") task = celery.send_task("tasks.send_daily_report")
return jsonify({ return (
"message": "Daily report task triggered", jsonify(
"task_id": task.id, {
"status": "pending" "message": "Daily report task triggered",
}), 202 "task_id": task.id,
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/stats", methods=["POST"]) @api_bp.route("/tasks/stats", methods=["POST"])
@ -322,11 +345,7 @@ def trigger_stats_task():
task = celery.send_task("tasks.update_product_statistics", args=[None]) task = celery.send_task("tasks.update_product_statistics", args=[None])
message = "Product statistics update triggered for all products" message = "Product statistics update triggered for all products"
return jsonify({ return jsonify({"message": message, "task_id": task.id, "status": "pending"}), 202
"message": message,
"task_id": task.id,
"status": "pending"
}), 202
@api_bp.route("/tasks/long-running", methods=["POST"]) @api_bp.route("/tasks/long-running", methods=["POST"])
@ -338,11 +357,16 @@ def trigger_long_running_task():
task = celery.send_task("tasks.long_running_task", args=[iterations]) task = celery.send_task("tasks.long_running_task", args=[iterations])
return jsonify({ return (
"message": f"Long-running task triggered with {iterations} iterations", jsonify(
"task_id": task.id, {
"status": "pending" "message": f"Long-running task triggered with {iterations} iterations",
}), 202 "task_id": task.id,
"status": "pending",
}
),
202,
)
@api_bp.route("/tasks/<task_id>", methods=["GET"]) @api_bp.route("/tasks/<task_id>", methods=["GET"])
@ -354,7 +378,7 @@ def get_task_status(task_id):
response = { response = {
"task_id": task_id, "task_id": task_id,
"status": task_result.status, "status": task_result.status,
"ready": task_result.ready() "ready": task_result.ready(),
} }
if task_result.ready(): if task_result.ready():
@ -376,18 +400,16 @@ def celery_health():
stats = inspector.stats() stats = inspector.stats()
if stats: if stats:
return jsonify({ return (
"status": "healthy", jsonify(
"workers": len(stats), {"status": "healthy", "workers": len(stats), "workers_info": stats}
"workers_info": stats ),
}), 200 200,
)
else: else:
return jsonify({ return (
"status": "unhealthy", jsonify({"status": "unhealthy", "message": "No workers available"}),
"message": "No workers available" 503,
}), 503 )
except Exception as e: except Exception as e:
return jsonify({ return jsonify({"status": "error", "message": str(e)}), 500
"status": "error",
"message": str(e)
}), 500

View file

@ -1,22 +1,16 @@
from flask import Blueprint, jsonify from flask import Blueprint, jsonify
health_bp = Blueprint('health', __name__) health_bp = Blueprint("health", __name__)
@health_bp.route('/', methods=['GET']) @health_bp.route("/", methods=["GET"])
def health_check(): def health_check():
"""Health check endpoint""" """Health check endpoint"""
return jsonify({ return jsonify({"status": "healthy", "service": "crafting-shop-backend"}), 200
'status': 'healthy',
'service': 'crafting-shop-backend'
}), 200
@health_bp.route('/readiness', methods=['GET']) @health_bp.route("/readiness", methods=["GET"])
def readiness_check(): def readiness_check():
"""Readiness check endpoint""" """Readiness check endpoint"""
# Add database check here if needed # Add database check here if needed
return jsonify({ return jsonify({"status": "ready", "service": "crafting-shop-backend"}), 200
'status': 'ready',
'service': 'crafting-shop-backend'
}), 200

View file

@ -1,12 +1,14 @@
"""Pydantic schemas for Product model""" """Pydantic schemas for Product model"""
from pydantic import BaseModel, Field, field_validator, ConfigDict
from decimal import Decimal
from datetime import datetime from datetime import datetime
from decimal import Decimal
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ProductCreateRequest(BaseModel): class ProductCreateRequest(BaseModel):
"""Schema for creating a new product""" """Schema for creating a new product"""
model_config = ConfigDict( model_config = ConfigDict(
json_schema_extra={ json_schema_extra={
"example": { "example": {
@ -14,16 +16,20 @@ class ProductCreateRequest(BaseModel):
"description": "A beautiful handcrafted bowl made from oak", "description": "A beautiful handcrafted bowl made from oak",
"price": 45.99, "price": 45.99,
"stock": 10, "stock": 10,
"image_url": "https://example.com/bowl.jpg" "image_url": "https://example.com/bowl.jpg",
} }
} }
) )
name: str = Field(..., min_length=1, max_length=200, description="Product name") name: str = Field(..., min_length=1, max_length=200, description="Product name")
description: Optional[str] = Field(None, description="Product description") description: Optional[str] = Field(None, description="Product description")
price: Decimal = Field(..., gt=0, description="Product price (must be greater than 0)") price: Decimal = Field(
..., gt=0, description="Product price (must be greater than 0)"
)
stock: int = Field(default=0, ge=0, description="Product stock quantity") stock: int = Field(default=0, ge=0, description="Product stock quantity")
image_url: Optional[str] = Field(None, max_length=500, description="Product image URL") image_url: Optional[str] = Field(
None, max_length=500, description="Product image URL"
)
@field_validator("price") @field_validator("price")
@classmethod @classmethod
@ -36,6 +42,7 @@ class ProductCreateRequest(BaseModel):
class ProductResponse(BaseModel): class ProductResponse(BaseModel):
"""Schema for product response""" """Schema for product response"""
model_config = ConfigDict( model_config = ConfigDict(
from_attributes=True, from_attributes=True,
json_schema_extra={ json_schema_extra={
@ -48,9 +55,9 @@ class ProductResponse(BaseModel):
"image_url": "https://example.com/bowl.jpg", "image_url": "https://example.com/bowl.jpg",
"is_active": True, "is_active": True,
"created_at": "2024-01-15T10:30:00", "created_at": "2024-01-15T10:30:00",
"updated_at": "2024-01-15T10:30:00" "updated_at": "2024-01-15T10:30:00",
} }
} },
) )
id: int id: int

View file

@ -1,28 +1,31 @@
"""Pytest configuration and fixtures""" """Pytest configuration and fixtures"""
import pytest
import tempfile
import os import os
import tempfile
import pytest
from faker import Faker from faker import Faker
from app import create_app, db from app import create_app, db
from app.models import User, Product, Order, OrderItem from app.models import Order, OrderItem, Product, User
fake = Faker() fake = Faker()
@pytest.fixture(scope='function') @pytest.fixture(scope="function")
def app(): def app():
"""Create application for testing with isolated database""" """Create application for testing with isolated database"""
db_fd, db_path = tempfile.mkstemp() db_fd, db_path = tempfile.mkstemp()
app = create_app(config_name='test') app = create_app(config_name="test")
app.config.update({ app.config.update(
'TESTING': True, {
'SQLALCHEMY_DATABASE_URI': f'sqlite:///{db_path}', "TESTING": True,
'WTF_CSRF_ENABLED': False, "SQLALCHEMY_DATABASE_URI": f"sqlite:///{db_path}",
'JWT_SECRET_KEY': 'test-secret-keytest-secret-keytest-secret-keytest-secret-keytest-secret-key', "WTF_CSRF_ENABLED": False,
'SERVER_NAME': 'localhost.localdomain' "JWT_SECRET_KEY": "test-secret-keytest-secret-keytest-secret-keytest-secret-keytest-secret-key",
}) "SERVER_NAME": "localhost.localdomain",
}
)
with app.app_context(): with app.app_context():
db.create_all() db.create_all()
@ -62,9 +65,9 @@ def admin_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=True, is_admin=True,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -79,9 +82,9 @@ def regular_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=False, is_admin=False,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -96,9 +99,9 @@ def inactive_user(db_session):
first_name=fake.first_name(), first_name=fake.first_name(),
last_name=fake.last_name(), last_name=fake.last_name(),
is_admin=False, is_admin=False,
is_active=False is_active=False,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
return user return user
@ -112,7 +115,7 @@ def product(db_session):
description=fake.paragraph(), description=fake.paragraph(),
price=fake.pydecimal(left_digits=2, right_digits=2, positive=True), price=fake.pydecimal(left_digits=2, right_digits=2, positive=True),
stock=fake.pyint(min_value=0, max_value=100), stock=fake.pyint(min_value=0, max_value=100),
image_url=fake.url() image_url=fake.url(),
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
@ -129,7 +132,7 @@ def products(db_session):
description=fake.paragraph(), description=fake.paragraph(),
price=fake.pydecimal(left_digits=2, right_digits=2, positive=True), price=fake.pydecimal(left_digits=2, right_digits=2, positive=True),
stock=fake.pyint(min_value=0, max_value=100), stock=fake.pyint(min_value=0, max_value=100),
image_url=fake.url() image_url=fake.url(),
) )
db_session.add(product) db_session.add(product)
products.append(product) products.append(product)
@ -140,36 +143,32 @@ def products(db_session):
@pytest.fixture @pytest.fixture
def auth_headers(client, regular_user): def auth_headers(client, regular_user):
"""Get authentication headers for a regular user""" """Get authentication headers for a regular user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': regular_user.email, "/api/auth/login", json={"email": regular_user.email, "password": "password123"}
'password': 'password123' )
})
data = response.get_json() data = response.get_json()
token = data['access_token'] token = data["access_token"]
print(f"Auth headers token for user {regular_user.email}: {token[:50]}...") print(f"Auth headers token for user {regular_user.email}: {token[:50]}...")
return {'Authorization': f'Bearer {token}'} return {"Authorization": f"Bearer {token}"}
@pytest.fixture @pytest.fixture
def admin_headers(client, admin_user): def admin_headers(client, admin_user):
"""Get authentication headers for an admin user""" """Get authentication headers for an admin user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': admin_user.email, "/api/auth/login", json={"email": admin_user.email, "password": "password123"}
'password': 'password123' )
})
data = response.get_json() data = response.get_json()
token = data['access_token'] token = data["access_token"]
print(f"Admin headers token for user {admin_user.email}: {token[:50]}...") print(f"Admin headers token for user {admin_user.email}: {token[:50]}...")
return {'Authorization': f'Bearer {token}'} return {"Authorization": f"Bearer {token}"}
@pytest.fixture @pytest.fixture
def order(db_session, regular_user, products): def order(db_session, regular_user, products):
"""Create an order for testing""" """Create an order for testing"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id, total_amount=0.0, shipping_address=fake.address()
total_amount=0.0,
shipping_address=fake.address()
) )
db_session.add(order) db_session.add(order)
db_session.flush() db_session.flush()
@ -181,7 +180,7 @@ def order(db_session, regular_user, products):
order_id=order.id, order_id=order.id,
product_id=product.id, product_id=product.id,
quantity=quantity, quantity=quantity,
price=product.price price=product.price,
) )
total_amount += float(product.price) * quantity total_amount += float(product.price) * quantity
db_session.add(order_item) db_session.add(order_item)

View file

@ -1,8 +1,9 @@
"""Test models""" """Test models"""
import pytest
from decimal import Decimal from decimal import Decimal
from datetime import datetime
from app.models import User, Product, Order, OrderItem import pytest
from app.models import Order, OrderItem, Product, User
class TestUserModel: class TestUserModel:
@ -12,62 +13,62 @@ class TestUserModel:
def test_user_creation(self, db_session): def test_user_creation(self, db_session):
"""Test creating a user""" """Test creating a user"""
user = User( user = User(
email='test@example.com', email="test@example.com",
username='testuser', username="testuser",
first_name='Test', first_name="Test",
last_name='User', last_name="User",
is_admin=False, is_admin=False,
is_active=True is_active=True,
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert user.id is not None assert user.id is not None
assert user.email == 'test@example.com' assert user.email == "test@example.com"
assert user.username == 'testuser' assert user.username == "testuser"
assert user.first_name == 'Test' assert user.first_name == "Test"
assert user.last_name == 'User' assert user.last_name == "User"
@pytest.mark.unit @pytest.mark.unit
def test_user_password_hashing(self, db_session): def test_user_password_hashing(self, db_session):
"""Test password hashing and verification""" """Test password hashing and verification"""
user = User(email='test@example.com', username='testuser') user = User(email="test@example.com", username="testuser")
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert user.check_password('password123') is True assert user.check_password("password123") is True
assert user.check_password('wrongpassword') is False assert user.check_password("wrongpassword") is False
@pytest.mark.unit @pytest.mark.unit
def test_user_to_dict(self, db_session): def test_user_to_dict(self, db_session):
"""Test user serialization to dictionary""" """Test user serialization to dictionary"""
user = User( user = User(
email='test@example.com', email="test@example.com",
username='testuser', username="testuser",
first_name='Test', first_name="Test",
last_name='User' last_name="User",
) )
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
user_dict = user.to_dict() user_dict = user.to_dict()
assert user_dict['email'] == 'test@example.com' assert user_dict["email"] == "test@example.com"
assert user_dict['username'] == 'testuser' assert user_dict["username"] == "testuser"
assert 'password' not in user_dict assert "password" not in user_dict
assert 'password_hash' not in user_dict assert "password_hash" not in user_dict
@pytest.mark.unit @pytest.mark.unit
def test_user_repr(self, db_session): def test_user_repr(self, db_session):
"""Test user string representation""" """Test user string representation"""
user = User(email='test@example.com', username='testuser') user = User(email="test@example.com", username="testuser")
user.set_password('password123') user.set_password("password123")
db_session.add(user) db_session.add(user)
db_session.commit() db_session.commit()
assert repr(user) == '<User testuser>' assert repr(user) == "<User testuser>"
class TestProductModel: class TestProductModel:
@ -77,18 +78,18 @@ class TestProductModel:
def test_product_creation(self, db_session): def test_product_creation(self, db_session):
"""Test creating a product""" """Test creating a product"""
product = Product( product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('99.99'), price=Decimal("99.99"),
stock=10, stock=10,
image_url='https://example.com/product.jpg' image_url="https://example.com/product.jpg",
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
assert product.id is not None assert product.id is not None
assert product.name == 'Test Product' assert product.name == "Test Product"
assert product.price == Decimal('99.99') assert product.price == Decimal("99.99")
assert product.stock == 10 assert product.stock == 10
assert product.is_active is True assert product.is_active is True
@ -96,27 +97,24 @@ class TestProductModel:
def test_product_to_dict(self, db_session): def test_product_to_dict(self, db_session):
"""Test product serialization to dictionary""" """Test product serialization to dictionary"""
product = Product( product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('99.99'), price=Decimal("99.99"),
stock=10 stock=10,
) )
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
product_dict = product.to_dict() product_dict = product.to_dict()
assert product_dict['name'] == 'Test Product' assert product_dict["name"] == "Test Product"
assert product_dict['price'] == 99.99 assert product_dict["price"] == 99.99
assert isinstance(product_dict['created_at'], str) assert isinstance(product_dict["created_at"], str)
assert isinstance(product_dict['updated_at'], str) assert isinstance(product_dict["updated_at"], str)
@pytest.mark.unit @pytest.mark.unit
def test_product_defaults(self, db_session): def test_product_defaults(self, db_session):
"""Test product default values""" """Test product default values"""
product = Product( product = Product(name="Test Product", price=Decimal("9.99"))
name='Test Product',
price=Decimal('9.99')
)
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
@ -128,11 +126,11 @@ class TestProductModel:
@pytest.mark.unit @pytest.mark.unit
def test_product_repr(self, db_session): def test_product_repr(self, db_session):
"""Test product string representation""" """Test product string representation"""
product = Product(name='Test Product', price=Decimal('9.99')) product = Product(name="Test Product", price=Decimal("9.99"))
db_session.add(product) db_session.add(product)
db_session.commit() db_session.commit()
assert repr(product) == '<Product Test Product>' assert repr(product) == "<Product Test Product>"
class TestOrderModel: class TestOrderModel:
@ -143,31 +141,31 @@ class TestOrderModel:
"""Test creating an order""" """Test creating an order"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id,
total_amount=Decimal('199.99'), total_amount=Decimal("199.99"),
shipping_address='123 Test St' shipping_address="123 Test St",
) )
db_session.add(order) db_session.add(order)
db_session.commit() db_session.commit()
assert order.id is not None assert order.id is not None
assert order.user_id == regular_user.id assert order.user_id == regular_user.id
assert order.total_amount == Decimal('199.99') assert order.total_amount == Decimal("199.99")
@pytest.mark.unit @pytest.mark.unit
def test_order_to_dict(self, db_session, regular_user): def test_order_to_dict(self, db_session, regular_user):
"""Test order serialization to dictionary""" """Test order serialization to dictionary"""
order = Order( order = Order(
user_id=regular_user.id, user_id=regular_user.id,
total_amount=Decimal('199.99'), total_amount=Decimal("199.99"),
shipping_address='123 Test St' shipping_address="123 Test St",
) )
db_session.add(order) db_session.add(order)
db_session.commit() db_session.commit()
order_dict = order.to_dict() order_dict = order.to_dict()
assert order_dict['user_id'] == regular_user.id assert order_dict["user_id"] == regular_user.id
assert order_dict['total_amount'] == 199.99 assert order_dict["total_amount"] == 199.99
assert isinstance(order_dict['created_at'], str) assert isinstance(order_dict["created_at"], str)
class TestOrderItemModel: class TestOrderItemModel:
@ -177,10 +175,7 @@ class TestOrderItemModel:
def test_order_item_creation(self, db_session, order, product): def test_order_item_creation(self, db_session, order, product):
"""Test creating an order item""" """Test creating an order item"""
order_item = OrderItem( order_item = OrderItem(
order_id=order.id, order_id=order.id, product_id=product.id, quantity=2, price=product.price
product_id=product.id,
quantity=2,
price=product.price
) )
db_session.add(order_item) db_session.add(order_item)
db_session.commit() db_session.commit()
@ -194,15 +189,12 @@ class TestOrderItemModel:
def test_order_item_to_dict(self, db_session, order, product): def test_order_item_to_dict(self, db_session, order, product):
"""Test order item serialization to dictionary""" """Test order item serialization to dictionary"""
order_item = OrderItem( order_item = OrderItem(
order_id=order.id, order_id=order.id, product_id=product.id, quantity=2, price=product.price
product_id=product.id,
quantity=2,
price=product.price
) )
db_session.add(order_item) db_session.add(order_item)
db_session.commit() db_session.commit()
item_dict = order_item.to_dict() item_dict = order_item.to_dict()
assert item_dict['order_id'] == order.id assert item_dict["order_id"] == order.id
assert item_dict['product_id'] == product.id assert item_dict["product_id"] == product.id
assert item_dict['quantity'] == 2 assert item_dict["quantity"] == 2

View file

@ -1,7 +1,5 @@
"""Test API routes""" """Test API routes"""
import pytest import pytest
import json
from decimal import Decimal
class TestAuthRoutes: class TestAuthRoutes:
@ -10,101 +8,109 @@ class TestAuthRoutes:
@pytest.mark.auth @pytest.mark.auth
def test_register_success(self, client): def test_register_success(self, client):
"""Test successful user registration""" """Test successful user registration"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': 'newuser@example.com', "/api/auth/register",
'password': 'password123', json={
'username': 'newuser', "email": "newuser@example.com",
'first_name': 'New', "password": "password123",
'last_name': 'User' "username": "newuser",
}) "first_name": "New",
"last_name": "User",
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['email'] == 'newuser@example.com' assert data["email"] == "newuser@example.com"
assert data['username'] == 'newuser' assert data["username"] == "newuser"
assert 'password' not in data assert "password" not in data
assert 'password_hash' not in data assert "password_hash" not in data
@pytest.mark.auth @pytest.mark.auth
def test_register_missing_fields(self, client): def test_register_missing_fields(self, client):
"""Test registration with missing required fields""" """Test registration with missing required fields"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': 'newuser@example.com' "/api/auth/register", json={"email": "newuser@example.com"}
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'error' in data assert "error" in data
@pytest.mark.auth @pytest.mark.auth
def test_register_duplicate_email(self, client, regular_user): def test_register_duplicate_email(self, client, regular_user):
"""Test registration with duplicate email""" """Test registration with duplicate email"""
response = client.post('/api/auth/register', json={ response = client.post(
'email': regular_user.email, "/api/auth/register",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'already exists' in data['error'].lower() assert "already exists" in data["error"].lower()
@pytest.mark.auth @pytest.mark.auth
def test_login_success(self, client, regular_user): def test_login_success(self, client, regular_user):
"""Test successful login""" """Test successful login"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': regular_user.email, "/api/auth/login",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert 'access_token' in data assert "access_token" in data
assert 'refresh_token' in data assert "refresh_token" in data
assert data['user']['email'] == regular_user.email assert data["user"]["email"] == regular_user.email
@pytest.mark.auth @pytest.mark.auth
@pytest.mark.parametrize("email,password,expected_status", [ @pytest.mark.parametrize(
("wrong@example.com", "password123", 401), "email,password,expected_status",
("user@example.com", "wrongpassword", 401), [
(None, "password123", 400), ("wrong@example.com", "password123", 401),
("user@example.com", None, 400), ("user@example.com", "wrongpassword", 401),
]) (None, "password123", 400),
def test_login_validation(self, client, regular_user, email, password, expected_status): ("user@example.com", None, 400),
],
)
def test_login_validation(
self, client, regular_user, email, password, expected_status
):
"""Test login with various invalid inputs""" """Test login with various invalid inputs"""
login_data = {} login_data = {}
if email is not None: if email is not None:
login_data['email'] = email login_data["email"] = email
if password is not None: if password is not None:
login_data['password'] = password login_data["password"] = password
response = client.post('/api/auth/login', json=login_data) response = client.post("/api/auth/login", json=login_data)
assert response.status_code == expected_status assert response.status_code == expected_status
@pytest.mark.auth @pytest.mark.auth
def test_login_inactive_user(self, client, inactive_user): def test_login_inactive_user(self, client, inactive_user):
"""Test login with inactive user""" """Test login with inactive user"""
response = client.post('/api/auth/login', json={ response = client.post(
'email': inactive_user.email, "/api/auth/login",
'password': 'password123' json={"email": inactive_user.email, "password": "password123"},
}) )
assert response.status_code == 401 assert response.status_code == 401
data = response.get_json() data = response.get_json()
assert 'inactive' in data['error'].lower() assert "inactive" in data["error"].lower()
@pytest.mark.auth @pytest.mark.auth
def test_get_current_user(self, client, auth_headers, regular_user): def test_get_current_user(self, client, auth_headers, regular_user):
"""Test getting current user""" """Test getting current user"""
response = client.get('/api/users/me', headers=auth_headers) response = client.get("/api/users/me", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['email'] == regular_user.email assert data["email"] == regular_user.email
@pytest.mark.auth @pytest.mark.auth
def test_get_current_user_unauthorized(self, client): def test_get_current_user_unauthorized(self, client):
"""Test getting current user without authentication""" """Test getting current user without authentication"""
response = client.get('/api/users/me') response = client.get("/api/users/me")
assert response.status_code == 401 assert response.status_code == 401
@ -114,7 +120,7 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_products(self, client, products): def test_get_products(self, client, products):
"""Test getting all products""" """Test getting all products"""
response = client.get('/api/products') response = client.get("/api/products")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -123,7 +129,7 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_products_empty(self, client): def test_get_products_empty(self, client):
"""Test getting products when none exist""" """Test getting products when none exist"""
response = client.get('/api/products') response = client.get("/api/products")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -132,113 +138,122 @@ class TestProductRoutes:
@pytest.mark.product @pytest.mark.product
def test_get_single_product(self, client, product): def test_get_single_product(self, client, product):
"""Test getting a single product""" """Test getting a single product"""
response = client.get(f'/api/products/{product.id}') response = client.get(f"/api/products/{product.id}")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['id'] == product.id assert data["id"] == product.id
assert data['name'] == product.name assert data["name"] == product.name
@pytest.mark.product @pytest.mark.product
def test_get_product_not_found(self, client): def test_get_product_not_found(self, client):
"""Test getting non-existent product""" """Test getting non-existent product"""
response = client.get('/api/products/999') response = client.get("/api/products/999")
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.product @pytest.mark.product
def test_create_product_admin(self, client, admin_headers): def test_create_product_admin(self, client, admin_headers):
"""Test creating product as admin""" """Test creating product as admin"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'description': 'A new product', headers=admin_headers,
'price': 29.99, json={
'stock': 10 "name": "New Product",
}) "description": "A new product",
"price": 29.99,
"stock": 10,
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['name'] == 'New Product' assert data["name"] == "New Product"
assert data['price'] == 29.99 assert data["price"] == 29.99
@pytest.mark.product @pytest.mark.product
def test_create_product_regular_user(self, client, auth_headers): def test_create_product_regular_user(self, client, auth_headers):
"""Test creating product as regular user (should fail)""" """Test creating product as regular user (should fail)"""
response = client.post('/api/products', headers=auth_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'price': 29.99 headers=auth_headers,
}) json={"name": "New Product", "price": 29.99},
)
assert response.status_code == 403 assert response.status_code == 403
data = response.get_json() data = response.get_json()
assert 'admin' in data['error'].lower() assert "admin" in data["error"].lower()
@pytest.mark.product @pytest.mark.product
def test_create_product_unauthorized(self, client): def test_create_product_unauthorized(self, client):
"""Test creating product without authentication""" """Test creating product without authentication"""
response = client.post('/api/products', json={ response = client.post(
'name': 'New Product', "/api/products", json={"name": "New Product", "price": 29.99}
'price': 29.99 )
})
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.product @pytest.mark.product
def test_create_product_validation_error(self, client, admin_headers): def test_create_product_validation_error(self, client, admin_headers):
"""Test creating product with invalid data""" """Test creating product with invalid data"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'New Product', "/api/products",
'price': -10.99 headers=admin_headers,
}) json={"name": "New Product", "price": -10.99},
)
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'Validation error' in data['error'] assert "Validation error" in data["error"]
@pytest.mark.product @pytest.mark.product
def test_create_product_missing_required_fields(self, client, admin_headers): def test_create_product_missing_required_fields(self, client, admin_headers):
"""Test creating product with missing required fields""" """Test creating product with missing required fields"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'description': 'Missing name and price' "/api/products",
}) headers=admin_headers,
json={"description": "Missing name and price"},
)
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'Validation error' in data['error'] assert "Validation error" in data["error"]
@pytest.mark.product @pytest.mark.product
def test_create_product_minimal_data(self, client, admin_headers): def test_create_product_minimal_data(self, client, admin_headers):
"""Test creating product with minimal valid data""" """Test creating product with minimal valid data"""
response = client.post('/api/products', headers=admin_headers, json={ response = client.post(
'name': 'Minimal Product', "/api/products",
'price': 19.99 headers=admin_headers,
}) json={"name": "Minimal Product", "price": 19.99},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert data['name'] == 'Minimal Product' assert data["name"] == "Minimal Product"
assert data['stock'] == 0 # Default value assert data["stock"] == 0 # Default value
@pytest.mark.product @pytest.mark.product
def test_update_product_admin(self, client, admin_headers, product): def test_update_product_admin(self, client, admin_headers, product):
"""Test updating product as admin""" """Test updating product as admin"""
response = client.put(f'/api/products/{product.id}', headers=admin_headers, json={ response = client.put(
'name': 'Updated Product', f"/api/products/{product.id}",
'price': 39.99 headers=admin_headers,
}) json={"name": "Updated Product", "price": 39.99},
)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['name'] == 'Updated Product' assert data["name"] == "Updated Product"
assert data['price'] == 39.99 assert data["price"] == 39.99
@pytest.mark.product @pytest.mark.product
def test_delete_product_admin(self, client, admin_headers, product): def test_delete_product_admin(self, client, admin_headers, product):
"""Test deleting product as admin""" """Test deleting product as admin"""
response = client.delete(f'/api/products/{product.id}', headers=admin_headers) response = client.delete(f"/api/products/{product.id}", headers=admin_headers)
assert response.status_code == 200 assert response.status_code == 200
# Verify product is deleted # Verify product is deleted
response = client.get(f'/api/products/{product.id}') response = client.get(f"/api/products/{product.id}")
assert response.status_code == 404 assert response.status_code == 404
@ -248,7 +263,7 @@ class TestOrderRoutes:
@pytest.mark.order @pytest.mark.order
def test_get_orders(self, client, auth_headers, order): def test_get_orders(self, client, auth_headers, order):
"""Test getting orders for current user""" """Test getting orders for current user"""
response = client.get('/api/orders', headers=auth_headers) response = client.get("/api/orders", headers=auth_headers)
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
@ -257,63 +272,68 @@ class TestOrderRoutes:
@pytest.mark.order @pytest.mark.order
def test_get_orders_unauthorized(self, client): def test_get_orders_unauthorized(self, client):
"""Test getting orders without authentication""" """Test getting orders without authentication"""
response = client.get('/api/orders') response = client.get("/api/orders")
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.order @pytest.mark.order
def test_create_order(self, client, auth_headers, products): def test_create_order(self, client, auth_headers, products):
"""Test creating an order""" """Test creating an order"""
response = client.post('/api/orders', headers=auth_headers, json={ response = client.post(
'items': [ "/api/orders",
{'product_id': products[0].id, 'quantity': 2}, headers=auth_headers,
{'product_id': products[1].id, 'quantity': 1} json={
], "items": [
'shipping_address': '123 Test St' {"product_id": products[0].id, "quantity": 2},
}) {"product_id": products[1].id, "quantity": 1},
],
"shipping_address": "123 Test St",
},
)
assert response.status_code == 201 assert response.status_code == 201
data = response.get_json() data = response.get_json()
assert 'id' in data assert "id" in data
assert len(data['items']) == 2 assert len(data["items"]) == 2
@pytest.mark.order @pytest.mark.order
def test_create_order_insufficient_stock(self, client, auth_headers, db_session, products): def test_create_order_insufficient_stock(
self, client, auth_headers, db_session, products
):
"""Test creating order with insufficient stock""" """Test creating order with insufficient stock"""
# Set stock to 0 # Set stock to 0
products[0].stock = 0 products[0].stock = 0
db_session.commit() db_session.commit()
response = client.post('/api/orders', headers=auth_headers, json={ response = client.post(
'items': [ "/api/orders",
{'product_id': products[0].id, 'quantity': 2} headers=auth_headers,
] json={"items": [{"product_id": products[0].id, "quantity": 2}]},
}) )
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
assert 'insufficient' in data['error'].lower() assert "insufficient" in data["error"].lower()
@pytest.mark.order @pytest.mark.order
def test_get_single_order(self, client, auth_headers, order): def test_get_single_order(self, client, auth_headers, order):
"""Test getting a single order""" """Test getting a single order"""
response = client.get(f'/api/orders/{order.id}', headers=auth_headers) response = client.get(f"/api/orders/{order.id}", headers=auth_headers)
print('test_get_single_order', response.get_json()) print("test_get_single_order", response.get_json())
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert data['id'] == order.id assert data["id"] == order.id
@pytest.mark.order @pytest.mark.order
def test_get_other_users_order(self, client, admin_headers, regular_user, products): def test_get_other_users_order(self, client, admin_headers, regular_user, products):
"""Test admin accessing another user's order""" """Test admin accessing another user's order"""
# Create an order for regular_user # Create an order for regular_user
client.post('/api/auth/login', json={ client.post(
'email': regular_user.email, "/api/auth/login",
'password': 'password123' json={"email": regular_user.email, "password": "password123"},
}) )
# Admin should be able to access any order # Admin should be able to access any order
response = client.get(f'/api/orders/1', headers=admin_headers)
# This test assumes order exists, adjust as needed # This test assumes order exists, adjust as needed
pass pass

View file

@ -1,7 +1,9 @@
"""Test Pydantic schemas""" """Test Pydantic schemas"""
from decimal import Decimal
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from decimal import Decimal
from app.schemas import ProductCreateRequest, ProductResponse from app.schemas import ProductCreateRequest, ProductResponse
@ -12,31 +14,28 @@ class TestProductCreateRequestSchema:
def test_valid_product_request(self): def test_valid_product_request(self):
"""Test valid product creation request""" """Test valid product creation request"""
data = { data = {
'name': 'Handcrafted Wooden Bowl', "name": "Handcrafted Wooden Bowl",
'description': 'A beautiful handcrafted bowl', "description": "A beautiful handcrafted bowl",
'price': 45.99, "price": 45.99,
'stock': 10, "stock": 10,
'image_url': 'https://example.com/bowl.jpg' "image_url": "https://example.com/bowl.jpg",
} }
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.name == data['name'] assert product.name == data["name"]
assert product.description == data['description'] assert product.description == data["description"]
assert product.price == Decimal('45.99') assert product.price == Decimal("45.99")
assert product.stock == 10 assert product.stock == 10
assert product.image_url == data['image_url'] assert product.image_url == data["image_url"]
@pytest.mark.unit @pytest.mark.unit
def test_minimal_valid_request(self): def test_minimal_valid_request(self):
"""Test minimal valid request (only required fields)""" """Test minimal valid request (only required fields)"""
data = { data = {"name": "Simple Product", "price": 19.99}
'name': 'Simple Product',
'price': 19.99
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.name == 'Simple Product' assert product.name == "Simple Product"
assert product.price == Decimal('19.99') assert product.price == Decimal("19.99")
assert product.stock == 0 assert product.stock == 0
assert product.description is None assert product.description is None
assert product.image_url is None assert product.image_url is None
@ -44,134 +43,107 @@ class TestProductCreateRequestSchema:
@pytest.mark.unit @pytest.mark.unit
def test_missing_name(self): def test_missing_name(self):
"""Test request with missing name""" """Test request with missing name"""
data = { data = {"price": 19.99}
'price': 19.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('name',) for error in errors) assert any(error["loc"] == ("name",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_missing_price(self): def test_missing_price(self):
"""Test request with missing price""" """Test request with missing price"""
data = { data = {"name": "Test Product"}
'name': 'Test Product'
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('price',) for error in errors) assert any(error["loc"] == ("price",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_negative(self): def test_invalid_price_negative(self):
"""Test request with negative price""" """Test request with negative price"""
data = { data = {"name": "Test Product", "price": -10.99}
'name': 'Test Product',
'price': -10.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than' for error in errors) assert any(error["type"] == "greater_than" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_zero(self): def test_invalid_price_zero(self):
"""Test request with zero price""" """Test request with zero price"""
data = { data = {"name": "Test Product", "price": 0.0}
'name': 'Test Product',
'price': 0.0
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than' for error in errors) assert any(error["type"] == "greater_than" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_price_too_many_decimals(self): def test_invalid_price_too_many_decimals(self):
"""Test request with too many decimal places""" """Test request with too many decimal places"""
data = { data = {"name": "Test Product", "price": 10.999}
'name': 'Test Product',
'price': 10.999
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any('decimal places' in str(error).lower() for error in errors) assert any("decimal places" in str(error).lower() for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_invalid_stock_negative(self): def test_invalid_stock_negative(self):
"""Test request with negative stock""" """Test request with negative stock"""
data = { data = {"name": "Test Product", "price": 19.99, "stock": -5}
'name': 'Test Product',
'price': 19.99,
'stock': -5
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['type'] == 'greater_than_equal' for error in errors) assert any(error["type"] == "greater_than_equal" for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_name_too_long(self): def test_name_too_long(self):
"""Test request with name exceeding max length""" """Test request with name exceeding max length"""
data = { data = {"name": "A" * 201, "price": 19.99} # Exceeds 200 character limit
'name': 'A' * 201, # Exceeds 200 character limit
'price': 19.99
}
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('name',) for error in errors) assert any(error["loc"] == ("name",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_image_url_too_long(self): def test_image_url_too_long(self):
"""Test request with image_url exceeding max length""" """Test request with image_url exceeding max length"""
data = { data = {
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'image_url': 'A' * 501 # Exceeds 500 character limit "image_url": "A" * 501, # Exceeds 500 character limit
} }
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
ProductCreateRequest(**data) ProductCreateRequest(**data)
errors = exc_info.value.errors() errors = exc_info.value.errors()
assert any(error['loc'] == ('image_url',) for error in errors) assert any(error["loc"] == ("image_url",) for error in errors)
@pytest.mark.unit @pytest.mark.unit
def test_price_string_conversion(self): def test_price_string_conversion(self):
"""Test price string to Decimal conversion""" """Test price string to Decimal conversion"""
data = { data = {"name": "Test Product", "price": "29.99"}
'name': 'Test Product',
'price': '29.99'
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.price == Decimal('29.99') assert product.price == Decimal("29.99")
@pytest.mark.unit @pytest.mark.unit
def test_stock_string_conversion(self): def test_stock_string_conversion(self):
"""Test stock string to int conversion""" """Test stock string to int conversion"""
data = { data = {"name": "Test Product", "price": 19.99, "stock": "10"}
'name': 'Test Product',
'price': 19.99,
'stock': '10'
}
product = ProductCreateRequest(**data) product = ProductCreateRequest(**data)
assert product.stock == 10 assert product.stock == 10
@ -185,20 +157,20 @@ class TestProductResponseSchema:
def test_valid_product_response(self): def test_valid_product_response(self):
"""Test valid product response""" """Test valid product response"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'description': 'A test product', "description": "A test product",
'price': 45.99, "price": 45.99,
'stock': 10, "stock": 10,
'image_url': 'https://example.com/product.jpg', "image_url": "https://example.com/product.jpg",
'is_active': True, "is_active": True,
'created_at': '2024-01-15T10:30:00', "created_at": "2024-01-15T10:30:00",
'updated_at': '2024-01-15T10:30:00' "updated_at": "2024-01-15T10:30:00",
} }
product = ProductResponse(**data) product = ProductResponse(**data)
assert product.id == 1 assert product.id == 1
assert product.name == 'Test Product' assert product.name == "Test Product"
assert product.price == 45.99 assert product.price == 45.99
assert product.stock == 10 assert product.stock == 10
assert product.is_active is True assert product.is_active is True
@ -207,11 +179,11 @@ class TestProductResponseSchema:
def test_product_response_with_none_fields(self): def test_product_response_with_none_fields(self):
"""Test product response with optional None fields""" """Test product response with optional None fields"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 0, "stock": 0,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
@ -226,17 +198,17 @@ class TestProductResponseSchema:
from app.models import Product from app.models import Product
db_product = Product( db_product = Product(
name='Test Product', name="Test Product",
description='A test product', description="A test product",
price=Decimal('45.99'), price=Decimal("45.99"),
stock=10 stock=10,
) )
db_session.add(db_product) db_session.add(db_product)
db_session.commit() db_session.commit()
# Validate using model_validate (for SQLAlchemy models) # Validate using model_validate (for SQLAlchemy models)
response = ProductResponse.model_validate(db_product) response = ProductResponse.model_validate(db_product)
assert response.name == 'Test Product' assert response.name == "Test Product"
assert response.price == 45.99 assert response.price == 45.99
assert response.stock == 10 assert response.stock == 10
@ -244,34 +216,34 @@ class TestProductResponseSchema:
def test_model_dump(self): def test_model_dump(self):
"""Test model_dump method""" """Test model_dump method"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 5, "stock": 5,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
dumped = product.model_dump() dumped = product.model_dump()
assert isinstance(dumped, dict) assert isinstance(dumped, dict)
assert dumped['id'] == 1 assert dumped["id"] == 1
assert dumped['name'] == 'Test Product' assert dumped["name"] == "Test Product"
assert dumped['price'] == 19.99 assert dumped["price"] == 19.99
@pytest.mark.unit @pytest.mark.unit
def test_model_dump_json(self): def test_model_dump_json(self):
"""Test model_dump_json method""" """Test model_dump_json method"""
data = { data = {
'id': 1, "id": 1,
'name': 'Test Product', "name": "Test Product",
'price': 19.99, "price": 19.99,
'stock': 5, "stock": 5,
'is_active': True "is_active": True,
} }
product = ProductResponse(**data) product = ProductResponse(**data)
json_str = product.model_dump_json() json_str = product.model_dump_json()
assert isinstance(json_str, str) assert isinstance(json_str, str)
assert 'Test Product' in json_str assert "Test Product" in json_str