diff --git a/testflows/snapshots/snapshots.py b/testflows/snapshots/snapshots.py index 6b72f23..61ad63f 100644 --- a/testflows/snapshots/snapshots.py +++ b/testflows/snapshots/snapshots.py @@ -19,6 +19,7 @@ from .compare import Compare from .errors import * from .v1 import snapshot as snapshot_v1 +from .v2 import snapshot as snapshot_v2 __all__ = ["snapshot"] @@ -53,6 +54,48 @@ def get_snapshot_filename(frame, path, id): return filename +def get_snapshot_filename_v2(frame, path, id, filename): + """Return snapshot filename for V2 (JSON) snapshots. + + When ``filename`` is provided, it is used directly (joined with ``path`` + if ``path`` is also given). Otherwise, the filename is derived from + the caller's frame info similar to V1 but with a ``.json`` extension. + + :param frame: caller's stack frame + :param path: custom snapshot directory, default: ``./snapshots`` + :param id: unique id of the snapshot file, default: ``None`` + :param filename: explicit filename to use, default: ``None`` + """ + if filename is not None: + # When an explicit filename is provided, just join it with path + if path is None: + frame_info = inspect.getframeinfo(frame) + path = os.path.join(os.path.dirname(frame_info.filename), "snapshots") + + if not os.path.exists(path): + os.makedirs(path) + + return os.path.join(path, filename) + + # Fall back to auto-generated filename with .json extension + frame_info = inspect.getframeinfo(frame) + + id_parts = [os.path.basename(frame_info.filename)] + if id is not None: + id_parts.append(str(id).lower()) + id_parts.append("json") + + file_id = ".".join(id_parts) + + if path is None: + path = os.path.join(os.path.dirname(frame_info.filename), "snapshots") + + if not os.path.exists(path): + os.makedirs(path) + + return os.path.join(path, file_id) + + def snapshot( value, id=None, @@ -65,27 +108,33 @@ def snapshot( version=snapshot_v1.VERSION, frame=None, compare=Compare.eq, + filename=None, ): """Compare value representation to a stored snapshot. If snapshot does not exist, assertion passes else representation of the value is compared to the stored snapshot. - Snapshot files have format: + For V1 (default), snapshot files have format: [.].snapshot + For V2 (JSON), snapshot files have format: + + [.].json (or custom ``filename``) + :param value: value to be used for snapshot - :param id: unique id of the snapshot file, default: `None` + :param id: unique id of the snapshot file, default: ``None`` :param output: function to output the representation of the value - :param path: custom snapshot path, default: `./snapshots` - :param name: name of the snapshot value inside the snapshots file, default: `snapshot` - :param encoder: custom snapshot encoder, default: `repr` + :param path: custom snapshot path, default: ``./snapshots`` + :param name: name of the snapshot value inside the snapshots file, default: ``snapshot`` + :param encoder: custom snapshot encoder, default: ``repr`` :param comment: (deprecated) :param mode: mode of operation: CHECK, UPDATE, REWRITE, default: CHECK | UPDATE - :param version: snapshot version, default: snapshot_v1.VERSION - :param frame: caller frame, default: `None` + :param version: snapshot version, default: ``snapshot_v1.VERSION`` + :param frame: caller frame, default: ``None`` :param compare: custom comparison function, default: equals + :param filename: explicit snapshot filename (V2 only), default: ``None`` """ if frame is None: frame = inspect.currentframe().f_back @@ -98,11 +147,24 @@ def snapshot( if output: output(repr_value) - filename = get_snapshot_filename(frame=frame, path=path, id=id) - if version == snapshot_v1.VERSION: + snapshot_file = get_snapshot_filename(frame=frame, path=path, id=id) + return snapshot_v1( - filename=filename, + filename=snapshot_file, + repr_value=repr_value, + name=name, + mode=mode, + compare=compare, + ) + + if version == snapshot_v2.VERSION: + snapshot_file = get_snapshot_filename_v2( + frame=frame, path=path, id=id, filename=filename + ) + + return snapshot_v2( + filename=snapshot_file, repr_value=repr_value, name=name, mode=mode, @@ -122,3 +184,4 @@ def snapshot( # define supported versions snapshot.VERSION_V1 = snapshot_v1.VERSION +snapshot.VERSION_V2 = snapshot_v2.VERSION \ No newline at end of file diff --git a/testflows/snapshots/v2.py b/testflows/snapshots/v2.py new file mode 100644 index 0000000..c4645f7 --- /dev/null +++ b/testflows/snapshots/v2.py @@ -0,0 +1,168 @@ +# Copyright 2024 Katteli Inc. +# TestFlows.com Open-Source Software Testing Framework (http://testflows.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import json +import textwrap +import difflib + +from .errors import SnapshotError as SnapshotErrorBase +from .errors import SnapshotNotFoundError as SnapshotNotFoundErrorBase +from .parallel import RWLock +from .mode import * +from .compare import Compare + +locks = {} + + +def get_lock(filename): + if filename not in locks: + locks[filename] = RWLock() + return locks[filename] + + +def _json_repr(value): + """Return a stable JSON string representation of a value.""" + return json.dumps(value, indent=2, sort_keys=True) + + +class SnapshotError(SnapshotErrorBase): + def __init__(self, filename, name, snapshot_value, actual_value): + self.snapshot_value = snapshot_value + self.actual_value = actual_value + self.filename = str(filename) + self.name = str(name) + + def __bool__(self): + return False + + def __repr__(self): + snapshot_str = _json_repr(self.snapshot_value) + actual_str = _json_repr(self.actual_value) + + r = "SnapshotError(" + r += "\nfilename=" + self.filename + r += "\nname=" + self.name + r += '\nsnapshot_value="""\n' + r += textwrap.indent(snapshot_str, " " * 4) + r += '""",\nactual_value="""\n' + r += textwrap.indent(actual_str, " " * 4) + r += '""",\ndiff="""\n' + r += textwrap.indent( + "\n".join( + [ + line.strip("\n") + for line in difflib.unified_diff( + snapshot_str.splitlines(), + actual_str.splitlines(), + fromfile=self.filename, + tofile="actual", + ) + ] + ), + " " * 4, + ) + r += '\n""")' + return r + + +class SnapshotNotFoundError(SnapshotNotFoundErrorBase): + def __init__(self, filename, name, actual_value): + self.actual_value = actual_value + self.filename = str(filename) + self.name = str(name) + + def __bool__(self): + return False + + def __repr__(self): + r = "SnapshotNotFoundError(" + r += "\nfilename=" + self.filename + r += "\nname=" + self.name + r += '\nactual_value="""\n' + r += textwrap.indent(_json_repr(self.actual_value), " " * 4) + r += '\n""")' + return r + + +def read_snapshot_file(filename): + """Read and parse a JSON snapshot file.""" + with open(filename, "r", encoding="utf-8") as fd: + return json.load(fd) + + +def write_snapshot_file(filename, data): + """Write data to a JSON snapshot file with sorted keys.""" + with open(filename, "w", encoding="utf-8") as fd: + json.dump(data, fd, indent=2, sort_keys=True) + fd.write("\n") + + +def snapshot( + filename, + repr_value, + name="snapshot", + mode=SNAPSHOT_MODE_CHECK | SNAPSHOT_MODE_UPDATE, + compare=Compare.eq, +): + """Check value against a snapshot value stored in a JSON file. + + The JSON file stores a single object where each key is a snapshot name + and each value is the stored snapshot value. + + :param filename: path to the JSON snapshot file + :param repr_value: the encoded value to compare (JSON-serializable) + :param name: name of the snapshot entry within the file, default: ``snapshot`` + :param mode: mode of operation: CHECK, UPDATE, REWRITE, default: CHECK | UPDATE + :param compare: custom comparison function, default: equals + """ + lock = get_lock(filename) + + if os.path.exists(filename): + with lock.read(): + data = read_snapshot_file(filename) + + if name in data: + snapshot_value = data[name] + if not compare(snapshot_value, repr_value): + if mode & SNAPSHOT_MODE_CHECK: + return SnapshotError(filename, name, snapshot_value, repr_value) + else: + return True + + if not (mode & SNAPSHOT_MODE_UPDATE): + return SnapshotNotFoundError(filename, name, repr_value) + + # write or update snapshot entry + with lock.write(): + if os.path.exists(filename): + data = read_snapshot_file(filename) + else: + data = {} + + data[name] = repr_value + write_snapshot_file(filename, data) + + if mode & SNAPSHOT_MODE_REWRITE: + # For JSON, rewriting simply re-reads and re-writes the file + # to ensure canonical formatting (sorted keys, consistent indent). + with lock.write(): + data = read_snapshot_file(filename) + write_snapshot_file(filename, data) + + return True + + +# define version +snapshot.VERSION = 2 diff --git a/tests/actions/model.py b/tests/actions/model.py index 98f4355..7bce06f 100644 --- a/tests/actions/model.py +++ b/tests/actions/model.py @@ -1,6 +1,6 @@ from testflows.core import current, debug from testflows.snapshots import snapshot -from testflows.snapshots.snapshots import get_snapshot_filename +from testflows.snapshots.snapshots import get_snapshot_filename, get_snapshot_filename_v2 import actions.expect @@ -16,11 +16,20 @@ def __init__(self, **kwargs): self.encoder = kwargs.pop("encoder", "repr") self.mode = kwargs.pop("mode", snapshot.CHECK | snapshot.UPDATE) self.version = kwargs.pop("version", snapshot.VERSION_V1) - self.filename = get_snapshot_filename( - frame=kwargs.pop("frame"), - path=kwargs.pop("path", None), - id=kwargs.pop("id", None), - ) + + frame = kwargs.pop("frame") + path = kwargs.pop("path", None) + id = kwargs.pop("id", None) + explicit_filename = kwargs.pop("filename", None) + + if self.version == snapshot.VERSION_V2: + self.filename = get_snapshot_filename_v2( + frame=frame, path=path, id=id, filename=explicit_filename, + ) + else: + self.filename = get_snapshot_filename( + frame=frame, path=path, id=id, + ) def __str__(self): mode = [] diff --git a/tests/actions/snapshot.py b/tests/actions/snapshot.py index fe921b2..ebf32a9 100644 --- a/tests/actions/snapshot.py +++ b/tests/actions/snapshot.py @@ -6,13 +6,13 @@ from testflows.snapshots import snapshot from testflows.snapshots.errors import SnapshotError -from testflows.snapshots.snapshots import get_snapshot_filename +from testflows.snapshots.snapshots import get_snapshot_filename, get_snapshot_filename_v2 import actions.model @TestStep(Given) -def get_unique_id(self, frame=None, path=None): +def get_unique_id(self, frame=None, path=None, version=None): """Generate unique snapshot id and delete snapshot file for that id at the end of the test.""" @@ -24,7 +24,30 @@ def get_unique_id(self, frame=None, path=None): yield id finally: - filename = get_snapshot_filename(frame=frame, path=path, id=id) + if version == snapshot.VERSION_V2: + filename = get_snapshot_filename_v2(frame=frame, path=path, id=id, filename=None) + else: + filename = get_snapshot_filename(frame=frame, path=path, id=id) + with By("deleting file for the snapshot id", description=f"{filename}"): + try: + os.remove(filename) + except FileNotFoundError: + pass + + +@TestStep(Given) +def get_unique_id_v2(self, frame=None, path=None): + """Generate unique snapshot id for V2 and delete snapshot file at the end.""" + + if frame is None: + frame = inspect.currentframe() + + try: + id = uuid.uuid4().hex + yield id + + finally: + filename = get_snapshot_filename_v2(frame=frame, path=path, id=id, filename=None) with By("deleting file for the snapshot id", description=f"{filename}"): try: os.remove(filename) diff --git a/tests/snapshot.py b/tests/snapshot.py index 4d792a5..94a5c06 100644 --- a/tests/snapshot.py +++ b/tests/snapshot.py @@ -11,3 +11,4 @@ def feature(self): Feature(run=load("value", "feature")) Feature(run=load("compare", "feature")) Feature(run=load("mode", "feature")) + Feature(run=load("value_v2", "feature")) \ No newline at end of file diff --git a/tests/value_v2.py b/tests/value_v2.py new file mode 100644 index 0000000..be81349 --- /dev/null +++ b/tests/value_v2.py @@ -0,0 +1,195 @@ +import json +import os +import tempfile + +from testflows.core import * +from testflows.asserts import error, raises +from testflows.snapshots import snapshot + +import actions.python +import actions.snapshot + +currentframe = actions.python.currentframe + + +@TestScenario +def stored_and_compared(self): + """Check V2 JSON value is stored and compared against.""" + + value = {"key": "hello", "number": 42} + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with When("value is stored using the UPDATE mode"): + assert snapshot( + value, id=id, encoder=lambda v: v, + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + + with Then("the snapshot matches the stored value using the CHECK mode"): + assert snapshot( + value, id=id, encoder=lambda v: v, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ), error() + + +@TestScenario +def mismatch_detected(self): + """Check V2 detects mismatch between stored and actual values.""" + + original = {"status": 200, "has_data": True} + changed = {"status": 403, "has_data": False} + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with When("original value is stored"): + assert snapshot( + original, id=id, encoder=lambda v: v, + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + + with Then("changed value fails CHECK"): + result = snapshot( + changed, id=id, encoder=lambda v: v, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ) + assert not result, error() + from testflows.snapshots.errors import SnapshotError + assert isinstance(result, SnapshotError), error() + + +@TestScenario +def multiple_names_same_file(self): + """Check V2 supports multiple named entries in the same JSON file.""" + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with When("store multiple named values"): + for name in ["admin", "user", "guest"]: + value = {"role": name, "status_code": 200 if name != "guest" else 403} + assert snapshot( + value, id=id, name=name, encoder=lambda v: v, + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + + with Then("each named value matches"): + for name in ["admin", "user", "guest"]: + value = {"role": name, "status_code": 200 if name != "guest" else 403} + assert snapshot( + value, id=id, name=name, encoder=lambda v: v, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ), error() + + +@TestScenario +def custom_filename(self): + """Check V2 with explicit filename parameter.""" + + with Given("a temp directory"): + tmpdir = tempfile.mkdtemp() + + try: + value = {"test": True, "count": 5} + + with When("snapshot is stored with explicit filename"): + assert snapshot( + value, name="entry1", encoder=lambda v: v, + path=tmpdir, filename="custom_test.json", + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + + with Then("the file exists with the custom name"): + filepath = os.path.join(tmpdir, "custom_test.json") + assert os.path.exists(filepath), error() + + with And("the content is correct JSON"): + with open(filepath, "r") as f: + data = json.load(f) + assert data == {"entry1": value}, error() + + with And("snapshot check passes"): + assert snapshot( + value, name="entry1", encoder=lambda v: v, + path=tmpdir, filename="custom_test.json", + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ), error() + + finally: + import shutil + shutil.rmtree(tmpdir, ignore_errors=True) + + +@TestScenario +def not_found_without_update(self): + """Check V2 returns SnapshotNotFoundError when file does not exist and UPDATE is off.""" + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with Then("CHECK-only mode returns SnapshotNotFoundError"): + result = snapshot( + {"test": 1}, id=id, encoder=lambda v: v, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ) + assert not result, error() + from testflows.snapshots.errors import SnapshotNotFoundError + assert isinstance(result, SnapshotNotFoundError), error() + + +@TestScenario +def supported_types(self): + """Check V2 supports various JSON-serializable types.""" + + values = [ + 1, 1.0, "string", [1, 2, 3], + {"key": "value"}, True, False, None, + {"nested": {"deep": [1, 2, {"x": 3}]}}, + ] + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with Then("supported types can be stored and checked"): + for i, value in enumerate(values): + assert snapshot( + value, id=id, name=f"type_{i}", encoder=lambda v: v, + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + assert snapshot( + value, id=id, name=f"type_{i}", encoder=lambda v: v, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ), error() + + +@TestScenario +def json_encoder(self): + """Check V2 works with json.dumps as encoder.""" + + value = {"key": "value", "list": [1, 2]} + + with Given("snapshot id"): + id = actions.snapshot.get_unique_id_v2(frame=currentframe()) + + with When("store with json.dumps encoder"): + assert snapshot( + value, id=id, encoder=json.dumps, + mode=snapshot.UPDATE, version=snapshot.VERSION_V2, + ), error() + + with Then("check with json.dumps encoder"): + assert snapshot( + value, id=id, encoder=json.dumps, + mode=snapshot.CHECK, version=snapshot.VERSION_V2, + ), error() + + +@TestFeature +@Name("value_v2") +def feature(self): + """Check V2 JSON snapshot values.""" + + for scenario in loads(current_module(), Scenario): + scenario()