From b59d91ddd9e957a4cdf9dbb35704d7664de840f3 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Wed, 25 Aug 2021 09:18:57 +0200 Subject: [PATCH 1/6] fix: by default columns should not have indexes --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..f745b244a9 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -417,7 +417,7 @@ def get_column_from_field(field: ModelField) -> Column: nullable = not field.required index = getattr(field.field_info, "index", Undefined) if index is Undefined: - index = True + index = False if hasattr(field.field_info, "nullable"): field_nullable = getattr(field.field_info, "nullable") if field_nullable != Undefined: From b64d2ddea8bca9c5001b62a7a1b6df7e1c082dab Mon Sep 17 00:00:00 2001 From: zhangbc <1731259685@qq.com> Date: Thu, 26 Aug 2021 13:21:37 +0800 Subject: [PATCH 2/6] Fix Enum Type Mapping --- sqlmodel/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..1992472b09 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -49,6 +49,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time +from sqlalchemy.sql.sqltypes import Enum as SQLAlchemyEnum from .sql.sqltypes import GUID, AutoString @@ -389,7 +390,7 @@ def get_sqlachemy_type(field: ModelField) -> Any: if issubclass(field.type_, time): return Time if issubclass(field.type_, Enum): - return Enum + return SQLAlchemyEnum(field.type_) if issubclass(field.type_, bytes): return LargeBinary if issubclass(field.type_, Decimal): From 959f3802187de63cced7a68c21cccbb174379795 Mon Sep 17 00:00:00 2001 From: Andrew Bolster Date: Thu, 26 Aug 2021 07:58:46 +0200 Subject: [PATCH 3/6] Update GUID handling use stdlib UUID.hex Rather than integer based serialization grandfathered in from sqlalchemy, use the stdlib [`UUID.hex`](https://docs.python.org/3/library/uuid.html#uuid.UUID.hex) method. This also fixes #25 --- sqlmodel/sql/sqltypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/sql/sqltypes.py b/sqlmodel/sql/sqltypes.py index e7b77b8c52..ac9dd773fc 100644 --- a/sqlmodel/sql/sqltypes.py +++ b/sqlmodel/sql/sqltypes.py @@ -46,10 +46,10 @@ def process_bind_param(self, value, dialect): return str(value) else: if not isinstance(value, uuid.UUID): - return f"{uuid.UUID(value).int:x}" + return uuid.UUID(value).hex else: # hexstring - return f"{value.int:x}" + return value.hex def process_result_value(self, value, dialect): if value is None: From 19805e9e42edb453f66632fd585e52e23c97d749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AB=A0=E4=B8=99=E8=BE=B0?= <1731259685@qq.com> Date: Thu, 26 Aug 2021 14:36:41 +0800 Subject: [PATCH 4/6] Unify Code Style Change SQLAlchemyEnum to _Enum --- sqlmodel/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 1992472b09..bf26c2effe 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -49,7 +49,7 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time -from sqlalchemy.sql.sqltypes import Enum as SQLAlchemyEnum +from sqlalchemy.sql.sqltypes import Enum as _Enum from .sql.sqltypes import GUID, AutoString @@ -390,7 +390,7 @@ def get_sqlachemy_type(field: ModelField) -> Any: if issubclass(field.type_, time): return Time if issubclass(field.type_, Enum): - return SQLAlchemyEnum(field.type_) + return _Enum(field.type_) if issubclass(field.type_, bytes): return LargeBinary if issubclass(field.type_, Decimal): From 83870d19d6aa6b9b7cd93183d9729b023c8208b2 Mon Sep 17 00:00:00 2001 From: Evangelos Anagnostopoulos Date: Fri, 3 Sep 2021 18:53:54 +0300 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=8E=A8=20Fix=20nullable=20property=20?= =?UTF-8?q?of=20Fields?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 13 ++++++++++++- .../test_create_db_and_table/test_tutorial001.py | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..a76a0cd072 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -26,6 +26,7 @@ from pydantic import BaseModel from pydantic.errors import ConfigError, DictError +from pydantic.fields import SHAPE_SINGLETON from pydantic.fields import FieldInfo as PydanticFieldInfo from pydantic.fields import ModelField, Undefined, UndefinedType from pydantic.main import BaseConfig, ModelMetaclass, validate_model @@ -414,7 +415,7 @@ def get_column_from_field(field: ModelField) -> Column: return sa_column sa_type = get_sqlachemy_type(field) primary_key = getattr(field.field_info, "primary_key", False) - nullable = not field.required + nullable = not primary_key and _is_field_nullable(field) index = getattr(field.field_info, "index", Undefined) if index is Undefined: index = True @@ -634,3 +635,13 @@ def _calculate_keys( # type: ignore @declared_attr # type: ignore def __tablename__(cls) -> str: return cls.__name__.lower() + + +def _is_field_nullable(field: ModelField) -> bool: + if not field.required: + # Taken from [Pydantic](https://github.com/samuelcolvin/pydantic/blob/v1.8.2/pydantic/fields.py#L946-L947) + is_optional = field.allow_none and ( + field.shape != SHAPE_SINGLETON or not field.sub_fields + ) + return is_optional and field.default is None and field.default_factory is None + return False diff --git a/tests/test_tutorial/test_create_db_and_table/test_tutorial001.py b/tests/test_tutorial/test_create_db_and_table/test_tutorial001.py index 591a51cc22..b6a2e72628 100644 --- a/tests/test_tutorial/test_create_db_and_table/test_tutorial001.py +++ b/tests/test_tutorial/test_create_db_and_table/test_tutorial001.py @@ -9,7 +9,7 @@ def test_create_db_and_table(cov_tmp_path: Path): assert "BEGIN" in result.stdout assert 'PRAGMA main.table_info("hero")' in result.stdout assert "CREATE TABLE hero (" in result.stdout - assert "id INTEGER," in result.stdout + assert "id INTEGER NOT NULL," in result.stdout assert "name VARCHAR NOT NULL," in result.stdout assert "secret_name VARCHAR NOT NULL," in result.stdout assert "age INTEGER," in result.stdout From dd9dc129c082a3ec93e2c58a38a123d128d759d1 Mon Sep 17 00:00:00 2001 From: Raphael Gibson Date: Tue, 7 Sep 2021 00:20:17 -0300 Subject: [PATCH 6/6] feat: add unique constraint param to Field function --- sqlmodel/main.py | 6 +++ tests/test_main.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 tests/test_main.py diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 661276b31d..7be05165d7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -70,6 +70,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: primary_key = kwargs.pop("primary_key", False) nullable = kwargs.pop("nullable", Undefined) foreign_key = kwargs.pop("foreign_key", Undefined) + unique = kwargs.pop("unique", False) index = kwargs.pop("index", Undefined) sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) @@ -89,6 +90,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.primary_key = primary_key self.nullable = nullable self.foreign_key = foreign_key + self.unique = unique self.index = index self.sa_column = sa_column self.sa_column_args = sa_column_args @@ -150,6 +152,7 @@ def Field( regex: str = None, primary_key: bool = False, foreign_key: Optional[Any] = None, + unique: bool = False, nullable: Union[bool, UndefinedType] = Undefined, index: Union[bool, UndefinedType] = Undefined, sa_column: Union[Column, UndefinedType] = Undefined, @@ -180,6 +183,7 @@ def Field( regex=regex, primary_key=primary_key, foreign_key=foreign_key, + unique=unique, nullable=nullable, index=index, sa_column=sa_column, @@ -424,12 +428,14 @@ def get_column_from_field(field: ModelField) -> Column: nullable = field_nullable args = [] foreign_key = getattr(field.field_info, "foreign_key", None) + unique = getattr(field.field_info, "unique", False) if foreign_key: args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, "nullable": nullable, "index": index, + "unique": unique } sa_default = Undefined if field.field_info.default_factory: diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000000..65ad0d9b56 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,91 @@ +import pytest +from typing import Optional + +from sqlmodel import Field, Session, SQLModel, create_engine +from sqlalchemy.exc import IntegrityError + + +def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_allow_duplicate_row_if_unique_constraint_is_false(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=False) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2) + + with Session(engine) as session: + heroes = session.query(Hero).all() + assert len(heroes) == 2 + assert heroes[0].name == heroes[1].name + + +def test_should_raise_exception_when_try_to_duplicate_row_if_unique_constraint_is_true(clear_sqlmodel): + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str = Field(unique=True) + age: Optional[int] = None + + hero_1 = Hero(name="Deadpond", secret_name="Dive Wilson") + hero_2 = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_engine("sqlite://") + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(hero_1) + session.commit() + session.refresh(hero_1) + + with pytest.raises(IntegrityError): + with Session(engine) as session: + session.add(hero_2) + session.commit() + session.refresh(hero_2)